Files
ai-game-2/tools/ora_editor/routes/sam.py
2026-03-27 23:33:04 -07:00

108 lines
3.8 KiB
Python

"""SAM3 rough mask generation routes for ORA Editor."""
import io
import json
import random
import string
import time
import zipfile
from pathlib import Path
from flask import Blueprint, request, jsonify, make_response
from ora_editor.config import APP_DIR, COMFYUI_BASE_URL, TEMP_DIR
from ora_editor.services.comfyui import ComfyUIService, batch_storage
import logging
logger = logging.getLogger(__name__)
sam_bp = Blueprint('sam', __name__)
comfy_service = ComfyUIService(COMFYUI_BASE_URL)
def generate_batch_id() -> str:
"""Generate a unique batch ID."""
return ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))
@sam_bp.route('/api/sam/generate', methods=['POST'])
def api_sam_generate():
"""Generate a rough mask using SAM3 with include/exclude points."""
data = request.get_json()
required = ['ora_path', 'include_points']
for field in required:
if field not in data:
return jsonify({'success': False, 'error': f'Missing {field} parameter'}), 400
ora_path = data['ora_path']
include_points = data['include_points']
exclude_points = data.get('exclude_points', [])
comfy_url = data.get('comfy_url', COMFYUI_BASE_URL)
logger.info(f"[SAM] ORA path: {ora_path}")
logger.info(f"[SAM] Include points: {include_points}")
logger.info(f"[SAM] Exclude points: {exclude_points}")
workflow_path = APP_DIR.parent / "mask_rough_cut.json"
if not workflow_path.exists():
logger.error(f"[SAM] Workflow file not found: {workflow_path}")
return jsonify({'success': False, 'error': f'Workflow file not found: {workflow_path}'}), 500
with open(workflow_path) as f:
workflow_template = json.load(f)
base_img = None
try:
with zipfile.ZipFile(ora_path, 'r') as zf:
img_data = zf.read('mergedimage.png')
base_img = __import__('PIL').Image.open(io.BytesIO(img_data)).convert('RGBA')
logger.info(f"[SAM] Loaded base image: {base_img.size}")
except Exception as e:
logger.error(f"[SAM] Error loading base image: {e}")
return jsonify({'success': False, 'error': f'Error loading image: {str(e)}'}), 500
batch_id = generate_batch_id()
webhook_url = f"http://localhost:5001/api/webhook/comfyui"
workflow = comfy_service.prepare_sam_workflow(
base_image=base_img,
include_points=include_points,
exclude_points=exclude_points,
webhook_url=webhook_url,
batch_id=batch_id,
workflow_template=workflow_template
)
logger.info(f"[SAM] Workflow prepared, sending to ComfyUI at http://{comfy_url}")
try:
prompt_id = comfy_service.submit_workflow(workflow, comfy_url)
logger.info(f"[SAM] Prompt submitted with ID: {prompt_id}")
except Exception as e:
logger.error(f"[SAM] Error submitting workflow: {e}")
return jsonify({'success': False, 'error': f'Failed to connect to ComfyUI: {str(e)}'}), 500
batch_storage.create_batch(batch_id, 1)
completed = comfy_service.poll_for_completion(prompt_id, comfy_url, timeout=120)
if not completed:
logger.error("[SAM] Timeout waiting for workflow completion")
return jsonify({'success': False, 'error': 'SAM generation timed out'}), 500
logger.info("[SAM] Workflow completed, waiting for webhook...")
mask_paths = comfy_service.poll_for_batch_completion(batch_id, timeout=60.0)
if not mask_paths:
logger.error("[SAM] No mask received via webhook")
return jsonify({'success': False, 'error': 'No mask received from SAM'}), 500
mask_path = mask_paths[0]
logger.info(f"[SAM] Mask received: {mask_path}")
return jsonify({
'success': True,
'mask_path': str(mask_path),
'mask_url': f'/api/file/mask?path={mask_path}'
})