"""ComfyUI integration service for mask extraction.""" import base64 import io import json import random import threading import time import urllib.error import urllib.parse import urllib.request from pathlib import Path from typing import Any import logging from PIL import Image, ImageDraw logger = logging.getLogger(__name__) class BatchStorage: """Thread-safe storage for batch mask extraction results.""" def __init__(self): self._batches: dict[str, dict[str, Any]] = {} self._lock = threading.Lock() def create_batch(self, batch_id: str, count: int) -> None: """Initialize a batch with expected count.""" with self._lock: self._batches[batch_id] = { 'expected': count, 'masks': {}, 'errors': [] } def add_mask(self, batch_id: str, index: int, path: Path) -> None: """Add a completed mask to the batch.""" with self._lock: if batch_id in self._batches: self._batches[batch_id]['masks'][index] = str(path) def add_error(self, batch_id: str, error: str) -> None: """Add an error to the batch.""" with self._lock: if batch_id in self._batches: self._batches[batch_id]['errors'].append(error) def get_batch(self, batch_id: str) -> dict | None: """Get batch status.""" with self._lock: return self._batches.get(batch_id) def is_complete(self, batch_id: str) -> bool: """Check if batch has all expected masks.""" with self._lock: batch = self._batches.get(batch_id) if not batch: return False return len(batch['masks']) >= batch['expected'] def clear_batch(self, batch_id: str) -> None: """Remove a batch from storage.""" with self._lock: self._batches.pop(batch_id, None) batch_storage = BatchStorage() class ComfyUIService: """Service for interacting with ComfyUI API.""" def __init__(self, base_url: str): self.base_url = base_url def submit_workflow(self, workflow: dict, comfy_url: str | None = None) -> str: """Submit a workflow to ComfyUI and return the prompt_id.""" url = comfy_url or self.base_url headers = {'Content-Type': 'application/json'} payload = json.dumps({"prompt": workflow}).encode('utf-8') req = urllib.request.Request( f'http://{url}/prompt', data=payload, headers=headers, method='POST' ) with urllib.request.urlopen(req, timeout=30) as response: result = json.loads(response.read().decode()) prompt_id = result.get('prompt_id') if not prompt_id: raise RuntimeError("No prompt_id returned from ComfyUI") return prompt_id def poll_for_completion(self, prompt_id: str, comfy_url: str | None = None, timeout: int = 240) -> bool: """Poll ComfyUI history for workflow completion.""" url = comfy_url or self.base_url headers = {'Content-Type': 'application/json'} start_time = time.time() while time.time() - start_time < timeout: try: req = urllib.request.Request( f'http://{url}/history/{prompt_id}', headers=headers, method='GET' ) with urllib.request.urlopen(req, timeout=30) as response: history = json.loads(response.read().decode()) if prompt_id in history: status = history[prompt_id].get('status', {}) if status.get('status_str') == 'success': return True time.sleep(2) except urllib.error.HTTPError as e: if e.code == 404: time.sleep(2) else: raise except Exception as e: logger.error(f"Error polling history: {e}") time.sleep(2) return False def poll_for_batch_completion(self, batch_id: str, timeout: float = 300.0) -> list[str]: """Poll until all masks in a batch are received via webhook.""" start_time = time.time() while time.time() - start_time < timeout: batch = batch_storage.get_batch(batch_id) if batch and batch_storage.is_complete(batch_id): masks = batch['masks'] return [masks[i] for i in sorted(masks.keys())] time.sleep(1) return [] def prepare_mask_workflow( self, base_image: Image.Image, subject: str, webhook_url: str, seed: int, batch_id: str = None, mask_index: int = 0, polygon_points: list | None = None, polygon_color: str = '#FF0000', polygon_width: int = 2, workflow_template: dict | None = None ) -> dict: """Prepare the mask extraction workflow.""" workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {} img = base_image.copy() if polygon_points and len(polygon_points) >= 3: w, h = img.size pixel_points = [(int(p['x'] * w), int(p['y'] * h)) for p in polygon_points] draw = ImageDraw.Draw(img) hex_color = polygon_color if len(polygon_color) == 7 else polygon_color + 'FF' draw.polygon(pixel_points, outline=hex_color, width=polygon_width) img_io = io.BytesIO() img.save(img_io, format='PNG') img_io.seek(0) base64_image = base64.b64encode(img_io.read()).decode('utf-8') if "87" in workflow: workflow["87"]["inputs"]["image"] = base64_image if "1:68" in workflow and 'inputs' in workflow["1:68"]: workflow["1:68"]["inputs"]["prompt"] = f"Create a black and white alpha mask of {subject}, leaving everything else black" if "96" in workflow and 'inputs' in workflow["96"]: workflow["96"]["inputs"]["webhook_url"] = webhook_url if "50" in workflow and 'inputs' in workflow["50"]: workflow["50"]["inputs"]["seed"] = seed if "96" in workflow and 'inputs' in workflow["96"]: metadata = f"{batch_id}:{mask_index}" if batch_id else str(mask_index) workflow["96"]["inputs"]["external_uid"] = metadata return workflow def prepare_sam_workflow( self, base_image: Image.Image, include_points: list, exclude_points: list, webhook_url: str, batch_id: str = None, workflow_template: dict | None = None ) -> dict: """Prepare the SAM3 rough mask workflow with user-provided points.""" workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {} img_io = io.BytesIO() base_image.save(img_io, format='PNG') img_io.seek(0) base64_image = base64.b64encode(img_io.read()).decode('utf-8') if "9" in workflow: workflow["9"]["inputs"]["image"] = base64_image all_points = [] for pt in include_points: all_points.append({ 'x': pt['x'], 'y': pt['y'], 'is_foreground': True }) for pt in exclude_points: all_points.append({ 'x': pt['x'], 'y': pt['y'], 'is_foreground': False }) point_nodes = {} point_node_ids = [] for i, pt in enumerate(all_points): node_id = str(100 + i) point_nodes[node_id] = { "inputs": { "x": pt['x'], "y": pt['y'], "is_foreground": pt['is_foreground'] }, "class_type": "SAM3CreatePoint", "_meta": {"title": f"SAM3 Point {i+1}"} } point_node_ids.append(node_id) for node_id, node_data in point_nodes.items(): workflow[node_id] = node_data if point_node_ids: combine_inputs = {} for i, node_id in enumerate(point_node_ids): combine_inputs[f"point_{i+1}"] = [node_id, 0] workflow["8"] = { "inputs": combine_inputs, "class_type": "SAM3CombinePoints", "_meta": {"title": "SAM3 Combine Points"} } if "1" in workflow: workflow["1"]["inputs"]["positive_points"] = ["8", 0] if "11" in workflow: workflow["11"]["inputs"]["webhook_url"] = webhook_url if batch_id: workflow["11"]["inputs"]["external_uid"] = f"{batch_id}:0" return workflow def prepare_mask_workflow_with_start( self, base_image: Image.Image, start_mask_image: Image.Image, subject: str, webhook_url: str, seed: int, batch_id: str = None, mask_index: int = 0, polygon_points: list | None = None, polygon_color: str = '#FF0000', polygon_width: int = 2, denoise_strength: float = 0.8, workflow_template: dict | None = None ) -> dict: """Prepare the mask extraction workflow with a starting mask (lower denoise).""" workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {} img = base_image.copy() if polygon_points and len(polygon_points) >= 3: w, h = img.size pixel_points = [(int(p['x'] * w), int(p['y'] * h)) for p in polygon_points] draw = ImageDraw.Draw(img) hex_color = polygon_color if len(polygon_color) == 7 else polygon_color + 'FF' draw.polygon(pixel_points, outline=hex_color, width=polygon_width) img_io = io.BytesIO() img.save(img_io, format='PNG') img_io.seek(0) base64_image = base64.b64encode(img_io.read()).decode('utf-8') if "87" in workflow: workflow["87"]["inputs"]["image"] = base64_image start_mask_io = io.BytesIO() start_mask_image.save(start_mask_io, format='PNG') start_mask_io.seek(0) start_mask_base64 = base64.b64encode(start_mask_io.read()).decode('utf-8') if "200" in workflow: workflow["200"]["inputs"]["image"] = start_mask_base64 if "1:68" in workflow and 'inputs' in workflow["1:68"]: workflow["1:68"]["inputs"]["prompt"] = f"Create a black and white alpha mask of {subject}, leaving everything else black" if "96" in workflow and 'inputs' in workflow["96"]: workflow["96"]["inputs"]["webhook_url"] = webhook_url if "50" in workflow and 'inputs' in workflow["50"]: workflow["50"]["inputs"]["seed"] = seed if "1:65" in workflow and 'inputs' in workflow["1:65"]: workflow["1:65"]["inputs"]["denoise"] = denoise_strength if "96" in workflow and 'inputs' in workflow["96"]: metadata = f"{batch_id}:{mask_index}" if batch_id else str(mask_index) workflow["96"]["inputs"]["external_uid"] = metadata return workflow