- Add roughMaskThumbnailScale state with $watch to sync with main scale slider - Update sidebar thumbnail to use transform:scale() for consistent zoom between views - Modify openRoughMaskInNewWindow() to create HTML page with matching scale - Add denoise strength slider (10-100%) visible only when rough mask exists - Backend already supports denoise_strength parameter in prepare_mask_workflow_with_start() - Rough mask auto-clears after successful extraction - Add Playwright tests for UI changes and API parameter acceptance
254 lines
9.9 KiB
Python
254 lines
9.9 KiB
Python
"""Mask extraction routes for ORA Editor."""
|
|
|
|
import io
|
|
import json
|
|
import random
|
|
import string
|
|
import time
|
|
import zipfile
|
|
from pathlib import Path
|
|
from flask import Blueprint, request, jsonify, make_response
|
|
|
|
from ora_editor.config import APP_DIR, COMFYUI_BASE_URL, TEMP_DIR
|
|
from ora_editor.services import polygon_storage
|
|
from ora_editor.services.comfyui import ComfyUIService, batch_storage
|
|
from ora_editor.ora_ops import parse_stack_xml
|
|
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
mask_bp = Blueprint('mask', __name__)
|
|
comfy_service = ComfyUIService(COMFYUI_BASE_URL)
|
|
|
|
|
|
def generate_batch_id() -> str:
|
|
"""Generate a unique batch ID."""
|
|
return ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))
|
|
|
|
|
|
@mask_bp.route('/api/webhook/comfyui', methods=['POST'])
|
|
def api_webhook_comfyui():
|
|
"""Webhook endpoint for ComfyUI to post completed mask images."""
|
|
logger.info("[WEBHOOK] Received webhook from ComfyUI")
|
|
|
|
try:
|
|
metadata = request.form.get('metadata', '') if request.form else ''
|
|
external_uid = request.form.get('external_uid', '') if request.form else ''
|
|
|
|
logger.info(f"[WEBHOOK] Metadata: {metadata}")
|
|
logger.info(f"[WEBHOOK] External UID: {external_uid}")
|
|
|
|
batch_id = None
|
|
mask_index = 0
|
|
|
|
if external_uid and ':' in external_uid:
|
|
parts = external_uid.split(':')
|
|
batch_id = parts[0]
|
|
mask_index = int(parts[1]) if len(parts) > 1 else 0
|
|
|
|
img_file = None
|
|
if 'file' in request.files:
|
|
img_file = request.files['file']
|
|
elif 'image' in request.files:
|
|
img_file = request.files['image']
|
|
elif request.files:
|
|
img_file = list(request.files.values())[0]
|
|
|
|
if img_file:
|
|
timestamp = str(int(time.time() * 1000))
|
|
if batch_id:
|
|
final_mask_path = TEMP_DIR / f"mask_{batch_id}_{mask_index}.png"
|
|
else:
|
|
final_mask_path = TEMP_DIR / f"mask_{timestamp}.png"
|
|
|
|
from PIL import Image
|
|
img = Image.open(img_file).convert('RGBA')
|
|
img.save(str(final_mask_path), format='PNG')
|
|
|
|
logger.info(f"[WEBHOOK] Image saved to {final_mask_path}")
|
|
|
|
if batch_id:
|
|
batch_storage.add_mask(batch_id, mask_index, final_mask_path)
|
|
|
|
return make_response(jsonify({'status': 'ok', 'message': 'Image received'}), 200)
|
|
|
|
elif request.data:
|
|
timestamp = str(int(time.time() * 1000))
|
|
if batch_id:
|
|
final_mask_path = TEMP_DIR / f"mask_{batch_id}_{mask_index}.png"
|
|
else:
|
|
final_mask_path = TEMP_DIR / f"mask_{timestamp}.png"
|
|
|
|
with open(final_mask_path, 'wb') as f:
|
|
f.write(request.data)
|
|
|
|
logger.info(f"[WEBHOOK] Raw data saved to {final_mask_path}")
|
|
|
|
if batch_id:
|
|
batch_storage.add_mask(batch_id, mask_index, final_mask_path)
|
|
|
|
return make_response(jsonify({'status': 'ok', 'message': 'Image received'}), 200)
|
|
|
|
else:
|
|
logger.error("[WEBHOOK] No image data in request")
|
|
return make_response(jsonify({'status': 'error', 'message': 'No image data'}), 400)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[WEBHOOK] Error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return make_response(jsonify({'status': 'error', 'message': str(e)}), 500)
|
|
|
|
|
|
@mask_bp.route('/api/mask/extract', methods=['POST'])
|
|
def api_mask_extract():
|
|
"""Extract one or more masks from the current base image using ComfyUI."""
|
|
data = request.get_json()
|
|
|
|
required = ['subject', 'ora_path']
|
|
for field in required:
|
|
if field not in data:
|
|
return jsonify({'success': False, 'error': f'Missing {field} parameter'}), 400
|
|
|
|
subject = data['subject']
|
|
use_polygon = data.get('use_polygon', False)
|
|
ora_path = data['ora_path']
|
|
comfy_url = data.get('comfy_url', COMFYUI_BASE_URL)
|
|
count = min(max(data.get('count', 1), 1), 10)
|
|
start_mask_path = data.get('start_mask_path', None)
|
|
denoise_strength = data.get('denoise_strength', 0.8)
|
|
|
|
logger.info(f"[MASK EXTRACT] Subject: {subject}")
|
|
logger.info(f"[MASK EXTRACT] Use polygon: {use_polygon}")
|
|
logger.info(f"[MASK EXTRACT] ORA path: {ora_path}")
|
|
logger.info(f"[MASK EXTRACT] ComfyUI URL: {comfy_url}")
|
|
logger.info(f"[MASK EXTRACT] Count: {count}")
|
|
logger.info(f"[MASK EXTRACT] Start mask path: {start_mask_path}")
|
|
|
|
if start_mask_path:
|
|
workflow_path = APP_DIR.parent / "image_mask_extraction_with_start.json"
|
|
else:
|
|
workflow_path = APP_DIR.parent / "image_mask_extraction.json"
|
|
|
|
if not workflow_path.exists():
|
|
logger.error(f"[MASK EXTRACT] Workflow file not found: {workflow_path}")
|
|
return jsonify({'success': False, 'error': f'Workflow file not found: {workflow_path}'}), 500
|
|
|
|
with open(workflow_path) as f:
|
|
workflow_template = json.load(f)
|
|
|
|
logger.info(f"[MASK EXTRACT] Loaded workflow")
|
|
|
|
base_img = None
|
|
try:
|
|
with zipfile.ZipFile(ora_path, 'r') as zf:
|
|
img_data = zf.read('mergedimage.png')
|
|
base_img = __import__('PIL').Image.open(io.BytesIO(img_data)).convert('RGBA')
|
|
logger.info(f"[MASK EXTRACT] Loaded base image: {base_img.size}")
|
|
except Exception as e:
|
|
logger.error(f"[MASK EXTRACT] Error loading base image: {e}")
|
|
return jsonify({'success': False, 'error': f'Error loading image: {str(e)}'}), 500
|
|
|
|
start_mask_img = None
|
|
if start_mask_path:
|
|
try:
|
|
start_mask_img = __import__('PIL').Image.open(start_mask_path).convert('RGBA')
|
|
if base_img.size != start_mask_img.size:
|
|
start_mask_img = start_mask_img.resize(base_img.size, __import__('PIL').Image.LANCZOS)
|
|
logger.info(f"[MASK EXTRACT] Loaded start mask: {start_mask_img.size}")
|
|
except Exception as e:
|
|
logger.error(f"[MASK EXTRACT] Error loading start mask: {e}")
|
|
return jsonify({'success': False, 'error': f'Error loading start mask: {str(e)}'}), 500
|
|
|
|
polygon_points = None
|
|
polygon_color = '#FF0000'
|
|
polygon_width = 2
|
|
|
|
if use_polygon:
|
|
poly_data = polygon_storage.get(ora_path)
|
|
if poly_data:
|
|
polygon_points = poly_data.get('points', [])
|
|
polygon_color = poly_data.get('color', '#FF0000')
|
|
polygon_width = poly_data.get('width', 2)
|
|
|
|
if not polygon_points or len(polygon_points) < 3:
|
|
logger.warning(f"[MASK EXTRACT] Use polygon requested but no valid polygon points stored")
|
|
return jsonify({'success': False, 'error': 'No valid polygon points. Please draw polygon first.'}), 400
|
|
|
|
batch_id = generate_batch_id()
|
|
batch_storage.create_batch(batch_id, count)
|
|
|
|
webhook_url = f"http://localhost:5001/api/webhook/comfyui"
|
|
|
|
seeds = [random.randint(0, 2**31-1) for _ in range(count)]
|
|
prompt_ids = []
|
|
|
|
for i in range(count):
|
|
if start_mask_img:
|
|
workflow = comfy_service.prepare_mask_workflow_with_start(
|
|
base_image=base_img,
|
|
start_mask_image=start_mask_img,
|
|
subject=subject,
|
|
webhook_url=webhook_url,
|
|
seed=seeds[i],
|
|
batch_id=batch_id,
|
|
mask_index=i,
|
|
polygon_points=polygon_points,
|
|
polygon_color=polygon_color,
|
|
polygon_width=polygon_width,
|
|
denoise_strength=denoise_strength,
|
|
workflow_template=workflow_template
|
|
)
|
|
else:
|
|
workflow = comfy_service.prepare_mask_workflow(
|
|
base_image=base_img,
|
|
subject=subject,
|
|
webhook_url=webhook_url,
|
|
seed=seeds[i],
|
|
batch_id=batch_id,
|
|
mask_index=i,
|
|
polygon_points=polygon_points,
|
|
polygon_color=polygon_color,
|
|
polygon_width=polygon_width,
|
|
workflow_template=workflow_template
|
|
)
|
|
|
|
logger.info(f"[MASK EXTRACT] Workflow {i} prepared, sending to ComfyUI at http://{comfy_url}")
|
|
|
|
try:
|
|
prompt_id = comfy_service.submit_workflow(workflow, comfy_url)
|
|
prompt_ids.append(prompt_id)
|
|
logger.info(f"[MASK EXTRACT] Prompt {i} submitted with ID: {prompt_id}, seed: {seeds[i]}")
|
|
except Exception as e:
|
|
logger.error(f"[MASK EXTRACT] Error submitting workflow {i}: {e}")
|
|
return jsonify({'success': False, 'error': f'Failed to connect to ComfyUI: {str(e)}'}), 500
|
|
|
|
all_completed = True
|
|
for prompt_id in prompt_ids:
|
|
completed = comfy_service.poll_for_completion(prompt_id, comfy_url, timeout=240)
|
|
if not completed:
|
|
all_completed = False
|
|
logger.warning(f"[MASK EXTRACT] Prompt {prompt_id} did not complete in time")
|
|
|
|
if not all_completed:
|
|
logger.error("[MASK EXTRACT] Timeout waiting for some ComfyUI workflows to complete")
|
|
return jsonify({'success': False, 'error': 'Mask extraction timed out'}), 500
|
|
|
|
logger.info(f"[MASK EXTRACT] All workflows completed, waiting for webhooks...")
|
|
|
|
mask_paths = comfy_service.poll_for_batch_completion(batch_id, timeout=120.0)
|
|
|
|
if not mask_paths:
|
|
logger.error(f"[MASK EXTRACT] No masks received via webhook")
|
|
return jsonify({'success': False, 'error': 'No masks received from ComfyUI'}), 500
|
|
|
|
logger.info(f"[MASK EXTRACT] Received {len(mask_paths)} masks")
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'batch_id': batch_id,
|
|
'mask_paths': [str(p) for p in mask_paths],
|
|
'mask_urls': [f'/api/file/mask?path={p}' for p in mask_paths],
|
|
'count': len(mask_paths)
|
|
})
|