#!/usr/bin/env python3 """Extract a mask from an image using ComfyUI workflow.""" import base64 import json import os import sys import time import urllib.error import urllib.request import uuid import random from urllib.parse import urlencode def check_server(server_address: str = "127.0.0.1:8188", timeout: int = 5) -> bool: """Check if ComfyUI server is running and accessible.""" try: req = urllib.request.Request( f"http://{server_address}/system_stats", method="GET", ) with urllib.request.urlopen(req, timeout=timeout) as response: return response.status == 200 except Exception: return False def encode_image_base64(image_path: str) -> str: """Encode an image file as base64 string.""" with open(image_path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") def queue_prompt(prompt: dict, server_address: str = "127.0.0.1:8188") -> dict: """Queue a prompt to ComfyUI server.""" client_id = str(uuid.uuid4()) p = {"prompt": prompt, "client_id": client_id} data = json.dumps(p).encode("utf-8") req = urllib.request.Request( f"http://{server_address}/prompt", data=data, headers={"Content-Type": "application/json"}, ) try: with urllib.request.urlopen(req) as response: return json.loads(response.read()) except urllib.error.HTTPError as e: error_body = e.read().decode("utf-8") print(f"HTTP Error {e.code}: {error_body}") raise def get_history(prompt_id: str, server_address: str = "127.0.0.1:8188") -> dict: """Get the history/status of a prompt from ComfyUI.""" req = urllib.request.Request( f"http://{server_address}/history/{prompt_id}", method="GET", ) try: with urllib.request.urlopen(req) as response: return json.loads(response.read()) except urllib.error.HTTPError as e: error_body = e.read().decode("utf-8") print(f"HTTP Error {e.code}: {error_body}") raise def download_image( filename: str, subfolder: str, folder_type: str, server_address: str = "127.0.0.1:8188", ) -> bytes: """Download an image from ComfyUI.""" params = {"filename": filename, "type": folder_type} if subfolder: params["subfolder"] = subfolder url = f"http://{server_address}/view?{urlencode(params)}" req = urllib.request.Request(url, method="GET") with urllib.request.urlopen(req) as response: return response.read() def wait_for_prompt_completion( prompt_id: str, server_address: str = "127.0.0.1:8188", timeout: int = 240 ) -> dict | None: """Wait for a prompt to complete and return the output info.""" start_time = time.time() while time.time() - start_time < timeout: history = get_history(prompt_id, server_address) if prompt_id in history: prompt_history = history[prompt_id] if "outputs" in prompt_history and prompt_history["outputs"]: return prompt_history["outputs"] time.sleep(0.5) return None def extract_mask( subject: str, input_image: str, output_path: str, server_address: str = "127.0.0.1:8188", ) -> str: """Extract mask from image for given subject. Args: subject: The subject to extract mask for (e.g., "the stump", "the door") input_image: Path to the input image file output_path: Path where the output mask should be saved server_address: ComfyUI server address Returns: Path to the saved output mask """ script_dir = os.path.dirname(os.path.abspath(__file__)) workflow_path = os.path.join(script_dir, "image_mask_extraction.json") with open(workflow_path, "r") as f: workflow = json.load(f) prompt_text = f"Create a black and white alpha mask of {subject}, leaving everything else black" print(f"Encoding input image...") base64_image = encode_image_base64(input_image) workflow["1:68"]["inputs"]["prompt"] = prompt_text workflow["87"]["inputs"]["image"] = base64_image workflow["50"]["inputs"]["seed"] = random.randint(1, 100000000) unique_id = str(uuid.uuid4())[:8] filename_prefix = f"masks/mask_{unique_id}" workflow["82"]["inputs"]["filename_prefix"] = filename_prefix print(f"Queuing mask extraction for: {subject}") print(f"Input image: {input_image}") print(f"Prompt: {prompt_text}") response = queue_prompt(workflow, server_address) prompt_id = response["prompt_id"] print(f"Prompt ID: {prompt_id}") print("Waiting for generation (up to 4 minutes)...") outputs = wait_for_prompt_completion(prompt_id, server_address, timeout=240) if not outputs: raise RuntimeError("Timeout: Workflow did not complete in 4 minutes") output_filename = None output_subfolder = "" output_type = "output" for node_id, node_output in outputs.items(): if "images" in node_output: for image_info in node_output["images"]: output_filename = image_info["filename"] output_subfolder = image_info.get("subfolder", "") output_type = image_info.get("type", "output") break if output_filename: break if not output_filename: raise RuntimeError("No output image found in workflow results") print(f"Downloading generated mask: {output_filename}") image_data = download_image( output_filename, output_subfolder, output_type, server_address ) output_dir_path = os.path.dirname(os.path.abspath(output_path)) os.makedirs(output_dir_path, exist_ok=True) with open(output_path, "wb") as f: f.write(image_data) print(f"Saved mask: {output_path}") return output_path def main(): import argparse parser = argparse.ArgumentParser( description="Extract mask from image using ComfyUI" ) parser.add_argument( "subject", help="Subject to extract mask for (e.g., 'the stump', 'the door')" ) parser.add_argument("input_image", help="Path to input image file") parser.add_argument("output_path", help="Path where output mask should be saved") parser.add_argument( "--server", default="127.0.0.1:8188", help="ComfyUI server address (default: 127.0.0.1:8188)", ) parser.add_argument( "--dry-run", action="store_true", help="Test mode: validate inputs and server connection without generating", ) args = parser.parse_args() if not os.path.exists(args.input_image): print(f"Error: Input image not found: {args.input_image}") sys.exit(1) print(f"Subject: {args.subject}") print(f"Input: {args.input_image}") print(f"Output: {args.output_path}") print(f"Server: {args.server}") if args.dry_run: print("\n[Dry Run Mode - Checking server connection...]") if check_server(args.server): print("āœ“ ComfyUI server is running and accessible") print("\nāœ“ Dry run successful! All checks passed.") sys.exit(0) else: print(f"āœ— ComfyUI server is not accessible at {args.server}") print(" Please ensure ComfyUI is running before extracting masks.") sys.exit(1) print("\nChecking ComfyUI server...") if not check_server(args.server): print(f"Error: ComfyUI server is not running at {args.server}") print("Please start ComfyUI first or check the server address.") print(f"\nTo test without generating, use: --dry-run") sys.exit(1) print("āœ“ ComfyUI server is running") try: output = extract_mask( args.subject, args.input_image, args.output_path, args.server ) print(f"\nMask extraction complete! Output: {output}") except Exception as e: print(f"Error: {e}") sys.exit(1) if __name__ == "__main__": main()