"""SAM3 rough mask generation routes for ORA Editor.""" import io import json import random import string import time import zipfile from pathlib import Path from flask import Blueprint, request, jsonify, make_response from ora_editor.config import APP_DIR, COMFYUI_BASE_URL, TEMP_DIR from ora_editor.services.comfyui import ComfyUIService, batch_storage import logging logger = logging.getLogger(__name__) sam_bp = Blueprint('sam', __name__) comfy_service = ComfyUIService(COMFYUI_BASE_URL) def generate_batch_id() -> str: """Generate a unique batch ID.""" return ''.join(random.choices(string.ascii_lowercase + string.digits, k=8)) @sam_bp.route('/api/sam/generate', methods=['POST']) def api_sam_generate(): """Generate a rough mask using SAM3 with include/exclude points.""" data = request.get_json() required = ['ora_path', 'include_points'] for field in required: if field not in data: return jsonify({'success': False, 'error': f'Missing {field} parameter'}), 400 ora_path = data['ora_path'] include_points = data['include_points'] exclude_points = data.get('exclude_points', []) comfy_url = data.get('comfy_url', COMFYUI_BASE_URL) logger.info(f"[SAM] ORA path: {ora_path}") logger.info(f"[SAM] Include points: {include_points}") logger.info(f"[SAM] Exclude points: {exclude_points}") workflow_path = APP_DIR.parent / "mask_rough_cut.json" if not workflow_path.exists(): logger.error(f"[SAM] Workflow file not found: {workflow_path}") return jsonify({'success': False, 'error': f'Workflow file not found: {workflow_path}'}), 500 with open(workflow_path) as f: workflow_template = json.load(f) base_img = None try: with zipfile.ZipFile(ora_path, 'r') as zf: img_data = zf.read('mergedimage.png') base_img = __import__('PIL').Image.open(io.BytesIO(img_data)).convert('RGBA') logger.info(f"[SAM] Loaded base image: {base_img.size}") except Exception as e: logger.error(f"[SAM] Error loading base image: {e}") return jsonify({'success': False, 'error': f'Error loading image: {str(e)}'}), 500 batch_id = generate_batch_id() webhook_url = f"http://localhost:5001/api/webhook/comfyui" workflow = comfy_service.prepare_sam_workflow( base_image=base_img, include_points=include_points, exclude_points=exclude_points, webhook_url=webhook_url, batch_id=batch_id, workflow_template=workflow_template ) logger.info(f"[SAM] Workflow prepared, sending to ComfyUI at http://{comfy_url}") try: prompt_id = comfy_service.submit_workflow(workflow, comfy_url) logger.info(f"[SAM] Prompt submitted with ID: {prompt_id}") except Exception as e: logger.error(f"[SAM] Error submitting workflow: {e}") return jsonify({'success': False, 'error': f'Failed to connect to ComfyUI: {str(e)}'}), 500 batch_storage.create_batch(batch_id, 1) completed = comfy_service.poll_for_completion(prompt_id, comfy_url, timeout=120) if not completed: logger.error("[SAM] Timeout waiting for workflow completion") return jsonify({'success': False, 'error': 'SAM generation timed out'}), 500 logger.info("[SAM] Workflow completed, waiting for webhook...") mask_paths = comfy_service.poll_for_batch_completion(batch_id, timeout=60.0) if not mask_paths: logger.error("[SAM] No mask received via webhook") return jsonify({'success': False, 'error': 'No mask received from SAM'}), 500 mask_path = mask_paths[0] logger.info(f"[SAM] Mask received: {mask_path}") return jsonify({ 'success': True, 'mask_path': str(mask_path), 'mask_url': f'/api/file/mask?path={mask_path}' })