Files
ai-game-2/tools/ora_editor/services/comfyui.py
Bryce c8932fdbf8 Add multi-mask extraction with count selector and navigation
- Add count selector (1-10) for generating multiple mask variations
- Each mask gets a unique random seed
- Add left/right arrow navigation in mask preview modal when multiple masks exist
- Batch storage system for tracking multiple concurrent extractions
- Webhook handler now uses batch_id:mask_index for routing responses
2026-03-27 21:36:20 -07:00

195 lines
6.6 KiB
Python

"""ComfyUI integration service for mask extraction."""
import base64
import io
import json
import random
import threading
import time
import urllib.error
import urllib.parse
import urllib.request
from pathlib import Path
from typing import Any
import logging
from PIL import Image, ImageDraw
logger = logging.getLogger(__name__)
class BatchStorage:
"""Thread-safe storage for batch mask extraction results."""
def __init__(self):
self._batches: dict[str, dict[str, Any]] = {}
self._lock = threading.Lock()
def create_batch(self, batch_id: str, count: int) -> None:
"""Initialize a batch with expected count."""
with self._lock:
self._batches[batch_id] = {
'expected': count,
'masks': {},
'errors': []
}
def add_mask(self, batch_id: str, index: int, path: Path) -> None:
"""Add a completed mask to the batch."""
with self._lock:
if batch_id in self._batches:
self._batches[batch_id]['masks'][index] = str(path)
def add_error(self, batch_id: str, error: str) -> None:
"""Add an error to the batch."""
with self._lock:
if batch_id in self._batches:
self._batches[batch_id]['errors'].append(error)
def get_batch(self, batch_id: str) -> dict | None:
"""Get batch status."""
with self._lock:
return self._batches.get(batch_id)
def is_complete(self, batch_id: str) -> bool:
"""Check if batch has all expected masks."""
with self._lock:
batch = self._batches.get(batch_id)
if not batch:
return False
return len(batch['masks']) >= batch['expected']
def clear_batch(self, batch_id: str) -> None:
"""Remove a batch from storage."""
with self._lock:
self._batches.pop(batch_id, None)
batch_storage = BatchStorage()
class ComfyUIService:
"""Service for interacting with ComfyUI API."""
def __init__(self, base_url: str):
self.base_url = base_url
def submit_workflow(self, workflow: dict, comfy_url: str | None = None) -> str:
"""Submit a workflow to ComfyUI and return the prompt_id."""
url = comfy_url or self.base_url
headers = {'Content-Type': 'application/json'}
payload = json.dumps({"prompt": workflow}).encode('utf-8')
req = urllib.request.Request(
f'http://{url}/prompt',
data=payload,
headers=headers,
method='POST'
)
with urllib.request.urlopen(req, timeout=30) as response:
result = json.loads(response.read().decode())
prompt_id = result.get('prompt_id')
if not prompt_id:
raise RuntimeError("No prompt_id returned from ComfyUI")
return prompt_id
def poll_for_completion(self, prompt_id: str, comfy_url: str | None = None, timeout: int = 240) -> bool:
"""Poll ComfyUI history for workflow completion."""
url = comfy_url or self.base_url
headers = {'Content-Type': 'application/json'}
start_time = time.time()
while time.time() - start_time < timeout:
try:
req = urllib.request.Request(
f'http://{url}/history/{prompt_id}',
headers=headers,
method='GET'
)
with urllib.request.urlopen(req, timeout=30) as response:
history = json.loads(response.read().decode())
if prompt_id in history:
status = history[prompt_id].get('status', {})
if status.get('status_str') == 'success':
return True
time.sleep(2)
except urllib.error.HTTPError as e:
if e.code == 404:
time.sleep(2)
else:
raise
except Exception as e:
logger.error(f"Error polling history: {e}")
time.sleep(2)
return False
def poll_for_batch_completion(self, batch_id: str, timeout: float = 300.0) -> list[str]:
"""Poll until all masks in a batch are received via webhook."""
start_time = time.time()
while time.time() - start_time < timeout:
batch = batch_storage.get_batch(batch_id)
if batch and batch_storage.is_complete(batch_id):
masks = batch['masks']
return [masks[i] for i in sorted(masks.keys())]
time.sleep(1)
return []
def prepare_mask_workflow(
self,
base_image: Image.Image,
subject: str,
webhook_url: str,
seed: int,
batch_id: str = None,
mask_index: int = 0,
polygon_points: list | None = None,
polygon_color: str = '#FF0000',
polygon_width: int = 2,
workflow_template: dict | None = None
) -> dict:
"""Prepare the mask extraction workflow."""
workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {}
img = base_image.copy()
if polygon_points and len(polygon_points) >= 3:
w, h = img.size
pixel_points = [(int(p['x'] * w), int(p['y'] * h)) for p in polygon_points]
draw = ImageDraw.Draw(img)
hex_color = polygon_color if len(polygon_color) == 7 else polygon_color + 'FF'
draw.polygon(pixel_points, outline=hex_color, width=polygon_width)
img_io = io.BytesIO()
img.save(img_io, format='PNG')
img_io.seek(0)
base64_image = base64.b64encode(img_io.read()).decode('utf-8')
if "87" in workflow:
workflow["87"]["inputs"]["image"] = base64_image
if "1:68" in workflow and 'inputs' in workflow["1:68"]:
workflow["1:68"]["inputs"]["prompt"] = f"Create a black and white alpha mask of {subject}, leaving everything else black"
if "96" in workflow and 'inputs' in workflow["96"]:
workflow["96"]["inputs"]["webhook_url"] = webhook_url
if "50" in workflow and 'inputs' in workflow["50"]:
workflow["50"]["inputs"]["seed"] = seed
if "96" in workflow and 'inputs' in workflow["96"]:
metadata = f"{batch_id}:{mask_index}" if batch_id else str(mask_index)
workflow["96"]["inputs"]["external_uid"] = metadata
return workflow