Add multi-mask extraction with count selector and navigation
- Add count selector (1-10) for generating multiple mask variations - Each mask gets a unique random seed - Add left/right arrow navigation in mask preview modal when multiple masks exist - Batch storage system for tracking multiple concurrent extractions - Webhook handler now uses batch_id:mask_index for routing responses
This commit is contained in:
@@ -18,13 +18,61 @@ from PIL import Image, ImageDraw
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BatchStorage:
|
||||
"""Thread-safe storage for batch mask extraction results."""
|
||||
|
||||
def __init__(self):
|
||||
self._batches: dict[str, dict[str, Any]] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def create_batch(self, batch_id: str, count: int) -> None:
|
||||
"""Initialize a batch with expected count."""
|
||||
with self._lock:
|
||||
self._batches[batch_id] = {
|
||||
'expected': count,
|
||||
'masks': {},
|
||||
'errors': []
|
||||
}
|
||||
|
||||
def add_mask(self, batch_id: str, index: int, path: Path) -> None:
|
||||
"""Add a completed mask to the batch."""
|
||||
with self._lock:
|
||||
if batch_id in self._batches:
|
||||
self._batches[batch_id]['masks'][index] = str(path)
|
||||
|
||||
def add_error(self, batch_id: str, error: str) -> None:
|
||||
"""Add an error to the batch."""
|
||||
with self._lock:
|
||||
if batch_id in self._batches:
|
||||
self._batches[batch_id]['errors'].append(error)
|
||||
|
||||
def get_batch(self, batch_id: str) -> dict | None:
|
||||
"""Get batch status."""
|
||||
with self._lock:
|
||||
return self._batches.get(batch_id)
|
||||
|
||||
def is_complete(self, batch_id: str) -> bool:
|
||||
"""Check if batch has all expected masks."""
|
||||
with self._lock:
|
||||
batch = self._batches.get(batch_id)
|
||||
if not batch:
|
||||
return False
|
||||
return len(batch['masks']) >= batch['expected']
|
||||
|
||||
def clear_batch(self, batch_id: str) -> None:
|
||||
"""Remove a batch from storage."""
|
||||
with self._lock:
|
||||
self._batches.pop(batch_id, None)
|
||||
|
||||
|
||||
batch_storage = BatchStorage()
|
||||
|
||||
|
||||
class ComfyUIService:
|
||||
"""Service for interacting with ComfyUI API."""
|
||||
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
self._webhook_response: dict | None = None
|
||||
self._webhook_ready = threading.Event()
|
||||
|
||||
def submit_workflow(self, workflow: dict, comfy_url: str | None = None) -> str:
|
||||
"""Submit a workflow to ComfyUI and return the prompt_id."""
|
||||
@@ -83,79 +131,34 @@ class ComfyUIService:
|
||||
|
||||
return False
|
||||
|
||||
def wait_for_webhook(self, timeout: float = 60.0) -> dict | None:
|
||||
"""Wait for webhook callback from ComfyUI."""
|
||||
if self._webhook_ready.is_set() and self._webhook_response is not None:
|
||||
return self._webhook_response
|
||||
def poll_for_batch_completion(self, batch_id: str, timeout: float = 300.0) -> list[str]:
|
||||
"""Poll until all masks in a batch are received via webhook."""
|
||||
start_time = time.time()
|
||||
|
||||
self._webhook_ready.clear()
|
||||
self._webhook_response = None
|
||||
while time.time() - start_time < timeout:
|
||||
batch = batch_storage.get_batch(batch_id)
|
||||
if batch and batch_storage.is_complete(batch_id):
|
||||
masks = batch['masks']
|
||||
return [masks[i] for i in sorted(masks.keys())]
|
||||
time.sleep(1)
|
||||
|
||||
webhook_received = self._webhook_ready.wait(timeout=timeout)
|
||||
|
||||
if webhook_received:
|
||||
return self._webhook_response
|
||||
return None
|
||||
|
||||
def handle_webhook(self, request_files, request_form, request_data, temp_dir: Path) -> dict:
|
||||
"""Handle incoming webhook from ComfyUI."""
|
||||
self._webhook_response = None
|
||||
self._webhook_ready.clear()
|
||||
|
||||
try:
|
||||
img_file = None
|
||||
if 'file' in request_files:
|
||||
img_file = request_files['file']
|
||||
elif 'image' in request_files:
|
||||
img_file = request_files['image']
|
||||
elif request_files:
|
||||
img_file = list(request_files.values())[0]
|
||||
|
||||
if img_file:
|
||||
timestamp = str(int(time.time()))
|
||||
final_mask_path = temp_dir / f"mask_{timestamp}.png"
|
||||
|
||||
img = Image.open(img_file).convert('RGBA')
|
||||
img.save(str(final_mask_path), format='PNG')
|
||||
|
||||
logger.info(f"[WEBHOOK] Image saved to {final_mask_path}")
|
||||
self._webhook_response = {'success': True, 'path': final_mask_path}
|
||||
|
||||
elif request_data:
|
||||
timestamp = str(int(time.time()))
|
||||
final_mask_path = temp_dir / f"mask_{timestamp}.png"
|
||||
|
||||
with open(final_mask_path, 'wb') as f:
|
||||
f.write(request_data)
|
||||
|
||||
logger.info(f"[WEBHOOK] Raw data saved to {final_mask_path}")
|
||||
self._webhook_response = {'success': True, 'path': final_mask_path}
|
||||
|
||||
else:
|
||||
logger.error("[WEBHOOK] No image data in request")
|
||||
self._webhook_response = {'success': False, 'error': 'No image data received'}
|
||||
|
||||
self._webhook_ready.set()
|
||||
return self._webhook_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[WEBHOOK] Error: {e}")
|
||||
self._webhook_response = {'success': False, 'error': str(e)}
|
||||
self._webhook_ready.set()
|
||||
return self._webhook_response
|
||||
return []
|
||||
|
||||
def prepare_mask_workflow(
|
||||
self,
|
||||
base_image: Image.Image,
|
||||
subject: str,
|
||||
webhook_url: str,
|
||||
seed: int,
|
||||
batch_id: str = None,
|
||||
mask_index: int = 0,
|
||||
polygon_points: list | None = None,
|
||||
polygon_color: str = '#FF0000',
|
||||
polygon_width: int = 2,
|
||||
workflow_template: dict | None = None
|
||||
) -> dict:
|
||||
"""Prepare the mask extraction workflow."""
|
||||
workflow = workflow_template.copy() if workflow_template else {}
|
||||
workflow = json.loads(json.dumps(workflow_template)) if workflow_template else {}
|
||||
|
||||
img = base_image.copy()
|
||||
|
||||
@@ -182,6 +185,10 @@ class ComfyUIService:
|
||||
workflow["96"]["inputs"]["webhook_url"] = webhook_url
|
||||
|
||||
if "50" in workflow and 'inputs' in workflow["50"]:
|
||||
workflow["50"]["inputs"]["seed"] = random.randint(0, 2**31-1)
|
||||
workflow["50"]["inputs"]["seed"] = seed
|
||||
|
||||
if "96" in workflow and 'inputs' in workflow["96"]:
|
||||
metadata = f"{batch_id}:{mask_index}" if batch_id else str(mask_index)
|
||||
workflow["96"]["inputs"]["external_uid"] = metadata
|
||||
|
||||
return workflow
|
||||
|
||||
Reference in New Issue
Block a user