"""ComfyUI integration service for mask extraction.""" import base64 import io import json import random import threading import time import urllib.error import urllib.parse import urllib.request from pathlib import Path from typing import Any import logging from PIL import Image, ImageDraw logger = logging.getLogger(__name__) class ComfyUIService: """Service for interacting with ComfyUI API.""" def __init__(self, base_url: str): self.base_url = base_url self._webhook_response: dict | None = None self._webhook_ready = threading.Event() def submit_workflow(self, workflow: dict, comfy_url: str | None = None) -> str: """Submit a workflow to ComfyUI and return the prompt_id.""" url = comfy_url or self.base_url headers = {'Content-Type': 'application/json'} payload = json.dumps({"prompt": workflow}).encode('utf-8') req = urllib.request.Request( f'http://{url}/prompt', data=payload, headers=headers, method='POST' ) with urllib.request.urlopen(req, timeout=30) as response: result = json.loads(response.read().decode()) prompt_id = result.get('prompt_id') if not prompt_id: raise RuntimeError("No prompt_id returned from ComfyUI") return prompt_id def poll_for_completion(self, prompt_id: str, comfy_url: str | None = None, timeout: int = 240) -> bool: """Poll ComfyUI history for workflow completion.""" url = comfy_url or self.base_url headers = {'Content-Type': 'application/json'} start_time = time.time() while time.time() - start_time < timeout: try: req = urllib.request.Request( f'http://{url}/history/{prompt_id}', headers=headers, method='GET' ) with urllib.request.urlopen(req, timeout=30) as response: history = json.loads(response.read().decode()) if prompt_id in history: status = history[prompt_id].get('status', {}) if status.get('status_str') == 'success': return True time.sleep(2) except urllib.error.HTTPError as e: if e.code == 404: time.sleep(2) else: raise except Exception as e: logger.error(f"Error polling history: {e}") time.sleep(2) return False def wait_for_webhook(self, timeout: float = 60.0) -> dict | None: """Wait for webhook callback from ComfyUI.""" if self._webhook_ready.is_set() and self._webhook_response is not None: return self._webhook_response self._webhook_ready.clear() self._webhook_response = None webhook_received = self._webhook_ready.wait(timeout=timeout) if webhook_received: return self._webhook_response return None def handle_webhook(self, request_files, request_form, request_data, temp_dir: Path) -> dict: """Handle incoming webhook from ComfyUI.""" self._webhook_response = None self._webhook_ready.clear() try: img_file = None if 'file' in request_files: img_file = request_files['file'] elif 'image' in request_files: img_file = request_files['image'] elif request_files: img_file = list(request_files.values())[0] if img_file: timestamp = str(int(time.time())) final_mask_path = temp_dir / f"mask_{timestamp}.png" img = Image.open(img_file).convert('RGBA') img.save(str(final_mask_path), format='PNG') logger.info(f"[WEBHOOK] Image saved to {final_mask_path}") self._webhook_response = {'success': True, 'path': final_mask_path} elif request_data: timestamp = str(int(time.time())) final_mask_path = temp_dir / f"mask_{timestamp}.png" with open(final_mask_path, 'wb') as f: f.write(request_data) logger.info(f"[WEBHOOK] Raw data saved to {final_mask_path}") self._webhook_response = {'success': True, 'path': final_mask_path} else: logger.error("[WEBHOOK] No image data in request") self._webhook_response = {'success': False, 'error': 'No image data received'} self._webhook_ready.set() return self._webhook_response except Exception as e: logger.error(f"[WEBHOOK] Error: {e}") self._webhook_response = {'success': False, 'error': str(e)} self._webhook_ready.set() return self._webhook_response def prepare_mask_workflow( self, base_image: Image.Image, subject: str, webhook_url: str, polygon_points: list | None = None, polygon_color: str = '#FF0000', polygon_width: int = 2, workflow_template: dict | None = None ) -> dict: """Prepare the mask extraction workflow.""" workflow = workflow_template.copy() if workflow_template else {} img = base_image.copy() if polygon_points and len(polygon_points) >= 3: w, h = img.size pixel_points = [(int(p['x'] * w), int(p['y'] * h)) for p in polygon_points] draw = ImageDraw.Draw(img) hex_color = polygon_color if len(polygon_color) == 7 else polygon_color + 'FF' draw.polygon(pixel_points, outline=hex_color, width=polygon_width) img_io = io.BytesIO() img.save(img_io, format='PNG') img_io.seek(0) base64_image = base64.b64encode(img_io.read()).decode('utf-8') if "87" in workflow: workflow["87"]["inputs"]["image"] = base64_image if "1:68" in workflow and 'inputs' in workflow["1:68"]: workflow["1:68"]["inputs"]["prompt"] = f"Create a black and white alpha mask of {subject}, leaving everything else black" if "96" in workflow and 'inputs' in workflow["96"]: workflow["96"]["inputs"]["webhook_url"] = webhook_url if "50" in workflow and 'inputs' in workflow["50"]: workflow["50"]["inputs"]["seed"] = random.randint(0, 2**31-1) return workflow