Add multi-mask extraction with count selector and navigation

- Add count selector (1-10) for generating multiple mask variations
- Each mask gets a unique random seed
- Add left/right arrow navigation in mask preview modal when multiple masks exist
- Batch storage system for tracking multiple concurrent extractions
- Webhook handler now uses batch_id:mask_index for routing responses
This commit is contained in:
2026-03-27 21:36:20 -07:00
parent fb812e57bc
commit c8932fdbf8
5 changed files with 306 additions and 142 deletions

View File

@@ -2,6 +2,8 @@
import io
import json
import random
import string
import time
import zipfile
from pathlib import Path
@@ -9,7 +11,7 @@ from flask import Blueprint, request, jsonify, make_response
from ora_editor.config import APP_DIR, COMFYUI_BASE_URL, TEMP_DIR
from ora_editor.services import polygon_storage
from ora_editor.services.comfyui import ComfyUIService
from ora_editor.services.comfyui import ComfyUIService, batch_storage
from ora_editor.ora_ops import parse_stack_xml
import logging
@@ -19,31 +21,88 @@ mask_bp = Blueprint('mask', __name__)
comfy_service = ComfyUIService(COMFYUI_BASE_URL)
def generate_batch_id() -> str:
"""Generate a unique batch ID."""
return ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))
@mask_bp.route('/api/webhook/comfyui', methods=['POST'])
def api_webhook_comfyui():
"""Webhook endpoint for ComfyUI to post completed mask images."""
logger.info("[WEBHOOK] Received webhook from ComfyUI")
result = comfy_service.handle_webhook(
request.files,
request.form if request.form else {},
request.data,
TEMP_DIR
)
if result.get('success'):
response = make_response(jsonify({'status': 'ok', 'message': 'Image received'}))
response.status_code = 200
else:
response = make_response(jsonify({'status': 'error', 'message': result.get('error', 'Unknown error')}))
response.status_code = 500
return response
try:
metadata = request.form.get('metadata', '') if request.form else ''
external_uid = request.form.get('external_uid', '') if request.form else ''
logger.info(f"[WEBHOOK] Metadata: {metadata}")
logger.info(f"[WEBHOOK] External UID: {external_uid}")
batch_id = None
mask_index = 0
if external_uid and ':' in external_uid:
parts = external_uid.split(':')
batch_id = parts[0]
mask_index = int(parts[1]) if len(parts) > 1 else 0
img_file = None
if 'file' in request.files:
img_file = request.files['file']
elif 'image' in request.files:
img_file = request.files['image']
elif request.files:
img_file = list(request.files.values())[0]
if img_file:
timestamp = str(int(time.time() * 1000))
if batch_id:
final_mask_path = TEMP_DIR / f"mask_{batch_id}_{mask_index}.png"
else:
final_mask_path = TEMP_DIR / f"mask_{timestamp}.png"
from PIL import Image
img = Image.open(img_file).convert('RGBA')
img.save(str(final_mask_path), format='PNG')
logger.info(f"[WEBHOOK] Image saved to {final_mask_path}")
if batch_id:
batch_storage.add_mask(batch_id, mask_index, final_mask_path)
return make_response(jsonify({'status': 'ok', 'message': 'Image received'}), 200)
elif request.data:
timestamp = str(int(time.time() * 1000))
if batch_id:
final_mask_path = TEMP_DIR / f"mask_{batch_id}_{mask_index}.png"
else:
final_mask_path = TEMP_DIR / f"mask_{timestamp}.png"
with open(final_mask_path, 'wb') as f:
f.write(request.data)
logger.info(f"[WEBHOOK] Raw data saved to {final_mask_path}")
if batch_id:
batch_storage.add_mask(batch_id, mask_index, final_mask_path)
return make_response(jsonify({'status': 'ok', 'message': 'Image received'}), 200)
else:
logger.error("[WEBHOOK] No image data in request")
return make_response(jsonify({'status': 'error', 'message': 'No image data'}), 400)
except Exception as e:
logger.error(f"[WEBHOOK] Error: {e}")
import traceback
traceback.print_exc()
return make_response(jsonify({'status': 'error', 'message': str(e)}), 500)
@mask_bp.route('/api/mask/extract', methods=['POST'])
def api_mask_extract():
"""Extract a mask from the current base image using ComfyUI."""
"""Extract one or more masks from the current base image using ComfyUI."""
data = request.get_json()
required = ['subject', 'ora_path']
@@ -55,11 +114,13 @@ def api_mask_extract():
use_polygon = data.get('use_polygon', False)
ora_path = data['ora_path']
comfy_url = data.get('comfy_url', COMFYUI_BASE_URL)
count = min(max(data.get('count', 1), 1), 10)
logger.info(f"[MASK EXTRACT] Subject: {subject}")
logger.info(f"[MASK EXTRACT] Use polygon: {use_polygon}")
logger.info(f"[MASK EXTRACT] ORA path: {ora_path}")
logger.info(f"[MASK EXTRACT] ComfyUI URL: {comfy_url}")
logger.info(f"[MASK EXTRACT] Count: {count}")
workflow_path = APP_DIR.parent / "image_mask_extraction.json"
if not workflow_path.exists():
@@ -96,66 +157,63 @@ def api_mask_extract():
logger.warning(f"[MASK EXTRACT] Use polygon requested but no valid polygon points stored")
return jsonify({'success': False, 'error': 'No valid polygon points. Please draw polygon first.'}), 400
batch_id = generate_batch_id()
batch_storage.create_batch(batch_id, count)
webhook_url = f"http://localhost:5001/api/webhook/comfyui"
workflow = comfy_service.prepare_mask_workflow(
base_image=base_img,
subject=subject,
webhook_url=webhook_url,
polygon_points=polygon_points,
polygon_color=polygon_color,
polygon_width=polygon_width,
workflow_template=workflow_template
)
seeds = [random.randint(0, 2**31-1) for _ in range(count)]
prompt_ids = []
logger.info(f"[MASK EXTRACT] Workflow prepared, sending to ComfyUI at http://{comfy_url}")
for i in range(count):
workflow = comfy_service.prepare_mask_workflow(
base_image=base_img,
subject=subject,
webhook_url=webhook_url,
seed=seeds[i],
batch_id=batch_id,
mask_index=i,
polygon_points=polygon_points,
polygon_color=polygon_color,
polygon_width=polygon_width,
workflow_template=workflow_template
)
logger.info(f"[MASK EXTRACT] Workflow {i} prepared, sending to ComfyUI at http://{comfy_url}")
try:
prompt_id = comfy_service.submit_workflow(workflow, comfy_url)
prompt_ids.append(prompt_id)
logger.info(f"[MASK EXTRACT] Prompt {i} submitted with ID: {prompt_id}, seed: {seeds[i]}")
except Exception as e:
logger.error(f"[MASK EXTRACT] Error submitting workflow {i}: {e}")
return jsonify({'success': False, 'error': f'Failed to connect to ComfyUI: {str(e)}'}), 500
try:
prompt_id = comfy_service.submit_workflow(workflow, comfy_url)
logger.info(f"[MASK EXTRACT] Prompt submitted with ID: {prompt_id}")
except Exception as e:
logger.error(f"[MASK EXTRACT] Error submitting to ComfyUI: {e}")
return jsonify({'success': False, 'error': f'Failed to connect to ComfyUI: {str(e)}'}), 500
all_completed = True
for prompt_id in prompt_ids:
completed = comfy_service.poll_for_completion(prompt_id, comfy_url, timeout=240)
if not completed:
all_completed = False
logger.warning(f"[MASK EXTRACT] Prompt {prompt_id} did not complete in time")
completed = comfy_service.poll_for_completion(prompt_id, comfy_url, timeout=240)
if not completed:
logger.error("[MASK EXTRACT] Timeout waiting for ComfyUI to complete")
if not all_completed:
logger.error("[MASK EXTRACT] Timeout waiting for some ComfyUI workflows to complete")
return jsonify({'success': False, 'error': 'Mask extraction timed out'}), 500
logger.info(f"[MASK EXTRACT] Checking/waiting for webhook callback from ComfyUI...")
logger.info(f"[MASK EXTRACT] All workflows completed, waiting for webhooks...")
webhook_result = comfy_service.wait_for_webhook(timeout=60.0)
mask_paths = comfy_service.poll_for_batch_completion(batch_id, timeout=120.0)
if not webhook_result:
logger.error(f"[MASK EXTRACT] Timeout waiting for webhook callback")
return jsonify({'success': False, 'error': 'Webhook timeout - mask extraction may have failed'}), 500
if not mask_paths:
logger.error(f"[MASK EXTRACT] No masks received via webhook")
return jsonify({'success': False, 'error': 'No masks received from ComfyUI'}), 500
logger.info(f"[MASK EXTRACT] Webhook received: {webhook_result}")
logger.info(f"[MASK EXTRACT] Received {len(mask_paths)} masks")
if not webhook_result.get('success'):
error_msg = webhook_result.get('error', 'Unknown error')
logger.error(f"[MASK EXTRACT] Webhook failed: {error_msg}")
return jsonify({'success': False, 'error': f'Webhook error: {error_msg}'}), 500
final_mask_path = webhook_result.get('path')
if not final_mask_path:
logger.error("[MASK EXTRACT] No mask path in webhook response")
return jsonify({'success': False, 'error': 'No mask path returned from webhook'}), 500
try:
if not Path(final_mask_path).exists():
logger.error(f"[MASK EXTRACT] Mask file not found at {final_mask_path}")
return jsonify({'success': False, 'error': f'Mask file not found: {final_mask_path}'}), 500
logger.info(f"[MASK EXTRACT] Mask received via webhook at {final_mask_path}")
return jsonify({
'success': True,
'mask_path': str(final_mask_path),
'mask_url': f'/api/file/mask?path={final_mask_path}'
})
except Exception as e:
logger.error(f"[MASK EXTRACT] Error processing mask file: {e}")
return jsonify({'success': False, 'error': f'Error accessing mask: {str(e)}'}), 500
return jsonify({
'success': True,
'batch_id': batch_id,
'mask_paths': [str(p) for p in mask_paths],
'mask_urls': [f'/api/file/mask?path={p}' for p in mask_paths],
'count': len(mask_paths)
})