Files
ai-game-2/tools/extract_mask.py
2026-03-09 15:55:31 -07:00

255 lines
7.9 KiB
Python
Executable File

#!/usr/bin/env python3
"""Extract a mask from an image using ComfyUI workflow."""
import base64
import json
import os
import sys
import time
import urllib.error
import urllib.request
import uuid
import random
from urllib.parse import urlencode
def check_server(server_address: str = "127.0.0.1:8188", timeout: int = 5) -> bool:
"""Check if ComfyUI server is running and accessible."""
try:
req = urllib.request.Request(
f"http://{server_address}/system_stats",
method="GET",
)
with urllib.request.urlopen(req, timeout=timeout) as response:
return response.status == 200
except Exception:
return False
def encode_image_base64(image_path: str) -> str:
"""Encode an image file as base64 string."""
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
def queue_prompt(prompt: dict, server_address: str = "127.0.0.1:8188") -> dict:
"""Queue a prompt to ComfyUI server."""
client_id = str(uuid.uuid4())
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request(
f"http://{server_address}/prompt",
data=data,
headers={"Content-Type": "application/json"},
)
try:
with urllib.request.urlopen(req) as response:
return json.loads(response.read())
except urllib.error.HTTPError as e:
error_body = e.read().decode("utf-8")
print(f"HTTP Error {e.code}: {error_body}")
raise
def get_history(prompt_id: str, server_address: str = "127.0.0.1:8188") -> dict:
"""Get the history/status of a prompt from ComfyUI."""
req = urllib.request.Request(
f"http://{server_address}/history/{prompt_id}",
method="GET",
)
try:
with urllib.request.urlopen(req) as response:
return json.loads(response.read())
except urllib.error.HTTPError as e:
error_body = e.read().decode("utf-8")
print(f"HTTP Error {e.code}: {error_body}")
raise
def download_image(
filename: str,
subfolder: str,
folder_type: str,
server_address: str = "127.0.0.1:8188",
) -> bytes:
"""Download an image from ComfyUI."""
params = {"filename": filename, "type": folder_type}
if subfolder:
params["subfolder"] = subfolder
url = f"http://{server_address}/view?{urlencode(params)}"
req = urllib.request.Request(url, method="GET")
with urllib.request.urlopen(req) as response:
return response.read()
def wait_for_prompt_completion(
prompt_id: str, server_address: str = "127.0.0.1:8188", timeout: int = 240
) -> dict | None:
"""Wait for a prompt to complete and return the output info."""
start_time = time.time()
while time.time() - start_time < timeout:
history = get_history(prompt_id, server_address)
if prompt_id in history:
prompt_history = history[prompt_id]
if "outputs" in prompt_history and prompt_history["outputs"]:
return prompt_history["outputs"]
time.sleep(0.5)
return None
def extract_mask(
subject: str,
input_image: str,
output_path: str,
server_address: str = "127.0.0.1:8188",
) -> str:
"""Extract mask from image for given subject.
Args:
subject: The subject to extract mask for (e.g., "the stump", "the door")
input_image: Path to the input image file
output_path: Path where the output mask should be saved
server_address: ComfyUI server address
Returns:
Path to the saved output mask
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
workflow_path = os.path.join(script_dir, "image_mask_extraction.json")
with open(workflow_path, "r") as f:
workflow = json.load(f)
prompt_text = f"Create a black and white alpha mask of {subject}, leaving everything else black"
print(f"Encoding input image...")
base64_image = encode_image_base64(input_image)
workflow["1:68"]["inputs"]["prompt"] = prompt_text
workflow["87"]["inputs"]["image"] = base64_image
workflow["50"]["inputs"]["seed"] = random.randint(1, 100000000)
unique_id = str(uuid.uuid4())[:8]
filename_prefix = f"masks/mask_{unique_id}"
workflow["82"]["inputs"]["filename_prefix"] = filename_prefix
print(f"Queuing mask extraction for: {subject}")
print(f"Input image: {input_image}")
print(f"Prompt: {prompt_text}")
response = queue_prompt(workflow, server_address)
prompt_id = response["prompt_id"]
print(f"Prompt ID: {prompt_id}")
print("Waiting for generation (up to 4 minutes)...")
outputs = wait_for_prompt_completion(prompt_id, server_address, timeout=240)
if not outputs:
raise RuntimeError("Timeout: Workflow did not complete in 4 minutes")
output_filename = None
output_subfolder = ""
output_type = "output"
for node_id, node_output in outputs.items():
if "images" in node_output:
for image_info in node_output["images"]:
output_filename = image_info["filename"]
output_subfolder = image_info.get("subfolder", "")
output_type = image_info.get("type", "output")
break
if output_filename:
break
if not output_filename:
raise RuntimeError("No output image found in workflow results")
print(f"Downloading generated mask: {output_filename}")
image_data = download_image(
output_filename, output_subfolder, output_type, server_address
)
output_dir_path = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir_path, exist_ok=True)
with open(output_path, "wb") as f:
f.write(image_data)
print(f"Saved mask: {output_path}")
return output_path
def main():
import argparse
parser = argparse.ArgumentParser(
description="Extract mask from image using ComfyUI"
)
parser.add_argument(
"subject", help="Subject to extract mask for (e.g., 'the stump', 'the door')"
)
parser.add_argument("input_image", help="Path to input image file")
parser.add_argument("output_path", help="Path where output mask should be saved")
parser.add_argument(
"--server",
default="127.0.0.1:8188",
help="ComfyUI server address (default: 127.0.0.1:8188)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Test mode: validate inputs and server connection without generating",
)
args = parser.parse_args()
if not os.path.exists(args.input_image):
print(f"Error: Input image not found: {args.input_image}")
sys.exit(1)
print(f"Subject: {args.subject}")
print(f"Input: {args.input_image}")
print(f"Output: {args.output_path}")
print(f"Server: {args.server}")
if args.dry_run:
print("\n[Dry Run Mode - Checking server connection...]")
if check_server(args.server):
print("✓ ComfyUI server is running and accessible")
print("\n✓ Dry run successful! All checks passed.")
sys.exit(0)
else:
print(f"✗ ComfyUI server is not accessible at {args.server}")
print(" Please ensure ComfyUI is running before extracting masks.")
sys.exit(1)
print("\nChecking ComfyUI server...")
if not check_server(args.server):
print(f"Error: ComfyUI server is not running at {args.server}")
print("Please start ComfyUI first or check the server address.")
print(f"\nTo test without generating, use: --dry-run")
sys.exit(1)
print("✓ ComfyUI server is running")
try:
output = extract_mask(
args.subject, args.input_image, args.output_path, args.server
)
print(f"\nMask extraction complete! Output: {output}")
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()