"""Mask extraction routes for ORA Editor.""" import io import json import random import string import time import zipfile from pathlib import Path from flask import Blueprint, request, jsonify, make_response from ora_editor.config import APP_DIR, COMFYUI_BASE_URL, TEMP_DIR from ora_editor.services import polygon_storage from ora_editor.services.comfyui import ComfyUIService, batch_storage from ora_editor.ora_ops import parse_stack_xml import logging logger = logging.getLogger(__name__) mask_bp = Blueprint('mask', __name__) comfy_service = ComfyUIService(COMFYUI_BASE_URL) def generate_batch_id() -> str: """Generate a unique batch ID.""" return ''.join(random.choices(string.ascii_lowercase + string.digits, k=8)) @mask_bp.route('/api/webhook/comfyui', methods=['POST']) def api_webhook_comfyui(): """Webhook endpoint for ComfyUI to post completed mask images.""" logger.info("[WEBHOOK] Received webhook from ComfyUI") try: metadata = request.form.get('metadata', '') if request.form else '' external_uid = request.form.get('external_uid', '') if request.form else '' logger.info(f"[WEBHOOK] Metadata: {metadata}") logger.info(f"[WEBHOOK] External UID: {external_uid}") batch_id = None mask_index = 0 if external_uid and ':' in external_uid: parts = external_uid.split(':') batch_id = parts[0] mask_index = int(parts[1]) if len(parts) > 1 else 0 img_file = None if 'file' in request.files: img_file = request.files['file'] elif 'image' in request.files: img_file = request.files['image'] elif request.files: img_file = list(request.files.values())[0] if img_file: timestamp = str(int(time.time() * 1000)) if batch_id: final_mask_path = TEMP_DIR / f"mask_{batch_id}_{mask_index}.png" else: final_mask_path = TEMP_DIR / f"mask_{timestamp}.png" from PIL import Image img = Image.open(img_file).convert('RGBA') img.save(str(final_mask_path), format='PNG') logger.info(f"[WEBHOOK] Image saved to {final_mask_path}") if batch_id: batch_storage.add_mask(batch_id, mask_index, final_mask_path) return make_response(jsonify({'status': 'ok', 'message': 'Image received'}), 200) elif request.data: timestamp = str(int(time.time() * 1000)) if batch_id: final_mask_path = TEMP_DIR / f"mask_{batch_id}_{mask_index}.png" else: final_mask_path = TEMP_DIR / f"mask_{timestamp}.png" with open(final_mask_path, 'wb') as f: f.write(request.data) logger.info(f"[WEBHOOK] Raw data saved to {final_mask_path}") if batch_id: batch_storage.add_mask(batch_id, mask_index, final_mask_path) return make_response(jsonify({'status': 'ok', 'message': 'Image received'}), 200) else: logger.error("[WEBHOOK] No image data in request") return make_response(jsonify({'status': 'error', 'message': 'No image data'}), 400) except Exception as e: logger.error(f"[WEBHOOK] Error: {e}") import traceback traceback.print_exc() return make_response(jsonify({'status': 'error', 'message': str(e)}), 500) @mask_bp.route('/api/mask/extract', methods=['POST']) def api_mask_extract(): """Extract one or more masks from the current base image using ComfyUI.""" data = request.get_json() required = ['subject', 'ora_path'] for field in required: if field not in data: return jsonify({'success': False, 'error': f'Missing {field} parameter'}), 400 subject = data['subject'] use_polygon = data.get('use_polygon', False) ora_path = data['ora_path'] comfy_url = data.get('comfy_url', COMFYUI_BASE_URL) count = min(max(data.get('count', 1), 1), 10) start_mask_path = data.get('start_mask_path', None) denoise_strength = data.get('denoise_strength', 0.8) logger.info(f"[MASK EXTRACT] Subject: {subject}") logger.info(f"[MASK EXTRACT] Use polygon: {use_polygon}") logger.info(f"[MASK EXTRACT] ORA path: {ora_path}") logger.info(f"[MASK EXTRACT] ComfyUI URL: {comfy_url}") logger.info(f"[MASK EXTRACT] Count: {count}") logger.info(f"[MASK EXTRACT] Start mask path: {start_mask_path}") if start_mask_path: workflow_path = APP_DIR.parent / "image_mask_extraction_with_start.json" else: workflow_path = APP_DIR.parent / "image_mask_extraction.json" if not workflow_path.exists(): logger.error(f"[MASK EXTRACT] Workflow file not found: {workflow_path}") return jsonify({'success': False, 'error': f'Workflow file not found: {workflow_path}'}), 500 with open(workflow_path) as f: workflow_template = json.load(f) logger.info(f"[MASK EXTRACT] Loaded workflow") base_img = None try: with zipfile.ZipFile(ora_path, 'r') as zf: img_data = zf.read('mergedimage.png') base_img = __import__('PIL').Image.open(io.BytesIO(img_data)).convert('RGBA') logger.info(f"[MASK EXTRACT] Loaded base image: {base_img.size}") except Exception as e: logger.error(f"[MASK EXTRACT] Error loading base image: {e}") return jsonify({'success': False, 'error': f'Error loading image: {str(e)}'}), 500 start_mask_img = None if start_mask_path: try: start_mask_img = __import__('PIL').Image.open(start_mask_path).convert('RGBA') if base_img.size != start_mask_img.size: start_mask_img = start_mask_img.resize(base_img.size, __import__('PIL').Image.LANCZOS) logger.info(f"[MASK EXTRACT] Loaded start mask: {start_mask_img.size}") except Exception as e: logger.error(f"[MASK EXTRACT] Error loading start mask: {e}") return jsonify({'success': False, 'error': f'Error loading start mask: {str(e)}'}), 500 polygon_points = None polygon_color = '#FF0000' polygon_width = 2 if use_polygon: subject = subject + ' outlined in red' poly_data = polygon_storage.get(ora_path) if poly_data: polygon_points = poly_data.get('points', []) polygon_color = poly_data.get('color', '#FF0000') polygon_width = poly_data.get('width', 2) if not polygon_points or len(polygon_points) < 3: logger.warning(f"[MASK EXTRACT] Use polygon requested but no valid polygon points stored") return jsonify({'success': False, 'error': 'No valid polygon points. Please draw polygon first.'}), 400 batch_id = generate_batch_id() batch_storage.create_batch(batch_id, count) webhook_url = f"http://localhost:5001/api/webhook/comfyui" seeds = [random.randint(0, 2**31-1) for _ in range(count)] prompt_ids = [] for i in range(count): if start_mask_img: workflow = comfy_service.prepare_mask_workflow_with_start( base_image=base_img, start_mask_image=start_mask_img, subject=subject, webhook_url=webhook_url, seed=seeds[i], batch_id=batch_id, mask_index=i, polygon_points=polygon_points, polygon_color=polygon_color, polygon_width=polygon_width, denoise_strength=denoise_strength, workflow_template=workflow_template ) else: workflow = comfy_service.prepare_mask_workflow( base_image=base_img, subject=subject, webhook_url=webhook_url, seed=seeds[i], batch_id=batch_id, mask_index=i, polygon_points=polygon_points, polygon_color=polygon_color, polygon_width=polygon_width, workflow_template=workflow_template ) logger.info(f"[MASK EXTRACT] Workflow {i} prepared, sending to ComfyUI at http://{comfy_url}") try: prompt_id = comfy_service.submit_workflow(workflow, comfy_url) prompt_ids.append(prompt_id) logger.info(f"[MASK EXTRACT] Prompt {i} submitted with ID: {prompt_id}, seed: {seeds[i]}") except Exception as e: logger.error(f"[MASK EXTRACT] Error submitting workflow {i}: {e}") return jsonify({'success': False, 'error': f'Failed to connect to ComfyUI: {str(e)}'}), 500 all_completed = True for prompt_id in prompt_ids: completed = comfy_service.poll_for_completion(prompt_id, comfy_url, timeout=240) if not completed: all_completed = False logger.warning(f"[MASK EXTRACT] Prompt {prompt_id} did not complete in time") if not all_completed: logger.error("[MASK EXTRACT] Timeout waiting for some ComfyUI workflows to complete") return jsonify({'success': False, 'error': 'Mask extraction timed out'}), 500 logger.info(f"[MASK EXTRACT] All workflows completed, waiting for webhooks...") mask_paths = comfy_service.poll_for_batch_completion(batch_id, timeout=120.0) if not mask_paths: logger.error(f"[MASK EXTRACT] No masks received via webhook") return jsonify({'success': False, 'error': 'No masks received from ComfyUI'}), 500 logger.info(f"[MASK EXTRACT] Received {len(mask_paths)} masks") return jsonify({ 'success': True, 'batch_id': batch_id, 'mask_paths': [str(p) for p in mask_paths], 'mask_urls': [f'/api/file/mask?path={p}' for p in mask_paths], 'count': len(mask_paths) })