- Add roughMaskThumbnailScale state with $watch to sync with main scale slider - Update sidebar thumbnail to use transform:scale() for consistent zoom between views - Modify openRoughMaskInNewWindow() to create HTML page with matching scale - Add denoise strength slider (10-100%) visible only when rough mask exists - Backend already supports denoise_strength parameter in prepare_mask_workflow_with_start() - Rough mask auto-clears after successful extraction - Add Playwright tests for UI changes and API parameter acceptance
331 lines
12 KiB
Python
331 lines
12 KiB
Python
"""ComfyUI integration service for mask extraction."""
|
|
|
|
import base64
|
|
import io
|
|
import json
|
|
import random
|
|
import threading
|
|
import time
|
|
import urllib.error
|
|
import urllib.parse
|
|
import urllib.request
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import logging
|
|
from PIL import Image, ImageDraw
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BatchStorage:
|
|
"""Thread-safe storage for batch mask extraction results."""
|
|
|
|
def __init__(self):
|
|
self._batches: dict[str, dict[str, Any]] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def create_batch(self, batch_id: str, count: int) -> None:
|
|
"""Initialize a batch with expected count."""
|
|
with self._lock:
|
|
self._batches[batch_id] = {
|
|
'expected': count,
|
|
'masks': {},
|
|
'errors': []
|
|
}
|
|
|
|
def add_mask(self, batch_id: str, index: int, path: Path) -> None:
|
|
"""Add a completed mask to the batch."""
|
|
with self._lock:
|
|
if batch_id in self._batches:
|
|
self._batches[batch_id]['masks'][index] = str(path)
|
|
|
|
def add_error(self, batch_id: str, error: str) -> None:
|
|
"""Add an error to the batch."""
|
|
with self._lock:
|
|
if batch_id in self._batches:
|
|
self._batches[batch_id]['errors'].append(error)
|
|
|
|
def get_batch(self, batch_id: str) -> dict | None:
|
|
"""Get batch status."""
|
|
with self._lock:
|
|
return self._batches.get(batch_id)
|
|
|
|
def is_complete(self, batch_id: str) -> bool:
|
|
"""Check if batch has all expected masks."""
|
|
with self._lock:
|
|
batch = self._batches.get(batch_id)
|
|
if not batch:
|
|
return False
|
|
return len(batch['masks']) >= batch['expected']
|
|
|
|
def clear_batch(self, batch_id: str) -> None:
|
|
"""Remove a batch from storage."""
|
|
with self._lock:
|
|
self._batches.pop(batch_id, None)
|
|
|
|
|
|
batch_storage = BatchStorage()
|
|
|
|
|
|
class ComfyUIService:
|
|
"""Service for interacting with ComfyUI API."""
|
|
|
|
def __init__(self, base_url: str):
|
|
self.base_url = base_url
|
|
|
|
def submit_workflow(self, workflow: dict, comfy_url: str | None = None) -> str:
|
|
"""Submit a workflow to ComfyUI and return the prompt_id."""
|
|
url = comfy_url or self.base_url
|
|
headers = {'Content-Type': 'application/json'}
|
|
payload = json.dumps({"prompt": workflow}).encode('utf-8')
|
|
|
|
req = urllib.request.Request(
|
|
f'http://{url}/prompt',
|
|
data=payload,
|
|
headers=headers,
|
|
method='POST'
|
|
)
|
|
|
|
with urllib.request.urlopen(req, timeout=30) as response:
|
|
result = json.loads(response.read().decode())
|
|
prompt_id = result.get('prompt_id')
|
|
|
|
if not prompt_id:
|
|
raise RuntimeError("No prompt_id returned from ComfyUI")
|
|
|
|
return prompt_id
|
|
|
|
def poll_for_completion(self, prompt_id: str, comfy_url: str | None = None, timeout: int = 240) -> bool:
|
|
"""Poll ComfyUI history for workflow completion."""
|
|
url = comfy_url or self.base_url
|
|
headers = {'Content-Type': 'application/json'}
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout:
|
|
try:
|
|
req = urllib.request.Request(
|
|
f'http://{url}/history/{prompt_id}',
|
|
headers=headers,
|
|
method='GET'
|
|
)
|
|
|
|
with urllib.request.urlopen(req, timeout=30) as response:
|
|
history = json.loads(response.read().decode())
|
|
|
|
if prompt_id in history:
|
|
status = history[prompt_id].get('status', {})
|
|
if status.get('status_str') == 'success':
|
|
return True
|
|
|
|
time.sleep(2)
|
|
|
|
except urllib.error.HTTPError as e:
|
|
if e.code == 404:
|
|
time.sleep(2)
|
|
else:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error polling history: {e}")
|
|
time.sleep(2)
|
|
|
|
return False
|
|
|
|
def poll_for_batch_completion(self, batch_id: str, timeout: float = 300.0) -> list[str]:
|
|
"""Poll until all masks in a batch are received via webhook."""
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout:
|
|
batch = batch_storage.get_batch(batch_id)
|
|
if batch and batch_storage.is_complete(batch_id):
|
|
masks = batch['masks']
|
|
return [masks[i] for i in sorted(masks.keys())]
|
|
time.sleep(1)
|
|
|
|
return []
|
|
|
|
def prepare_mask_workflow(
|
|
self,
|
|
base_image: Image.Image,
|
|
subject: str,
|
|
webhook_url: str,
|
|
seed: int,
|
|
batch_id: str = None,
|
|
mask_index: int = 0,
|
|
polygon_points: list | None = None,
|
|
polygon_color: str = '#FF0000',
|
|
polygon_width: int = 2,
|
|
workflow_template: dict | None = None
|
|
) -> dict:
|
|
"""Prepare the mask extraction workflow."""
|
|
workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {}
|
|
|
|
img = base_image.copy()
|
|
|
|
if polygon_points and len(polygon_points) >= 3:
|
|
w, h = img.size
|
|
pixel_points = [(int(p['x'] * w), int(p['y'] * h)) for p in polygon_points]
|
|
|
|
draw = ImageDraw.Draw(img)
|
|
hex_color = polygon_color if len(polygon_color) == 7 else polygon_color + 'FF'
|
|
draw.polygon(pixel_points, outline=hex_color, width=polygon_width)
|
|
|
|
img_io = io.BytesIO()
|
|
img.save(img_io, format='PNG')
|
|
img_io.seek(0)
|
|
base64_image = base64.b64encode(img_io.read()).decode('utf-8')
|
|
|
|
if "87" in workflow:
|
|
workflow["87"]["inputs"]["image"] = base64_image
|
|
|
|
if "1:68" in workflow and 'inputs' in workflow["1:68"]:
|
|
workflow["1:68"]["inputs"]["prompt"] = f"Create a black and white alpha mask of {subject}, leaving everything else black"
|
|
|
|
if "96" in workflow and 'inputs' in workflow["96"]:
|
|
workflow["96"]["inputs"]["webhook_url"] = webhook_url
|
|
|
|
if "50" in workflow and 'inputs' in workflow["50"]:
|
|
workflow["50"]["inputs"]["seed"] = seed
|
|
|
|
if "96" in workflow and 'inputs' in workflow["96"]:
|
|
metadata = f"{batch_id}:{mask_index}" if batch_id else str(mask_index)
|
|
workflow["96"]["inputs"]["external_uid"] = metadata
|
|
|
|
return workflow
|
|
|
|
def prepare_sam_workflow(
|
|
self,
|
|
base_image: Image.Image,
|
|
include_points: list,
|
|
exclude_points: list,
|
|
webhook_url: str,
|
|
batch_id: str = None,
|
|
workflow_template: dict | None = None
|
|
) -> dict:
|
|
"""Prepare the SAM3 rough mask workflow with user-provided points."""
|
|
workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {}
|
|
|
|
img_io = io.BytesIO()
|
|
base_image.save(img_io, format='PNG')
|
|
img_io.seek(0)
|
|
base64_image = base64.b64encode(img_io.read()).decode('utf-8')
|
|
|
|
if "9" in workflow:
|
|
workflow["9"]["inputs"]["image"] = base64_image
|
|
|
|
all_points = []
|
|
for pt in include_points:
|
|
all_points.append({
|
|
'x': pt['x'],
|
|
'y': pt['y'],
|
|
'is_foreground': True
|
|
})
|
|
for pt in exclude_points:
|
|
all_points.append({
|
|
'x': pt['x'],
|
|
'y': pt['y'],
|
|
'is_foreground': False
|
|
})
|
|
|
|
point_nodes = {}
|
|
point_node_ids = []
|
|
|
|
for i, pt in enumerate(all_points):
|
|
node_id = str(100 + i)
|
|
point_nodes[node_id] = {
|
|
"inputs": {
|
|
"x": pt['x'],
|
|
"y": pt['y'],
|
|
"is_foreground": pt['is_foreground']
|
|
},
|
|
"class_type": "SAM3CreatePoint",
|
|
"_meta": {"title": f"SAM3 Point {i+1}"}
|
|
}
|
|
point_node_ids.append(node_id)
|
|
|
|
for node_id, node_data in point_nodes.items():
|
|
workflow[node_id] = node_data
|
|
|
|
if point_node_ids:
|
|
combine_inputs = {}
|
|
for i, node_id in enumerate(point_node_ids):
|
|
combine_inputs[f"point_{i+1}"] = [node_id, 0]
|
|
|
|
workflow["8"] = {
|
|
"inputs": combine_inputs,
|
|
"class_type": "SAM3CombinePoints",
|
|
"_meta": {"title": "SAM3 Combine Points"}
|
|
}
|
|
|
|
if "1" in workflow:
|
|
workflow["1"]["inputs"]["positive_points"] = ["8", 0]
|
|
|
|
if "11" in workflow:
|
|
workflow["11"]["inputs"]["webhook_url"] = webhook_url
|
|
if batch_id:
|
|
workflow["11"]["inputs"]["external_uid"] = f"{batch_id}:0"
|
|
|
|
return workflow
|
|
|
|
def prepare_mask_workflow_with_start(
|
|
self,
|
|
base_image: Image.Image,
|
|
start_mask_image: Image.Image,
|
|
subject: str,
|
|
webhook_url: str,
|
|
seed: int,
|
|
batch_id: str = None,
|
|
mask_index: int = 0,
|
|
polygon_points: list | None = None,
|
|
polygon_color: str = '#FF0000',
|
|
polygon_width: int = 2,
|
|
denoise_strength: float = 0.8,
|
|
workflow_template: dict | None = None
|
|
) -> dict:
|
|
"""Prepare the mask extraction workflow with a starting mask (lower denoise)."""
|
|
workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {}
|
|
|
|
img = base_image.copy()
|
|
|
|
if polygon_points and len(polygon_points) >= 3:
|
|
w, h = img.size
|
|
pixel_points = [(int(p['x'] * w), int(p['y'] * h)) for p in polygon_points]
|
|
|
|
draw = ImageDraw.Draw(img)
|
|
hex_color = polygon_color if len(polygon_color) == 7 else polygon_color + 'FF'
|
|
draw.polygon(pixel_points, outline=hex_color, width=polygon_width)
|
|
|
|
img_io = io.BytesIO()
|
|
img.save(img_io, format='PNG')
|
|
img_io.seek(0)
|
|
base64_image = base64.b64encode(img_io.read()).decode('utf-8')
|
|
|
|
if "87" in workflow:
|
|
workflow["87"]["inputs"]["image"] = base64_image
|
|
|
|
start_mask_io = io.BytesIO()
|
|
start_mask_image.save(start_mask_io, format='PNG')
|
|
start_mask_io.seek(0)
|
|
start_mask_base64 = base64.b64encode(start_mask_io.read()).decode('utf-8')
|
|
|
|
if "200" in workflow:
|
|
workflow["200"]["inputs"]["image"] = start_mask_base64
|
|
|
|
if "1:68" in workflow and 'inputs' in workflow["1:68"]:
|
|
workflow["1:68"]["inputs"]["prompt"] = f"Create a black and white alpha mask of {subject}, leaving everything else black"
|
|
|
|
if "96" in workflow and 'inputs' in workflow["96"]:
|
|
workflow["96"]["inputs"]["webhook_url"] = webhook_url
|
|
|
|
if "50" in workflow and 'inputs' in workflow["50"]:
|
|
workflow["50"]["inputs"]["seed"] = seed
|
|
|
|
if "1:65" in workflow and 'inputs' in workflow["1:65"]:
|
|
workflow["1:65"]["inputs"]["denoise"] = denoise_strength
|
|
|
|
if "96" in workflow and 'inputs' in workflow["96"]:
|
|
metadata = f"{batch_id}:{mask_index}" if batch_id else str(mask_index)
|
|
workflow["96"]["inputs"]["external_uid"] = metadata
|
|
|
|
return workflow
|