- Add count selector (1-10) for generating multiple mask variations - Each mask gets a unique random seed - Add left/right arrow navigation in mask preview modal when multiple masks exist - Batch storage system for tracking multiple concurrent extractions - Webhook handler now uses batch_id:mask_index for routing responses
195 lines
6.6 KiB
Python
195 lines
6.6 KiB
Python
"""ComfyUI integration service for mask extraction."""
|
|
|
|
import base64
|
|
import io
|
|
import json
|
|
import random
|
|
import threading
|
|
import time
|
|
import urllib.error
|
|
import urllib.parse
|
|
import urllib.request
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import logging
|
|
from PIL import Image, ImageDraw
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BatchStorage:
|
|
"""Thread-safe storage for batch mask extraction results."""
|
|
|
|
def __init__(self):
|
|
self._batches: dict[str, dict[str, Any]] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def create_batch(self, batch_id: str, count: int) -> None:
|
|
"""Initialize a batch with expected count."""
|
|
with self._lock:
|
|
self._batches[batch_id] = {
|
|
'expected': count,
|
|
'masks': {},
|
|
'errors': []
|
|
}
|
|
|
|
def add_mask(self, batch_id: str, index: int, path: Path) -> None:
|
|
"""Add a completed mask to the batch."""
|
|
with self._lock:
|
|
if batch_id in self._batches:
|
|
self._batches[batch_id]['masks'][index] = str(path)
|
|
|
|
def add_error(self, batch_id: str, error: str) -> None:
|
|
"""Add an error to the batch."""
|
|
with self._lock:
|
|
if batch_id in self._batches:
|
|
self._batches[batch_id]['errors'].append(error)
|
|
|
|
def get_batch(self, batch_id: str) -> dict | None:
|
|
"""Get batch status."""
|
|
with self._lock:
|
|
return self._batches.get(batch_id)
|
|
|
|
def is_complete(self, batch_id: str) -> bool:
|
|
"""Check if batch has all expected masks."""
|
|
with self._lock:
|
|
batch = self._batches.get(batch_id)
|
|
if not batch:
|
|
return False
|
|
return len(batch['masks']) >= batch['expected']
|
|
|
|
def clear_batch(self, batch_id: str) -> None:
|
|
"""Remove a batch from storage."""
|
|
with self._lock:
|
|
self._batches.pop(batch_id, None)
|
|
|
|
|
|
batch_storage = BatchStorage()
|
|
|
|
|
|
class ComfyUIService:
|
|
"""Service for interacting with ComfyUI API."""
|
|
|
|
def __init__(self, base_url: str):
|
|
self.base_url = base_url
|
|
|
|
def submit_workflow(self, workflow: dict, comfy_url: str | None = None) -> str:
|
|
"""Submit a workflow to ComfyUI and return the prompt_id."""
|
|
url = comfy_url or self.base_url
|
|
headers = {'Content-Type': 'application/json'}
|
|
payload = json.dumps({"prompt": workflow}).encode('utf-8')
|
|
|
|
req = urllib.request.Request(
|
|
f'http://{url}/prompt',
|
|
data=payload,
|
|
headers=headers,
|
|
method='POST'
|
|
)
|
|
|
|
with urllib.request.urlopen(req, timeout=30) as response:
|
|
result = json.loads(response.read().decode())
|
|
prompt_id = result.get('prompt_id')
|
|
|
|
if not prompt_id:
|
|
raise RuntimeError("No prompt_id returned from ComfyUI")
|
|
|
|
return prompt_id
|
|
|
|
def poll_for_completion(self, prompt_id: str, comfy_url: str | None = None, timeout: int = 240) -> bool:
|
|
"""Poll ComfyUI history for workflow completion."""
|
|
url = comfy_url or self.base_url
|
|
headers = {'Content-Type': 'application/json'}
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout:
|
|
try:
|
|
req = urllib.request.Request(
|
|
f'http://{url}/history/{prompt_id}',
|
|
headers=headers,
|
|
method='GET'
|
|
)
|
|
|
|
with urllib.request.urlopen(req, timeout=30) as response:
|
|
history = json.loads(response.read().decode())
|
|
|
|
if prompt_id in history:
|
|
status = history[prompt_id].get('status', {})
|
|
if status.get('status_str') == 'success':
|
|
return True
|
|
|
|
time.sleep(2)
|
|
|
|
except urllib.error.HTTPError as e:
|
|
if e.code == 404:
|
|
time.sleep(2)
|
|
else:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error polling history: {e}")
|
|
time.sleep(2)
|
|
|
|
return False
|
|
|
|
def poll_for_batch_completion(self, batch_id: str, timeout: float = 300.0) -> list[str]:
|
|
"""Poll until all masks in a batch are received via webhook."""
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout:
|
|
batch = batch_storage.get_batch(batch_id)
|
|
if batch and batch_storage.is_complete(batch_id):
|
|
masks = batch['masks']
|
|
return [masks[i] for i in sorted(masks.keys())]
|
|
time.sleep(1)
|
|
|
|
return []
|
|
|
|
def prepare_mask_workflow(
|
|
self,
|
|
base_image: Image.Image,
|
|
subject: str,
|
|
webhook_url: str,
|
|
seed: int,
|
|
batch_id: str = None,
|
|
mask_index: int = 0,
|
|
polygon_points: list | None = None,
|
|
polygon_color: str = '#FF0000',
|
|
polygon_width: int = 2,
|
|
workflow_template: dict | None = None
|
|
) -> dict:
|
|
"""Prepare the mask extraction workflow."""
|
|
workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {}
|
|
|
|
img = base_image.copy()
|
|
|
|
if polygon_points and len(polygon_points) >= 3:
|
|
w, h = img.size
|
|
pixel_points = [(int(p['x'] * w), int(p['y'] * h)) for p in polygon_points]
|
|
|
|
draw = ImageDraw.Draw(img)
|
|
hex_color = polygon_color if len(polygon_color) == 7 else polygon_color + 'FF'
|
|
draw.polygon(pixel_points, outline=hex_color, width=polygon_width)
|
|
|
|
img_io = io.BytesIO()
|
|
img.save(img_io, format='PNG')
|
|
img_io.seek(0)
|
|
base64_image = base64.b64encode(img_io.read()).decode('utf-8')
|
|
|
|
if "87" in workflow:
|
|
workflow["87"]["inputs"]["image"] = base64_image
|
|
|
|
if "1:68" in workflow and 'inputs' in workflow["1:68"]:
|
|
workflow["1:68"]["inputs"]["prompt"] = f"Create a black and white alpha mask of {subject}, leaving everything else black"
|
|
|
|
if "96" in workflow and 'inputs' in workflow["96"]:
|
|
workflow["96"]["inputs"]["webhook_url"] = webhook_url
|
|
|
|
if "50" in workflow and 'inputs' in workflow["50"]:
|
|
workflow["50"]["inputs"]["seed"] = seed
|
|
|
|
if "96" in workflow and 'inputs' in workflow["96"]:
|
|
metadata = f"{batch_id}:{mask_index}" if batch_id else str(mask_index)
|
|
workflow["96"]["inputs"]["external_uid"] = metadata
|
|
|
|
return workflow
|