ora editor
This commit is contained in:
@@ -6,6 +6,7 @@ from .images import images_bp
|
||||
from .polygon import polygon_bp
|
||||
from .mask import mask_bp
|
||||
from .krita import krita_bp
|
||||
from .sam import sam_bp
|
||||
|
||||
__all__ = [
|
||||
'files_bp',
|
||||
@@ -13,5 +14,6 @@ __all__ = [
|
||||
'images_bp',
|
||||
'polygon_bp',
|
||||
'mask_bp',
|
||||
'krita_bp'
|
||||
'krita_bp',
|
||||
'sam_bp'
|
||||
]
|
||||
|
||||
@@ -115,14 +115,20 @@ def api_mask_extract():
|
||||
ora_path = data['ora_path']
|
||||
comfy_url = data.get('comfy_url', COMFYUI_BASE_URL)
|
||||
count = min(max(data.get('count', 1), 1), 10)
|
||||
start_mask_path = data.get('start_mask_path', None)
|
||||
|
||||
logger.info(f"[MASK EXTRACT] Subject: {subject}")
|
||||
logger.info(f"[MASK EXTRACT] Use polygon: {use_polygon}")
|
||||
logger.info(f"[MASK EXTRACT] ORA path: {ora_path}")
|
||||
logger.info(f"[MASK EXTRACT] ComfyUI URL: {comfy_url}")
|
||||
logger.info(f"[MASK EXTRACT] Count: {count}")
|
||||
logger.info(f"[MASK EXTRACT] Start mask path: {start_mask_path}")
|
||||
|
||||
if start_mask_path:
|
||||
workflow_path = APP_DIR.parent / "image_mask_extraction_with_start.json"
|
||||
else:
|
||||
workflow_path = APP_DIR.parent / "image_mask_extraction.json"
|
||||
|
||||
workflow_path = APP_DIR.parent / "image_mask_extraction.json"
|
||||
if not workflow_path.exists():
|
||||
logger.error(f"[MASK EXTRACT] Workflow file not found: {workflow_path}")
|
||||
return jsonify({'success': False, 'error': f'Workflow file not found: {workflow_path}'}), 500
|
||||
@@ -142,6 +148,17 @@ def api_mask_extract():
|
||||
logger.error(f"[MASK EXTRACT] Error loading base image: {e}")
|
||||
return jsonify({'success': False, 'error': f'Error loading image: {str(e)}'}), 500
|
||||
|
||||
start_mask_img = None
|
||||
if start_mask_path:
|
||||
try:
|
||||
start_mask_img = __import__('PIL').Image.open(start_mask_path).convert('RGBA')
|
||||
if base_img.size != start_mask_img.size:
|
||||
start_mask_img = start_mask_img.resize(base_img.size, __import__('PIL').Image.LANCZOS)
|
||||
logger.info(f"[MASK EXTRACT] Loaded start mask: {start_mask_img.size}")
|
||||
except Exception as e:
|
||||
logger.error(f"[MASK EXTRACT] Error loading start mask: {e}")
|
||||
return jsonify({'success': False, 'error': f'Error loading start mask: {str(e)}'}), 500
|
||||
|
||||
polygon_points = None
|
||||
polygon_color = '#FF0000'
|
||||
polygon_width = 2
|
||||
@@ -166,18 +183,33 @@ def api_mask_extract():
|
||||
prompt_ids = []
|
||||
|
||||
for i in range(count):
|
||||
workflow = comfy_service.prepare_mask_workflow(
|
||||
base_image=base_img,
|
||||
subject=subject,
|
||||
webhook_url=webhook_url,
|
||||
seed=seeds[i],
|
||||
batch_id=batch_id,
|
||||
mask_index=i,
|
||||
polygon_points=polygon_points,
|
||||
polygon_color=polygon_color,
|
||||
polygon_width=polygon_width,
|
||||
workflow_template=workflow_template
|
||||
)
|
||||
if start_mask_img:
|
||||
workflow = comfy_service.prepare_mask_workflow_with_start(
|
||||
base_image=base_img,
|
||||
start_mask_image=start_mask_img,
|
||||
subject=subject,
|
||||
webhook_url=webhook_url,
|
||||
seed=seeds[i],
|
||||
batch_id=batch_id,
|
||||
mask_index=i,
|
||||
polygon_points=polygon_points,
|
||||
polygon_color=polygon_color,
|
||||
polygon_width=polygon_width,
|
||||
workflow_template=workflow_template
|
||||
)
|
||||
else:
|
||||
workflow = comfy_service.prepare_mask_workflow(
|
||||
base_image=base_img,
|
||||
subject=subject,
|
||||
webhook_url=webhook_url,
|
||||
seed=seeds[i],
|
||||
batch_id=batch_id,
|
||||
mask_index=i,
|
||||
polygon_points=polygon_points,
|
||||
polygon_color=polygon_color,
|
||||
polygon_width=polygon_width,
|
||||
workflow_template=workflow_template
|
||||
)
|
||||
|
||||
logger.info(f"[MASK EXTRACT] Workflow {i} prepared, sending to ComfyUI at http://{comfy_url}")
|
||||
|
||||
|
||||
107
tools/ora_editor/routes/sam.py
Normal file
107
tools/ora_editor/routes/sam.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""SAM3 rough mask generation routes for ORA Editor."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from flask import Blueprint, request, jsonify, make_response
|
||||
|
||||
from ora_editor.config import APP_DIR, COMFYUI_BASE_URL, TEMP_DIR
|
||||
from ora_editor.services.comfyui import ComfyUIService, batch_storage
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
sam_bp = Blueprint('sam', __name__)
|
||||
comfy_service = ComfyUIService(COMFYUI_BASE_URL)
|
||||
|
||||
|
||||
def generate_batch_id() -> str:
|
||||
"""Generate a unique batch ID."""
|
||||
return ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))
|
||||
|
||||
|
||||
@sam_bp.route('/api/sam/generate', methods=['POST'])
|
||||
def api_sam_generate():
|
||||
"""Generate a rough mask using SAM3 with include/exclude points."""
|
||||
data = request.get_json()
|
||||
|
||||
required = ['ora_path', 'include_points']
|
||||
for field in required:
|
||||
if field not in data:
|
||||
return jsonify({'success': False, 'error': f'Missing {field} parameter'}), 400
|
||||
|
||||
ora_path = data['ora_path']
|
||||
include_points = data['include_points']
|
||||
exclude_points = data.get('exclude_points', [])
|
||||
comfy_url = data.get('comfy_url', COMFYUI_BASE_URL)
|
||||
|
||||
logger.info(f"[SAM] ORA path: {ora_path}")
|
||||
logger.info(f"[SAM] Include points: {include_points}")
|
||||
logger.info(f"[SAM] Exclude points: {exclude_points}")
|
||||
|
||||
workflow_path = APP_DIR.parent / "mask_rough_cut.json"
|
||||
if not workflow_path.exists():
|
||||
logger.error(f"[SAM] Workflow file not found: {workflow_path}")
|
||||
return jsonify({'success': False, 'error': f'Workflow file not found: {workflow_path}'}), 500
|
||||
|
||||
with open(workflow_path) as f:
|
||||
workflow_template = json.load(f)
|
||||
|
||||
base_img = None
|
||||
try:
|
||||
with zipfile.ZipFile(ora_path, 'r') as zf:
|
||||
img_data = zf.read('mergedimage.png')
|
||||
base_img = __import__('PIL').Image.open(io.BytesIO(img_data)).convert('RGBA')
|
||||
logger.info(f"[SAM] Loaded base image: {base_img.size}")
|
||||
except Exception as e:
|
||||
logger.error(f"[SAM] Error loading base image: {e}")
|
||||
return jsonify({'success': False, 'error': f'Error loading image: {str(e)}'}), 500
|
||||
|
||||
batch_id = generate_batch_id()
|
||||
webhook_url = f"http://localhost:5001/api/webhook/comfyui"
|
||||
|
||||
workflow = comfy_service.prepare_sam_workflow(
|
||||
base_image=base_img,
|
||||
include_points=include_points,
|
||||
exclude_points=exclude_points,
|
||||
webhook_url=webhook_url,
|
||||
batch_id=batch_id,
|
||||
workflow_template=workflow_template
|
||||
)
|
||||
|
||||
logger.info(f"[SAM] Workflow prepared, sending to ComfyUI at http://{comfy_url}")
|
||||
|
||||
try:
|
||||
prompt_id = comfy_service.submit_workflow(workflow, comfy_url)
|
||||
logger.info(f"[SAM] Prompt submitted with ID: {prompt_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"[SAM] Error submitting workflow: {e}")
|
||||
return jsonify({'success': False, 'error': f'Failed to connect to ComfyUI: {str(e)}'}), 500
|
||||
|
||||
batch_storage.create_batch(batch_id, 1)
|
||||
|
||||
completed = comfy_service.poll_for_completion(prompt_id, comfy_url, timeout=120)
|
||||
if not completed:
|
||||
logger.error("[SAM] Timeout waiting for workflow completion")
|
||||
return jsonify({'success': False, 'error': 'SAM generation timed out'}), 500
|
||||
|
||||
logger.info("[SAM] Workflow completed, waiting for webhook...")
|
||||
|
||||
mask_paths = comfy_service.poll_for_batch_completion(batch_id, timeout=60.0)
|
||||
|
||||
if not mask_paths:
|
||||
logger.error("[SAM] No mask received via webhook")
|
||||
return jsonify({'success': False, 'error': 'No mask received from SAM'}), 500
|
||||
|
||||
mask_path = mask_paths[0]
|
||||
logger.info(f"[SAM] Mask received: {mask_path}")
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'mask_path': str(mask_path),
|
||||
'mask_url': f'/api/file/mask?path={mask_path}'
|
||||
})
|
||||
Reference in New Issue
Block a user