Files
ai-game-2/tools/extract_mask.py
2026-03-09 09:22:21 -07:00

253 lines
7.8 KiB
Python
Executable File

#!/usr/bin/env python3
"""Extract a mask from an image using ComfyUI workflow."""
import base64
import json
import os
import sys
import time
import urllib.error
import urllib.request
import uuid
from urllib.parse import urlencode
def check_server(server_address: str = "127.0.0.1:8188", timeout: int = 5) -> bool:
"""Check if ComfyUI server is running and accessible."""
try:
req = urllib.request.Request(
f"http://{server_address}/system_stats",
method="GET",
)
with urllib.request.urlopen(req, timeout=timeout) as response:
return response.status == 200
except Exception:
return False
def encode_image_base64(image_path: str) -> str:
"""Encode an image file as base64 string."""
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
def queue_prompt(prompt: dict, server_address: str = "127.0.0.1:8188") -> dict:
"""Queue a prompt to ComfyUI server."""
client_id = str(uuid.uuid4())
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request(
f"http://{server_address}/prompt",
data=data,
headers={"Content-Type": "application/json"},
)
try:
with urllib.request.urlopen(req) as response:
return json.loads(response.read())
except urllib.error.HTTPError as e:
error_body = e.read().decode("utf-8")
print(f"HTTP Error {e.code}: {error_body}")
raise
def get_history(prompt_id: str, server_address: str = "127.0.0.1:8188") -> dict:
"""Get the history/status of a prompt from ComfyUI."""
req = urllib.request.Request(
f"http://{server_address}/history/{prompt_id}",
method="GET",
)
try:
with urllib.request.urlopen(req) as response:
return json.loads(response.read())
except urllib.error.HTTPError as e:
error_body = e.read().decode("utf-8")
print(f"HTTP Error {e.code}: {error_body}")
raise
def download_image(
filename: str,
subfolder: str,
folder_type: str,
server_address: str = "127.0.0.1:8188",
) -> bytes:
"""Download an image from ComfyUI."""
params = {"filename": filename, "type": folder_type}
if subfolder:
params["subfolder"] = subfolder
url = f"http://{server_address}/view?{urlencode(params)}"
req = urllib.request.Request(url, method="GET")
with urllib.request.urlopen(req) as response:
return response.read()
def wait_for_prompt_completion(
prompt_id: str, server_address: str = "127.0.0.1:8188", timeout: int = 240
) -> dict | None:
"""Wait for a prompt to complete and return the output info."""
start_time = time.time()
while time.time() - start_time < timeout:
history = get_history(prompt_id, server_address)
if prompt_id in history:
prompt_history = history[prompt_id]
if "outputs" in prompt_history and prompt_history["outputs"]:
return prompt_history["outputs"]
time.sleep(0.5)
return None
def extract_mask(
subject: str,
input_image: str,
output_path: str,
server_address: str = "127.0.0.1:8188",
) -> str:
"""Extract mask from image for given subject.
Args:
subject: The subject to extract mask for (e.g., "the stump", "the door")
input_image: Path to the input image file
output_path: Path where the output mask should be saved
server_address: ComfyUI server address
Returns:
Path to the saved output mask
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
workflow_path = os.path.join(script_dir, "image_mask_extraction.json")
with open(workflow_path, "r") as f:
workflow = json.load(f)
prompt_text = f"Create a black and white alpha mask of {subject}, leaving everything else black"
print(f"Encoding input image...")
base64_image = encode_image_base64(input_image)
workflow["1:68"]["inputs"]["prompt"] = prompt_text
workflow["87"]["inputs"]["image"] = base64_image
unique_id = str(uuid.uuid4())[:8]
filename_prefix = f"masks/mask_{unique_id}"
workflow["82"]["inputs"]["filename_prefix"] = filename_prefix
print(f"Queuing mask extraction for: {subject}")
print(f"Input image: {input_image}")
print(f"Prompt: {prompt_text}")
response = queue_prompt(workflow, server_address)
prompt_id = response["prompt_id"]
print(f"Prompt ID: {prompt_id}")
print("Waiting for generation (up to 4 minutes)...")
outputs = wait_for_prompt_completion(prompt_id, server_address, timeout=240)
if not outputs:
raise RuntimeError("Timeout: Workflow did not complete in 4 minutes")
output_filename = None
output_subfolder = ""
output_type = "output"
for node_id, node_output in outputs.items():
if "images" in node_output:
for image_info in node_output["images"]:
output_filename = image_info["filename"]
output_subfolder = image_info.get("subfolder", "")
output_type = image_info.get("type", "output")
break
if output_filename:
break
if not output_filename:
raise RuntimeError("No output image found in workflow results")
print(f"Downloading generated mask: {output_filename}")
image_data = download_image(
output_filename, output_subfolder, output_type, server_address
)
output_dir_path = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir_path, exist_ok=True)
with open(output_path, "wb") as f:
f.write(image_data)
print(f"Saved mask: {output_path}")
return output_path
def main():
import argparse
parser = argparse.ArgumentParser(
description="Extract mask from image using ComfyUI"
)
parser.add_argument(
"subject", help="Subject to extract mask for (e.g., 'the stump', 'the door')"
)
parser.add_argument("input_image", help="Path to input image file")
parser.add_argument("output_path", help="Path where output mask should be saved")
parser.add_argument(
"--server",
default="127.0.0.1:8188",
help="ComfyUI server address (default: 127.0.0.1:8188)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Test mode: validate inputs and server connection without generating",
)
args = parser.parse_args()
if not os.path.exists(args.input_image):
print(f"Error: Input image not found: {args.input_image}")
sys.exit(1)
print(f"Subject: {args.subject}")
print(f"Input: {args.input_image}")
print(f"Output: {args.output_path}")
print(f"Server: {args.server}")
if args.dry_run:
print("\n[Dry Run Mode - Checking server connection...]")
if check_server(args.server):
print("✓ ComfyUI server is running and accessible")
print("\n✓ Dry run successful! All checks passed.")
sys.exit(0)
else:
print(f"✗ ComfyUI server is not accessible at {args.server}")
print(" Please ensure ComfyUI is running before extracting masks.")
sys.exit(1)
print("\nChecking ComfyUI server...")
if not check_server(args.server):
print(f"Error: ComfyUI server is not running at {args.server}")
print("Please start ComfyUI first or check the server address.")
print(f"\nTo test without generating, use: --dry-run")
sys.exit(1)
print("✓ ComfyUI server is running")
try:
output = extract_mask(
args.subject, args.input_image, args.output_path, args.server
)
print(f"\nMask extraction complete! Output: {output}")
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()