#!/usr/bin/env python3 """Extract a mask from an image using ComfyUI workflow.""" import base64 import json import os import sys import time import urllib.error import urllib.request import uuid from urllib.parse import urlencode def check_server(server_address: str = "127.0.0.1:8188", timeout: int = 5) -> bool: """Check if ComfyUI server is running and accessible.""" try: req = urllib.request.Request( f"http://{server_address}/system_stats", method="GET", ) with urllib.request.urlopen(req, timeout=timeout) as response: return response.status == 200 except Exception: return False def encode_image_base64(image_path: str) -> str: """Encode an image file as base64 string.""" with open(image_path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") def queue_prompt(prompt: dict, server_address: str = "127.0.0.1:8188") -> dict: """Queue a prompt to ComfyUI server.""" client_id = str(uuid.uuid4()) p = {"prompt": prompt, "client_id": client_id} data = json.dumps(p).encode("utf-8") req = urllib.request.Request( f"http://{server_address}/prompt", data=data, headers={"Content-Type": "application/json"}, ) try: with urllib.request.urlopen(req) as response: return json.loads(response.read()) except urllib.error.HTTPError as e: error_body = e.read().decode("utf-8") print(f"HTTP Error {e.code}: {error_body}") raise def get_history(prompt_id: str, server_address: str = "127.0.0.1:8188") -> dict: """Get the history/status of a prompt from ComfyUI.""" req = urllib.request.Request( f"http://{server_address}/history/{prompt_id}", method="GET", ) try: with urllib.request.urlopen(req) as response: return json.loads(response.read()) except urllib.error.HTTPError as e: error_body = e.read().decode("utf-8") print(f"HTTP Error {e.code}: {error_body}") raise def download_image( filename: str, subfolder: str, folder_type: str, server_address: str = "127.0.0.1:8188", ) -> bytes: """Download an image from ComfyUI.""" params = {"filename": filename, "type": folder_type} if subfolder: params["subfolder"] = subfolder url = f"http://{server_address}/view?{urlencode(params)}" req = urllib.request.Request(url, method="GET") with urllib.request.urlopen(req) as response: return response.read() def wait_for_prompt_completion( prompt_id: str, server_address: str = "127.0.0.1:8188", timeout: int = 240 ) -> dict | None: """Wait for a prompt to complete and return the output info.""" start_time = time.time() while time.time() - start_time < timeout: history = get_history(prompt_id, server_address) if prompt_id in history: prompt_history = history[prompt_id] if "outputs" in prompt_history and prompt_history["outputs"]: return prompt_history["outputs"] time.sleep(0.5) return None def extract_mask( subject: str, input_image: str, output_path: str, server_address: str = "127.0.0.1:8188", ) -> str: """Extract mask from image for given subject. Args: subject: The subject to extract mask for (e.g., "the stump", "the door") input_image: Path to the input image file output_path: Path where the output mask should be saved server_address: ComfyUI server address Returns: Path to the saved output mask """ script_dir = os.path.dirname(os.path.abspath(__file__)) workflow_path = os.path.join(script_dir, "image_mask_extraction.json") with open(workflow_path, "r") as f: workflow = json.load(f) prompt_text = f"Create a black and white alpha mask of {subject}, leaving everything else black" print(f"Encoding input image...") base64_image = encode_image_base64(input_image) workflow["1:68"]["inputs"]["prompt"] = prompt_text workflow["87"]["inputs"]["image"] = base64_image unique_id = str(uuid.uuid4())[:8] filename_prefix = f"masks/mask_{unique_id}" workflow["82"]["inputs"]["filename_prefix"] = filename_prefix print(f"Queuing mask extraction for: {subject}") print(f"Input image: {input_image}") print(f"Prompt: {prompt_text}") response = queue_prompt(workflow, server_address) prompt_id = response["prompt_id"] print(f"Prompt ID: {prompt_id}") print("Waiting for generation (up to 4 minutes)...") outputs = wait_for_prompt_completion(prompt_id, server_address, timeout=240) if not outputs: raise RuntimeError("Timeout: Workflow did not complete in 4 minutes") output_filename = None output_subfolder = "" output_type = "output" for node_id, node_output in outputs.items(): if "images" in node_output: for image_info in node_output["images"]: output_filename = image_info["filename"] output_subfolder = image_info.get("subfolder", "") output_type = image_info.get("type", "output") break if output_filename: break if not output_filename: raise RuntimeError("No output image found in workflow results") print(f"Downloading generated mask: {output_filename}") image_data = download_image( output_filename, output_subfolder, output_type, server_address ) output_dir_path = os.path.dirname(os.path.abspath(output_path)) os.makedirs(output_dir_path, exist_ok=True) with open(output_path, "wb") as f: f.write(image_data) print(f"Saved mask: {output_path}") return output_path def main(): import argparse parser = argparse.ArgumentParser( description="Extract mask from image using ComfyUI" ) parser.add_argument( "subject", help="Subject to extract mask for (e.g., 'the stump', 'the door')" ) parser.add_argument("input_image", help="Path to input image file") parser.add_argument("output_path", help="Path where output mask should be saved") parser.add_argument( "--server", default="127.0.0.1:8188", help="ComfyUI server address (default: 127.0.0.1:8188)", ) parser.add_argument( "--dry-run", action="store_true", help="Test mode: validate inputs and server connection without generating", ) args = parser.parse_args() if not os.path.exists(args.input_image): print(f"Error: Input image not found: {args.input_image}") sys.exit(1) print(f"Subject: {args.subject}") print(f"Input: {args.input_image}") print(f"Output: {args.output_path}") print(f"Server: {args.server}") if args.dry_run: print("\n[Dry Run Mode - Checking server connection...]") if check_server(args.server): print("āœ“ ComfyUI server is running and accessible") print("\nāœ“ Dry run successful! All checks passed.") sys.exit(0) else: print(f"āœ— ComfyUI server is not accessible at {args.server}") print(" Please ensure ComfyUI is running before extracting masks.") sys.exit(1) print("\nChecking ComfyUI server...") if not check_server(args.server): print(f"Error: ComfyUI server is not running at {args.server}") print("Please start ComfyUI first or check the server address.") print(f"\nTo test without generating, use: --dry-run") sys.exit(1) print("āœ“ ComfyUI server is running") try: output = extract_mask( args.subject, args.input_image, args.output_path, args.server ) print(f"\nMask extraction complete! Output: {output}") except Exception as e: print(f"Error: {e}") sys.exit(1) if __name__ == "__main__": main()