Files
ai-game-2/tools/ora_editor/routes/mask.py
2026-04-07 10:50:55 -07:00

255 lines
9.9 KiB
Python

"""Mask extraction routes for ORA Editor."""
import io
import json
import random
import string
import time
import zipfile
from pathlib import Path
from flask import Blueprint, request, jsonify, make_response
from ora_editor.config import APP_DIR, COMFYUI_BASE_URL, TEMP_DIR
from ora_editor.services import polygon_storage
from ora_editor.services.comfyui import ComfyUIService, batch_storage
from ora_editor.ora_ops import parse_stack_xml
import logging
logger = logging.getLogger(__name__)
mask_bp = Blueprint('mask', __name__)
comfy_service = ComfyUIService(COMFYUI_BASE_URL)
def generate_batch_id() -> str:
"""Generate a unique batch ID."""
return ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))
@mask_bp.route('/api/webhook/comfyui', methods=['POST'])
def api_webhook_comfyui():
"""Webhook endpoint for ComfyUI to post completed mask images."""
logger.info("[WEBHOOK] Received webhook from ComfyUI")
try:
metadata = request.form.get('metadata', '') if request.form else ''
external_uid = request.form.get('external_uid', '') if request.form else ''
logger.info(f"[WEBHOOK] Metadata: {metadata}")
logger.info(f"[WEBHOOK] External UID: {external_uid}")
batch_id = None
mask_index = 0
if external_uid and ':' in external_uid:
parts = external_uid.split(':')
batch_id = parts[0]
mask_index = int(parts[1]) if len(parts) > 1 else 0
img_file = None
if 'file' in request.files:
img_file = request.files['file']
elif 'image' in request.files:
img_file = request.files['image']
elif request.files:
img_file = list(request.files.values())[0]
if img_file:
timestamp = str(int(time.time() * 1000))
if batch_id:
final_mask_path = TEMP_DIR / f"mask_{batch_id}_{mask_index}.png"
else:
final_mask_path = TEMP_DIR / f"mask_{timestamp}.png"
from PIL import Image
img = Image.open(img_file).convert('RGBA')
img.save(str(final_mask_path), format='PNG')
logger.info(f"[WEBHOOK] Image saved to {final_mask_path}")
if batch_id:
batch_storage.add_mask(batch_id, mask_index, final_mask_path)
return make_response(jsonify({'status': 'ok', 'message': 'Image received'}), 200)
elif request.data:
timestamp = str(int(time.time() * 1000))
if batch_id:
final_mask_path = TEMP_DIR / f"mask_{batch_id}_{mask_index}.png"
else:
final_mask_path = TEMP_DIR / f"mask_{timestamp}.png"
with open(final_mask_path, 'wb') as f:
f.write(request.data)
logger.info(f"[WEBHOOK] Raw data saved to {final_mask_path}")
if batch_id:
batch_storage.add_mask(batch_id, mask_index, final_mask_path)
return make_response(jsonify({'status': 'ok', 'message': 'Image received'}), 200)
else:
logger.error("[WEBHOOK] No image data in request")
return make_response(jsonify({'status': 'error', 'message': 'No image data'}), 400)
except Exception as e:
logger.error(f"[WEBHOOK] Error: {e}")
import traceback
traceback.print_exc()
return make_response(jsonify({'status': 'error', 'message': str(e)}), 500)
@mask_bp.route('/api/mask/extract', methods=['POST'])
def api_mask_extract():
"""Extract one or more masks from the current base image using ComfyUI."""
data = request.get_json()
required = ['subject', 'ora_path']
for field in required:
if field not in data:
return jsonify({'success': False, 'error': f'Missing {field} parameter'}), 400
subject = data['subject']
use_polygon = data.get('use_polygon', False)
ora_path = data['ora_path']
comfy_url = data.get('comfy_url', COMFYUI_BASE_URL)
count = min(max(data.get('count', 1), 1), 10)
start_mask_path = data.get('start_mask_path', None)
denoise_strength = data.get('denoise_strength', 0.8)
logger.info(f"[MASK EXTRACT] Subject: {subject}")
logger.info(f"[MASK EXTRACT] Use polygon: {use_polygon}")
logger.info(f"[MASK EXTRACT] ORA path: {ora_path}")
logger.info(f"[MASK EXTRACT] ComfyUI URL: {comfy_url}")
logger.info(f"[MASK EXTRACT] Count: {count}")
logger.info(f"[MASK EXTRACT] Start mask path: {start_mask_path}")
if start_mask_path:
workflow_path = APP_DIR.parent / "image_mask_extraction_with_start.json"
else:
workflow_path = APP_DIR.parent / "image_mask_extraction.json"
if not workflow_path.exists():
logger.error(f"[MASK EXTRACT] Workflow file not found: {workflow_path}")
return jsonify({'success': False, 'error': f'Workflow file not found: {workflow_path}'}), 500
with open(workflow_path) as f:
workflow_template = json.load(f)
logger.info(f"[MASK EXTRACT] Loaded workflow")
base_img = None
try:
with zipfile.ZipFile(ora_path, 'r') as zf:
img_data = zf.read('mergedimage.png')
base_img = __import__('PIL').Image.open(io.BytesIO(img_data)).convert('RGBA')
logger.info(f"[MASK EXTRACT] Loaded base image: {base_img.size}")
except Exception as e:
logger.error(f"[MASK EXTRACT] Error loading base image: {e}")
return jsonify({'success': False, 'error': f'Error loading image: {str(e)}'}), 500
start_mask_img = None
if start_mask_path:
try:
start_mask_img = __import__('PIL').Image.open(start_mask_path).convert('RGBA')
if base_img.size != start_mask_img.size:
start_mask_img = start_mask_img.resize(base_img.size, __import__('PIL').Image.LANCZOS)
logger.info(f"[MASK EXTRACT] Loaded start mask: {start_mask_img.size}")
except Exception as e:
logger.error(f"[MASK EXTRACT] Error loading start mask: {e}")
return jsonify({'success': False, 'error': f'Error loading start mask: {str(e)}'}), 500
polygon_points = None
polygon_color = '#FF0000'
polygon_width = 2
if use_polygon:
subject = subject + ' outlined in red'
poly_data = polygon_storage.get(ora_path)
if poly_data:
polygon_points = poly_data.get('points', [])
polygon_color = poly_data.get('color', '#FF0000')
polygon_width = poly_data.get('width', 2)
if not polygon_points or len(polygon_points) < 3:
logger.warning(f"[MASK EXTRACT] Use polygon requested but no valid polygon points stored")
return jsonify({'success': False, 'error': 'No valid polygon points. Please draw polygon first.'}), 400
batch_id = generate_batch_id()
batch_storage.create_batch(batch_id, count)
webhook_url = f"http://localhost:5001/api/webhook/comfyui"
seeds = [random.randint(0, 2**31-1) for _ in range(count)]
prompt_ids = []
for i in range(count):
if start_mask_img:
workflow = comfy_service.prepare_mask_workflow_with_start(
base_image=base_img,
start_mask_image=start_mask_img,
subject=subject,
webhook_url=webhook_url,
seed=seeds[i],
batch_id=batch_id,
mask_index=i,
polygon_points=polygon_points,
polygon_color=polygon_color,
polygon_width=polygon_width,
denoise_strength=denoise_strength,
workflow_template=workflow_template
)
else:
workflow = comfy_service.prepare_mask_workflow(
base_image=base_img,
subject=subject,
webhook_url=webhook_url,
seed=seeds[i],
batch_id=batch_id,
mask_index=i,
polygon_points=polygon_points,
polygon_color=polygon_color,
polygon_width=polygon_width,
workflow_template=workflow_template
)
logger.info(f"[MASK EXTRACT] Workflow {i} prepared, sending to ComfyUI at http://{comfy_url}")
try:
prompt_id = comfy_service.submit_workflow(workflow, comfy_url)
prompt_ids.append(prompt_id)
logger.info(f"[MASK EXTRACT] Prompt {i} submitted with ID: {prompt_id}, seed: {seeds[i]}")
except Exception as e:
logger.error(f"[MASK EXTRACT] Error submitting workflow {i}: {e}")
return jsonify({'success': False, 'error': f'Failed to connect to ComfyUI: {str(e)}'}), 500
all_completed = True
for prompt_id in prompt_ids:
completed = comfy_service.poll_for_completion(prompt_id, comfy_url, timeout=240)
if not completed:
all_completed = False
logger.warning(f"[MASK EXTRACT] Prompt {prompt_id} did not complete in time")
if not all_completed:
logger.error("[MASK EXTRACT] Timeout waiting for some ComfyUI workflows to complete")
return jsonify({'success': False, 'error': 'Mask extraction timed out'}), 500
logger.info(f"[MASK EXTRACT] All workflows completed, waiting for webhooks...")
mask_paths = comfy_service.poll_for_batch_completion(batch_id, timeout=120.0)
if not mask_paths:
logger.error(f"[MASK EXTRACT] No masks received via webhook")
return jsonify({'success': False, 'error': 'No masks received from ComfyUI'}), 500
logger.info(f"[MASK EXTRACT] Received {len(mask_paths)} masks")
return jsonify({
'success': True,
'batch_id': batch_id,
'mask_paths': [str(p) for p in mask_paths],
'mask_urls': [f'/api/file/mask?path={p}' for p in mask_paths],
'count': len(mask_paths)
})