From c8932fdbf8a7874cb8d6102554c5731382460b57 Mon Sep 17 00:00:00 2001 From: Bryce Date: Fri, 27 Mar 2026 21:36:20 -0700 Subject: [PATCH] Add multi-mask extraction with count selector and navigation - Add count selector (1-10) for generating multiple mask variations - Each mask gets a unique random seed - Add left/right arrow navigation in mask preview modal when multiple masks exist - Batch storage system for tracking multiple concurrent extractions - Webhook handler now uses batch_id:mask_index for routing responses --- tools/ora_editor/routes/mask.py | 196 ++++++++++++------ tools/ora_editor/services/comfyui.py | 131 ++++++------ .../templates/components/sidebar.html | 24 ++- tools/ora_editor/templates/editor.html | 44 +++- .../templates/modals/mask_preview.html | 53 ++++- 5 files changed, 306 insertions(+), 142 deletions(-) diff --git a/tools/ora_editor/routes/mask.py b/tools/ora_editor/routes/mask.py index 8eac804..dd38a3b 100644 --- a/tools/ora_editor/routes/mask.py +++ b/tools/ora_editor/routes/mask.py @@ -2,6 +2,8 @@ import io import json +import random +import string import time import zipfile from pathlib import Path @@ -9,7 +11,7 @@ from flask import Blueprint, request, jsonify, make_response from ora_editor.config import APP_DIR, COMFYUI_BASE_URL, TEMP_DIR from ora_editor.services import polygon_storage -from ora_editor.services.comfyui import ComfyUIService +from ora_editor.services.comfyui import ComfyUIService, batch_storage from ora_editor.ora_ops import parse_stack_xml import logging @@ -19,31 +21,88 @@ mask_bp = Blueprint('mask', __name__) comfy_service = ComfyUIService(COMFYUI_BASE_URL) +def generate_batch_id() -> str: + """Generate a unique batch ID.""" + return ''.join(random.choices(string.ascii_lowercase + string.digits, k=8)) + + @mask_bp.route('/api/webhook/comfyui', methods=['POST']) def api_webhook_comfyui(): """Webhook endpoint for ComfyUI to post completed mask images.""" logger.info("[WEBHOOK] Received webhook from ComfyUI") - result = comfy_service.handle_webhook( - request.files, - request.form if request.form else {}, - request.data, - TEMP_DIR - ) - - if result.get('success'): - response = make_response(jsonify({'status': 'ok', 'message': 'Image received'})) - response.status_code = 200 - else: - response = make_response(jsonify({'status': 'error', 'message': result.get('error', 'Unknown error')})) - response.status_code = 500 - - return response + try: + metadata = request.form.get('metadata', '') if request.form else '' + external_uid = request.form.get('external_uid', '') if request.form else '' + + logger.info(f"[WEBHOOK] Metadata: {metadata}") + logger.info(f"[WEBHOOK] External UID: {external_uid}") + + batch_id = None + mask_index = 0 + + if external_uid and ':' in external_uid: + parts = external_uid.split(':') + batch_id = parts[0] + mask_index = int(parts[1]) if len(parts) > 1 else 0 + + img_file = None + if 'file' in request.files: + img_file = request.files['file'] + elif 'image' in request.files: + img_file = request.files['image'] + elif request.files: + img_file = list(request.files.values())[0] + + if img_file: + timestamp = str(int(time.time() * 1000)) + if batch_id: + final_mask_path = TEMP_DIR / f"mask_{batch_id}_{mask_index}.png" + else: + final_mask_path = TEMP_DIR / f"mask_{timestamp}.png" + + from PIL import Image + img = Image.open(img_file).convert('RGBA') + img.save(str(final_mask_path), format='PNG') + + logger.info(f"[WEBHOOK] Image saved to {final_mask_path}") + + if batch_id: + batch_storage.add_mask(batch_id, mask_index, final_mask_path) + + return make_response(jsonify({'status': 'ok', 'message': 'Image received'}), 200) + + elif request.data: + timestamp = str(int(time.time() * 1000)) + if batch_id: + final_mask_path = TEMP_DIR / f"mask_{batch_id}_{mask_index}.png" + else: + final_mask_path = TEMP_DIR / f"mask_{timestamp}.png" + + with open(final_mask_path, 'wb') as f: + f.write(request.data) + + logger.info(f"[WEBHOOK] Raw data saved to {final_mask_path}") + + if batch_id: + batch_storage.add_mask(batch_id, mask_index, final_mask_path) + + return make_response(jsonify({'status': 'ok', 'message': 'Image received'}), 200) + + else: + logger.error("[WEBHOOK] No image data in request") + return make_response(jsonify({'status': 'error', 'message': 'No image data'}), 400) + + except Exception as e: + logger.error(f"[WEBHOOK] Error: {e}") + import traceback + traceback.print_exc() + return make_response(jsonify({'status': 'error', 'message': str(e)}), 500) @mask_bp.route('/api/mask/extract', methods=['POST']) def api_mask_extract(): - """Extract a mask from the current base image using ComfyUI.""" + """Extract one or more masks from the current base image using ComfyUI.""" data = request.get_json() required = ['subject', 'ora_path'] @@ -55,11 +114,13 @@ def api_mask_extract(): use_polygon = data.get('use_polygon', False) ora_path = data['ora_path'] comfy_url = data.get('comfy_url', COMFYUI_BASE_URL) + count = min(max(data.get('count', 1), 1), 10) logger.info(f"[MASK EXTRACT] Subject: {subject}") logger.info(f"[MASK EXTRACT] Use polygon: {use_polygon}") logger.info(f"[MASK EXTRACT] ORA path: {ora_path}") logger.info(f"[MASK EXTRACT] ComfyUI URL: {comfy_url}") + logger.info(f"[MASK EXTRACT] Count: {count}") workflow_path = APP_DIR.parent / "image_mask_extraction.json" if not workflow_path.exists(): @@ -96,66 +157,63 @@ def api_mask_extract(): logger.warning(f"[MASK EXTRACT] Use polygon requested but no valid polygon points stored") return jsonify({'success': False, 'error': 'No valid polygon points. Please draw polygon first.'}), 400 + batch_id = generate_batch_id() + batch_storage.create_batch(batch_id, count) + webhook_url = f"http://localhost:5001/api/webhook/comfyui" - workflow = comfy_service.prepare_mask_workflow( - base_image=base_img, - subject=subject, - webhook_url=webhook_url, - polygon_points=polygon_points, - polygon_color=polygon_color, - polygon_width=polygon_width, - workflow_template=workflow_template - ) + seeds = [random.randint(0, 2**31-1) for _ in range(count)] + prompt_ids = [] - logger.info(f"[MASK EXTRACT] Workflow prepared, sending to ComfyUI at http://{comfy_url}") + for i in range(count): + workflow = comfy_service.prepare_mask_workflow( + base_image=base_img, + subject=subject, + webhook_url=webhook_url, + seed=seeds[i], + batch_id=batch_id, + mask_index=i, + polygon_points=polygon_points, + polygon_color=polygon_color, + polygon_width=polygon_width, + workflow_template=workflow_template + ) + + logger.info(f"[MASK EXTRACT] Workflow {i} prepared, sending to ComfyUI at http://{comfy_url}") + + try: + prompt_id = comfy_service.submit_workflow(workflow, comfy_url) + prompt_ids.append(prompt_id) + logger.info(f"[MASK EXTRACT] Prompt {i} submitted with ID: {prompt_id}, seed: {seeds[i]}") + except Exception as e: + logger.error(f"[MASK EXTRACT] Error submitting workflow {i}: {e}") + return jsonify({'success': False, 'error': f'Failed to connect to ComfyUI: {str(e)}'}), 500 - try: - prompt_id = comfy_service.submit_workflow(workflow, comfy_url) - logger.info(f"[MASK EXTRACT] Prompt submitted with ID: {prompt_id}") - except Exception as e: - logger.error(f"[MASK EXTRACT] Error submitting to ComfyUI: {e}") - return jsonify({'success': False, 'error': f'Failed to connect to ComfyUI: {str(e)}'}), 500 + all_completed = True + for prompt_id in prompt_ids: + completed = comfy_service.poll_for_completion(prompt_id, comfy_url, timeout=240) + if not completed: + all_completed = False + logger.warning(f"[MASK EXTRACT] Prompt {prompt_id} did not complete in time") - completed = comfy_service.poll_for_completion(prompt_id, comfy_url, timeout=240) - - if not completed: - logger.error("[MASK EXTRACT] Timeout waiting for ComfyUI to complete") + if not all_completed: + logger.error("[MASK EXTRACT] Timeout waiting for some ComfyUI workflows to complete") return jsonify({'success': False, 'error': 'Mask extraction timed out'}), 500 - logger.info(f"[MASK EXTRACT] Checking/waiting for webhook callback from ComfyUI...") + logger.info(f"[MASK EXTRACT] All workflows completed, waiting for webhooks...") - webhook_result = comfy_service.wait_for_webhook(timeout=60.0) + mask_paths = comfy_service.poll_for_batch_completion(batch_id, timeout=120.0) - if not webhook_result: - logger.error(f"[MASK EXTRACT] Timeout waiting for webhook callback") - return jsonify({'success': False, 'error': 'Webhook timeout - mask extraction may have failed'}), 500 + if not mask_paths: + logger.error(f"[MASK EXTRACT] No masks received via webhook") + return jsonify({'success': False, 'error': 'No masks received from ComfyUI'}), 500 - logger.info(f"[MASK EXTRACT] Webhook received: {webhook_result}") + logger.info(f"[MASK EXTRACT] Received {len(mask_paths)} masks") - if not webhook_result.get('success'): - error_msg = webhook_result.get('error', 'Unknown error') - logger.error(f"[MASK EXTRACT] Webhook failed: {error_msg}") - return jsonify({'success': False, 'error': f'Webhook error: {error_msg}'}), 500 - - final_mask_path = webhook_result.get('path') - - if not final_mask_path: - logger.error("[MASK EXTRACT] No mask path in webhook response") - return jsonify({'success': False, 'error': 'No mask path returned from webhook'}), 500 - - try: - if not Path(final_mask_path).exists(): - logger.error(f"[MASK EXTRACT] Mask file not found at {final_mask_path}") - return jsonify({'success': False, 'error': f'Mask file not found: {final_mask_path}'}), 500 - - logger.info(f"[MASK EXTRACT] Mask received via webhook at {final_mask_path}") - - return jsonify({ - 'success': True, - 'mask_path': str(final_mask_path), - 'mask_url': f'/api/file/mask?path={final_mask_path}' - }) - except Exception as e: - logger.error(f"[MASK EXTRACT] Error processing mask file: {e}") - return jsonify({'success': False, 'error': f'Error accessing mask: {str(e)}'}), 500 + return jsonify({ + 'success': True, + 'batch_id': batch_id, + 'mask_paths': [str(p) for p in mask_paths], + 'mask_urls': [f'/api/file/mask?path={p}' for p in mask_paths], + 'count': len(mask_paths) + }) diff --git a/tools/ora_editor/services/comfyui.py b/tools/ora_editor/services/comfyui.py index bfc11dd..59e0cd9 100644 --- a/tools/ora_editor/services/comfyui.py +++ b/tools/ora_editor/services/comfyui.py @@ -18,13 +18,61 @@ from PIL import Image, ImageDraw logger = logging.getLogger(__name__) +class BatchStorage: + """Thread-safe storage for batch mask extraction results.""" + + def __init__(self): + self._batches: dict[str, dict[str, Any]] = {} + self._lock = threading.Lock() + + def create_batch(self, batch_id: str, count: int) -> None: + """Initialize a batch with expected count.""" + with self._lock: + self._batches[batch_id] = { + 'expected': count, + 'masks': {}, + 'errors': [] + } + + def add_mask(self, batch_id: str, index: int, path: Path) -> None: + """Add a completed mask to the batch.""" + with self._lock: + if batch_id in self._batches: + self._batches[batch_id]['masks'][index] = str(path) + + def add_error(self, batch_id: str, error: str) -> None: + """Add an error to the batch.""" + with self._lock: + if batch_id in self._batches: + self._batches[batch_id]['errors'].append(error) + + def get_batch(self, batch_id: str) -> dict | None: + """Get batch status.""" + with self._lock: + return self._batches.get(batch_id) + + def is_complete(self, batch_id: str) -> bool: + """Check if batch has all expected masks.""" + with self._lock: + batch = self._batches.get(batch_id) + if not batch: + return False + return len(batch['masks']) >= batch['expected'] + + def clear_batch(self, batch_id: str) -> None: + """Remove a batch from storage.""" + with self._lock: + self._batches.pop(batch_id, None) + + +batch_storage = BatchStorage() + + class ComfyUIService: """Service for interacting with ComfyUI API.""" def __init__(self, base_url: str): self.base_url = base_url - self._webhook_response: dict | None = None - self._webhook_ready = threading.Event() def submit_workflow(self, workflow: dict, comfy_url: str | None = None) -> str: """Submit a workflow to ComfyUI and return the prompt_id.""" @@ -83,79 +131,34 @@ class ComfyUIService: return False - def wait_for_webhook(self, timeout: float = 60.0) -> dict | None: - """Wait for webhook callback from ComfyUI.""" - if self._webhook_ready.is_set() and self._webhook_response is not None: - return self._webhook_response + def poll_for_batch_completion(self, batch_id: str, timeout: float = 300.0) -> list[str]: + """Poll until all masks in a batch are received via webhook.""" + start_time = time.time() - self._webhook_ready.clear() - self._webhook_response = None + while time.time() - start_time < timeout: + batch = batch_storage.get_batch(batch_id) + if batch and batch_storage.is_complete(batch_id): + masks = batch['masks'] + return [masks[i] for i in sorted(masks.keys())] + time.sleep(1) - webhook_received = self._webhook_ready.wait(timeout=timeout) - - if webhook_received: - return self._webhook_response - return None - - def handle_webhook(self, request_files, request_form, request_data, temp_dir: Path) -> dict: - """Handle incoming webhook from ComfyUI.""" - self._webhook_response = None - self._webhook_ready.clear() - - try: - img_file = None - if 'file' in request_files: - img_file = request_files['file'] - elif 'image' in request_files: - img_file = request_files['image'] - elif request_files: - img_file = list(request_files.values())[0] - - if img_file: - timestamp = str(int(time.time())) - final_mask_path = temp_dir / f"mask_{timestamp}.png" - - img = Image.open(img_file).convert('RGBA') - img.save(str(final_mask_path), format='PNG') - - logger.info(f"[WEBHOOK] Image saved to {final_mask_path}") - self._webhook_response = {'success': True, 'path': final_mask_path} - - elif request_data: - timestamp = str(int(time.time())) - final_mask_path = temp_dir / f"mask_{timestamp}.png" - - with open(final_mask_path, 'wb') as f: - f.write(request_data) - - logger.info(f"[WEBHOOK] Raw data saved to {final_mask_path}") - self._webhook_response = {'success': True, 'path': final_mask_path} - - else: - logger.error("[WEBHOOK] No image data in request") - self._webhook_response = {'success': False, 'error': 'No image data received'} - - self._webhook_ready.set() - return self._webhook_response - - except Exception as e: - logger.error(f"[WEBHOOK] Error: {e}") - self._webhook_response = {'success': False, 'error': str(e)} - self._webhook_ready.set() - return self._webhook_response + return [] def prepare_mask_workflow( self, base_image: Image.Image, subject: str, webhook_url: str, + seed: int, + batch_id: str = None, + mask_index: int = 0, polygon_points: list | None = None, polygon_color: str = '#FF0000', polygon_width: int = 2, workflow_template: dict | None = None ) -> dict: """Prepare the mask extraction workflow.""" - workflow = workflow_template.copy() if workflow_template else {} + workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {} img = base_image.copy() @@ -182,6 +185,10 @@ class ComfyUIService: workflow["96"]["inputs"]["webhook_url"] = webhook_url if "50" in workflow and 'inputs' in workflow["50"]: - workflow["50"]["inputs"]["seed"] = random.randint(0, 2**31-1) + workflow["50"]["inputs"]["seed"] = seed + + if "96" in workflow and 'inputs' in workflow["96"]: + metadata = f"{batch_id}:{mask_index}" if batch_id else str(mask_index) + workflow["96"]["inputs"]["external_uid"] = metadata return workflow diff --git a/tools/ora_editor/templates/components/sidebar.html b/tools/ora_editor/templates/components/sidebar.html index 6e16109..ece41d2 100644 --- a/tools/ora_editor/templates/components/sidebar.html +++ b/tools/ora_editor/templates/components/sidebar.html @@ -110,13 +110,35 @@ Use polygon hint + +
+ + + (different seeds) +
+ diff --git a/tools/ora_editor/templates/editor.html b/tools/ora_editor/templates/editor.html index 9a15677..f5ff262 100644 --- a/tools/ora_editor/templates/editor.html +++ b/tools/ora_editor/templates/editor.html @@ -135,11 +135,14 @@ function oraEditor() { // Mask extraction maskSubject: '', usePolygonHint: true, + maskCount: 3, isExtracting: false, maskViewMode: 'with-bg', showMaskModal: false, tempMaskPath: null, tempMaskUrl: null, + tempMaskPaths: [], + currentMaskIndex: 0, lastError: '', // Settings @@ -586,7 +589,7 @@ function oraEditor() { async extractMask() { if (!this.maskSubject.trim()) return; - console.log('[ORA EDITOR] Extracting mask for:', this.maskSubject); + console.log('[ORA EDITOR] Extracting', this.maskCount, 'masks for:', this.maskSubject); this.isExtracting = true; this.lastError = ''; @@ -598,7 +601,8 @@ function oraEditor() { subject: this.maskSubject, use_polygon: this.usePolygonHint && this.polygonPoints.length >= 3, ora_path: this.oraPath, - comfy_url: this.comfyUrl + comfy_url: this.comfyUrl, + count: this.maskCount }) }); @@ -608,8 +612,10 @@ function oraEditor() { if (!data.success) throw new Error(data.error || 'Failed'); - this.tempMaskPath = data.mask_path; - this.tempMaskUrl = data.mask_url || '/api/file/mask?path=' + encodeURIComponent(data.mask_path); + this.tempMaskPaths = data.mask_paths || [data.mask_path]; + this.currentMaskIndex = 0; + this.tempMaskPath = this.tempMaskPaths[0]; + this.tempMaskUrl = data.mask_urls?.[0] || '/api/file/mask?path=' + encodeURIComponent(this.tempMaskPath); this.showMaskModal = true; } catch (e) { console.error('[ORA EDITOR] Error extracting mask:', e); @@ -618,13 +624,31 @@ function oraEditor() { this.isExtracting = false; } }, + + get currentMaskPath() { + return this.tempMaskPaths[this.currentMaskIndex] || null; + }, - rerollMask() { + previousMask() { + if (this.currentMaskIndex > 0) { + this.currentMaskIndex--; + this.tempMaskPath = this.tempMaskPaths[this.currentMaskIndex]; + } + }, + + nextMask() { + if (this.currentMaskIndex < this.tempMaskPaths.length - 1) { + this.currentMaskIndex++; + this.tempMaskPath = this.tempMaskPaths[this.currentMaskIndex]; + } + }, + + rerollMasks() { this.extractMask(); }, - async useMask() { - if (!this.tempMaskPath) return; + async useCurrentMask() { + if (!this.currentMaskPath) return; let finalEntityName = this.entityName || 'element'; const existingLayers = this.layers.filter(l => l.name.startsWith(finalEntityName + '_')); @@ -634,7 +658,7 @@ function oraEditor() { } finalEntityName = `${finalEntityName}_${counter}`; - console.log('[ORA EDITOR] Adding layer:', finalEntityName); + console.log('[ORA EDITOR] Adding layer:', finalEntityName, 'with mask:', this.currentMaskPath); await fetch('/api/layer/add', { method: 'POST', @@ -642,7 +666,7 @@ function oraEditor() { body: JSON.stringify({ ora_path: this.oraPath, entity_name: finalEntityName, - mask_path: this.tempMaskPath + mask_path: this.currentMaskPath }) }); @@ -663,6 +687,8 @@ function oraEditor() { this.showMaskModal = false; this.tempMaskPath = null; this.tempMaskUrl = null; + this.tempMaskPaths = []; + this.currentMaskIndex = 0; }, // === Krita Integration === diff --git a/tools/ora_editor/templates/modals/mask_preview.html b/tools/ora_editor/templates/modals/mask_preview.html index dc6a76f..fe22010 100644 --- a/tools/ora_editor/templates/modals/mask_preview.html +++ b/tools/ora_editor/templates/modals/mask_preview.html @@ -4,7 +4,12 @@ class="fixed inset-0 z-50 flex items-center justify-center bg-black bg-opacity-75">
-

Extracted Mask

+

+ Extracted Mask + + ( / ) + +

+
+ + + +
+ + + + +
+ + + +
+ +
+ + + +
+
+
+ +