Files
ai-game-2/tools/ora_editor/services/comfyui.py
Bryce 7c9a25dd91 Restructure SAM rough mask workflow for sidebar preview
- Add roughMaskThumbnailScale state with $watch to sync with main scale slider
- Update sidebar thumbnail to use transform:scale() for consistent zoom between views
- Modify openRoughMaskInNewWindow() to create HTML page with matching scale
- Add denoise strength slider (10-100%) visible only when rough mask exists
- Backend already supports denoise_strength parameter in prepare_mask_workflow_with_start()
- Rough mask auto-clears after successful extraction
- Add Playwright tests for UI changes and API parameter acceptance
2026-03-28 10:42:27 -07:00

331 lines
12 KiB
Python

"""ComfyUI integration service for mask extraction."""
import base64
import io
import json
import random
import threading
import time
import urllib.error
import urllib.parse
import urllib.request
from pathlib import Path
from typing import Any
import logging
from PIL import Image, ImageDraw
logger = logging.getLogger(__name__)
class BatchStorage:
"""Thread-safe storage for batch mask extraction results."""
def __init__(self):
self._batches: dict[str, dict[str, Any]] = {}
self._lock = threading.Lock()
def create_batch(self, batch_id: str, count: int) -> None:
"""Initialize a batch with expected count."""
with self._lock:
self._batches[batch_id] = {
'expected': count,
'masks': {},
'errors': []
}
def add_mask(self, batch_id: str, index: int, path: Path) -> None:
"""Add a completed mask to the batch."""
with self._lock:
if batch_id in self._batches:
self._batches[batch_id]['masks'][index] = str(path)
def add_error(self, batch_id: str, error: str) -> None:
"""Add an error to the batch."""
with self._lock:
if batch_id in self._batches:
self._batches[batch_id]['errors'].append(error)
def get_batch(self, batch_id: str) -> dict | None:
"""Get batch status."""
with self._lock:
return self._batches.get(batch_id)
def is_complete(self, batch_id: str) -> bool:
"""Check if batch has all expected masks."""
with self._lock:
batch = self._batches.get(batch_id)
if not batch:
return False
return len(batch['masks']) >= batch['expected']
def clear_batch(self, batch_id: str) -> None:
"""Remove a batch from storage."""
with self._lock:
self._batches.pop(batch_id, None)
batch_storage = BatchStorage()
class ComfyUIService:
"""Service for interacting with ComfyUI API."""
def __init__(self, base_url: str):
self.base_url = base_url
def submit_workflow(self, workflow: dict, comfy_url: str | None = None) -> str:
"""Submit a workflow to ComfyUI and return the prompt_id."""
url = comfy_url or self.base_url
headers = {'Content-Type': 'application/json'}
payload = json.dumps({"prompt": workflow}).encode('utf-8')
req = urllib.request.Request(
f'http://{url}/prompt',
data=payload,
headers=headers,
method='POST'
)
with urllib.request.urlopen(req, timeout=30) as response:
result = json.loads(response.read().decode())
prompt_id = result.get('prompt_id')
if not prompt_id:
raise RuntimeError("No prompt_id returned from ComfyUI")
return prompt_id
def poll_for_completion(self, prompt_id: str, comfy_url: str | None = None, timeout: int = 240) -> bool:
"""Poll ComfyUI history for workflow completion."""
url = comfy_url or self.base_url
headers = {'Content-Type': 'application/json'}
start_time = time.time()
while time.time() - start_time < timeout:
try:
req = urllib.request.Request(
f'http://{url}/history/{prompt_id}',
headers=headers,
method='GET'
)
with urllib.request.urlopen(req, timeout=30) as response:
history = json.loads(response.read().decode())
if prompt_id in history:
status = history[prompt_id].get('status', {})
if status.get('status_str') == 'success':
return True
time.sleep(2)
except urllib.error.HTTPError as e:
if e.code == 404:
time.sleep(2)
else:
raise
except Exception as e:
logger.error(f"Error polling history: {e}")
time.sleep(2)
return False
def poll_for_batch_completion(self, batch_id: str, timeout: float = 300.0) -> list[str]:
"""Poll until all masks in a batch are received via webhook."""
start_time = time.time()
while time.time() - start_time < timeout:
batch = batch_storage.get_batch(batch_id)
if batch and batch_storage.is_complete(batch_id):
masks = batch['masks']
return [masks[i] for i in sorted(masks.keys())]
time.sleep(1)
return []
def prepare_mask_workflow(
self,
base_image: Image.Image,
subject: str,
webhook_url: str,
seed: int,
batch_id: str = None,
mask_index: int = 0,
polygon_points: list | None = None,
polygon_color: str = '#FF0000',
polygon_width: int = 2,
workflow_template: dict | None = None
) -> dict:
"""Prepare the mask extraction workflow."""
workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {}
img = base_image.copy()
if polygon_points and len(polygon_points) >= 3:
w, h = img.size
pixel_points = [(int(p['x'] * w), int(p['y'] * h)) for p in polygon_points]
draw = ImageDraw.Draw(img)
hex_color = polygon_color if len(polygon_color) == 7 else polygon_color + 'FF'
draw.polygon(pixel_points, outline=hex_color, width=polygon_width)
img_io = io.BytesIO()
img.save(img_io, format='PNG')
img_io.seek(0)
base64_image = base64.b64encode(img_io.read()).decode('utf-8')
if "87" in workflow:
workflow["87"]["inputs"]["image"] = base64_image
if "1:68" in workflow and 'inputs' in workflow["1:68"]:
workflow["1:68"]["inputs"]["prompt"] = f"Create a black and white alpha mask of {subject}, leaving everything else black"
if "96" in workflow and 'inputs' in workflow["96"]:
workflow["96"]["inputs"]["webhook_url"] = webhook_url
if "50" in workflow and 'inputs' in workflow["50"]:
workflow["50"]["inputs"]["seed"] = seed
if "96" in workflow and 'inputs' in workflow["96"]:
metadata = f"{batch_id}:{mask_index}" if batch_id else str(mask_index)
workflow["96"]["inputs"]["external_uid"] = metadata
return workflow
def prepare_sam_workflow(
self,
base_image: Image.Image,
include_points: list,
exclude_points: list,
webhook_url: str,
batch_id: str = None,
workflow_template: dict | None = None
) -> dict:
"""Prepare the SAM3 rough mask workflow with user-provided points."""
workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {}
img_io = io.BytesIO()
base_image.save(img_io, format='PNG')
img_io.seek(0)
base64_image = base64.b64encode(img_io.read()).decode('utf-8')
if "9" in workflow:
workflow["9"]["inputs"]["image"] = base64_image
all_points = []
for pt in include_points:
all_points.append({
'x': pt['x'],
'y': pt['y'],
'is_foreground': True
})
for pt in exclude_points:
all_points.append({
'x': pt['x'],
'y': pt['y'],
'is_foreground': False
})
point_nodes = {}
point_node_ids = []
for i, pt in enumerate(all_points):
node_id = str(100 + i)
point_nodes[node_id] = {
"inputs": {
"x": pt['x'],
"y": pt['y'],
"is_foreground": pt['is_foreground']
},
"class_type": "SAM3CreatePoint",
"_meta": {"title": f"SAM3 Point {i+1}"}
}
point_node_ids.append(node_id)
for node_id, node_data in point_nodes.items():
workflow[node_id] = node_data
if point_node_ids:
combine_inputs = {}
for i, node_id in enumerate(point_node_ids):
combine_inputs[f"point_{i+1}"] = [node_id, 0]
workflow["8"] = {
"inputs": combine_inputs,
"class_type": "SAM3CombinePoints",
"_meta": {"title": "SAM3 Combine Points"}
}
if "1" in workflow:
workflow["1"]["inputs"]["positive_points"] = ["8", 0]
if "11" in workflow:
workflow["11"]["inputs"]["webhook_url"] = webhook_url
if batch_id:
workflow["11"]["inputs"]["external_uid"] = f"{batch_id}:0"
return workflow
def prepare_mask_workflow_with_start(
self,
base_image: Image.Image,
start_mask_image: Image.Image,
subject: str,
webhook_url: str,
seed: int,
batch_id: str = None,
mask_index: int = 0,
polygon_points: list | None = None,
polygon_color: str = '#FF0000',
polygon_width: int = 2,
denoise_strength: float = 0.8,
workflow_template: dict | None = None
) -> dict:
"""Prepare the mask extraction workflow with a starting mask (lower denoise)."""
workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {}
img = base_image.copy()
if polygon_points and len(polygon_points) >= 3:
w, h = img.size
pixel_points = [(int(p['x'] * w), int(p['y'] * h)) for p in polygon_points]
draw = ImageDraw.Draw(img)
hex_color = polygon_color if len(polygon_color) == 7 else polygon_color + 'FF'
draw.polygon(pixel_points, outline=hex_color, width=polygon_width)
img_io = io.BytesIO()
img.save(img_io, format='PNG')
img_io.seek(0)
base64_image = base64.b64encode(img_io.read()).decode('utf-8')
if "87" in workflow:
workflow["87"]["inputs"]["image"] = base64_image
start_mask_io = io.BytesIO()
start_mask_image.save(start_mask_io, format='PNG')
start_mask_io.seek(0)
start_mask_base64 = base64.b64encode(start_mask_io.read()).decode('utf-8')
if "200" in workflow:
workflow["200"]["inputs"]["image"] = start_mask_base64
if "1:68" in workflow and 'inputs' in workflow["1:68"]:
workflow["1:68"]["inputs"]["prompt"] = f"Create a black and white alpha mask of {subject}, leaving everything else black"
if "96" in workflow and 'inputs' in workflow["96"]:
workflow["96"]["inputs"]["webhook_url"] = webhook_url
if "50" in workflow and 'inputs' in workflow["50"]:
workflow["50"]["inputs"]["seed"] = seed
if "1:65" in workflow and 'inputs' in workflow["1:65"]:
workflow["1:65"]["inputs"]["denoise"] = denoise_strength
if "96" in workflow and 'inputs' in workflow["96"]:
metadata = f"{batch_id}:{mask_index}" if batch_id else str(mask_index)
workflow["96"]["inputs"]["external_uid"] = metadata
return workflow