108 lines
3.8 KiB
Python
108 lines
3.8 KiB
Python
"""SAM3 rough mask generation routes for ORA Editor."""
|
|
|
|
import io
|
|
import json
|
|
import random
|
|
import string
|
|
import time
|
|
import zipfile
|
|
from pathlib import Path
|
|
from flask import Blueprint, request, jsonify, make_response
|
|
|
|
from ora_editor.config import APP_DIR, COMFYUI_BASE_URL, TEMP_DIR
|
|
from ora_editor.services.comfyui import ComfyUIService, batch_storage
|
|
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
sam_bp = Blueprint('sam', __name__)
|
|
comfy_service = ComfyUIService(COMFYUI_BASE_URL)
|
|
|
|
|
|
def generate_batch_id() -> str:
|
|
"""Generate a unique batch ID."""
|
|
return ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))
|
|
|
|
|
|
@sam_bp.route('/api/sam/generate', methods=['POST'])
|
|
def api_sam_generate():
|
|
"""Generate a rough mask using SAM3 with include/exclude points."""
|
|
data = request.get_json()
|
|
|
|
required = ['ora_path', 'include_points']
|
|
for field in required:
|
|
if field not in data:
|
|
return jsonify({'success': False, 'error': f'Missing {field} parameter'}), 400
|
|
|
|
ora_path = data['ora_path']
|
|
include_points = data['include_points']
|
|
exclude_points = data.get('exclude_points', [])
|
|
comfy_url = data.get('comfy_url', COMFYUI_BASE_URL)
|
|
|
|
logger.info(f"[SAM] ORA path: {ora_path}")
|
|
logger.info(f"[SAM] Include points: {include_points}")
|
|
logger.info(f"[SAM] Exclude points: {exclude_points}")
|
|
|
|
workflow_path = APP_DIR.parent / "mask_rough_cut.json"
|
|
if not workflow_path.exists():
|
|
logger.error(f"[SAM] Workflow file not found: {workflow_path}")
|
|
return jsonify({'success': False, 'error': f'Workflow file not found: {workflow_path}'}), 500
|
|
|
|
with open(workflow_path) as f:
|
|
workflow_template = json.load(f)
|
|
|
|
base_img = None
|
|
try:
|
|
with zipfile.ZipFile(ora_path, 'r') as zf:
|
|
img_data = zf.read('mergedimage.png')
|
|
base_img = __import__('PIL').Image.open(io.BytesIO(img_data)).convert('RGBA')
|
|
logger.info(f"[SAM] Loaded base image: {base_img.size}")
|
|
except Exception as e:
|
|
logger.error(f"[SAM] Error loading base image: {e}")
|
|
return jsonify({'success': False, 'error': f'Error loading image: {str(e)}'}), 500
|
|
|
|
batch_id = generate_batch_id()
|
|
webhook_url = f"http://localhost:5001/api/webhook/comfyui"
|
|
|
|
workflow = comfy_service.prepare_sam_workflow(
|
|
base_image=base_img,
|
|
include_points=include_points,
|
|
exclude_points=exclude_points,
|
|
webhook_url=webhook_url,
|
|
batch_id=batch_id,
|
|
workflow_template=workflow_template
|
|
)
|
|
|
|
logger.info(f"[SAM] Workflow prepared, sending to ComfyUI at http://{comfy_url}")
|
|
|
|
try:
|
|
prompt_id = comfy_service.submit_workflow(workflow, comfy_url)
|
|
logger.info(f"[SAM] Prompt submitted with ID: {prompt_id}")
|
|
except Exception as e:
|
|
logger.error(f"[SAM] Error submitting workflow: {e}")
|
|
return jsonify({'success': False, 'error': f'Failed to connect to ComfyUI: {str(e)}'}), 500
|
|
|
|
batch_storage.create_batch(batch_id, 1)
|
|
|
|
completed = comfy_service.poll_for_completion(prompt_id, comfy_url, timeout=120)
|
|
if not completed:
|
|
logger.error("[SAM] Timeout waiting for workflow completion")
|
|
return jsonify({'success': False, 'error': 'SAM generation timed out'}), 500
|
|
|
|
logger.info("[SAM] Workflow completed, waiting for webhook...")
|
|
|
|
mask_paths = comfy_service.poll_for_batch_completion(batch_id, timeout=60.0)
|
|
|
|
if not mask_paths:
|
|
logger.error("[SAM] No mask received via webhook")
|
|
return jsonify({'success': False, 'error': 'No mask received from SAM'}), 500
|
|
|
|
mask_path = mask_paths[0]
|
|
logger.info(f"[SAM] Mask received: {mask_path}")
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'mask_path': str(mask_path),
|
|
'mask_url': f'/api/file/mask?path={mask_path}'
|
|
})
|