255 lines
7.9 KiB
Python
Executable File
255 lines
7.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Extract a mask from an image using ComfyUI workflow."""
|
|
|
|
import base64
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
import urllib.error
|
|
import urllib.request
|
|
import uuid
|
|
import random
|
|
from urllib.parse import urlencode
|
|
|
|
|
|
def check_server(server_address: str = "127.0.0.1:8188", timeout: int = 5) -> bool:
|
|
"""Check if ComfyUI server is running and accessible."""
|
|
try:
|
|
req = urllib.request.Request(
|
|
f"http://{server_address}/system_stats",
|
|
method="GET",
|
|
)
|
|
with urllib.request.urlopen(req, timeout=timeout) as response:
|
|
return response.status == 200
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def encode_image_base64(image_path: str) -> str:
|
|
"""Encode an image file as base64 string."""
|
|
with open(image_path, "rb") as f:
|
|
return base64.b64encode(f.read()).decode("utf-8")
|
|
|
|
|
|
def queue_prompt(prompt: dict, server_address: str = "127.0.0.1:8188") -> dict:
|
|
"""Queue a prompt to ComfyUI server."""
|
|
client_id = str(uuid.uuid4())
|
|
p = {"prompt": prompt, "client_id": client_id}
|
|
data = json.dumps(p).encode("utf-8")
|
|
req = urllib.request.Request(
|
|
f"http://{server_address}/prompt",
|
|
data=data,
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
try:
|
|
with urllib.request.urlopen(req) as response:
|
|
return json.loads(response.read())
|
|
except urllib.error.HTTPError as e:
|
|
error_body = e.read().decode("utf-8")
|
|
print(f"HTTP Error {e.code}: {error_body}")
|
|
raise
|
|
|
|
|
|
def get_history(prompt_id: str, server_address: str = "127.0.0.1:8188") -> dict:
|
|
"""Get the history/status of a prompt from ComfyUI."""
|
|
req = urllib.request.Request(
|
|
f"http://{server_address}/history/{prompt_id}",
|
|
method="GET",
|
|
)
|
|
try:
|
|
with urllib.request.urlopen(req) as response:
|
|
return json.loads(response.read())
|
|
except urllib.error.HTTPError as e:
|
|
error_body = e.read().decode("utf-8")
|
|
print(f"HTTP Error {e.code}: {error_body}")
|
|
raise
|
|
|
|
|
|
def download_image(
|
|
filename: str,
|
|
subfolder: str,
|
|
folder_type: str,
|
|
server_address: str = "127.0.0.1:8188",
|
|
) -> bytes:
|
|
"""Download an image from ComfyUI."""
|
|
params = {"filename": filename, "type": folder_type}
|
|
if subfolder:
|
|
params["subfolder"] = subfolder
|
|
|
|
url = f"http://{server_address}/view?{urlencode(params)}"
|
|
|
|
req = urllib.request.Request(url, method="GET")
|
|
with urllib.request.urlopen(req) as response:
|
|
return response.read()
|
|
|
|
|
|
def wait_for_prompt_completion(
|
|
prompt_id: str, server_address: str = "127.0.0.1:8188", timeout: int = 240
|
|
) -> dict | None:
|
|
"""Wait for a prompt to complete and return the output info."""
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout:
|
|
history = get_history(prompt_id, server_address)
|
|
|
|
if prompt_id in history:
|
|
prompt_history = history[prompt_id]
|
|
if "outputs" in prompt_history and prompt_history["outputs"]:
|
|
return prompt_history["outputs"]
|
|
|
|
time.sleep(0.5)
|
|
|
|
return None
|
|
|
|
|
|
def extract_mask(
|
|
subject: str,
|
|
input_image: str,
|
|
output_path: str,
|
|
server_address: str = "127.0.0.1:8188",
|
|
) -> str:
|
|
"""Extract mask from image for given subject.
|
|
|
|
Args:
|
|
subject: The subject to extract mask for (e.g., "the stump", "the door")
|
|
input_image: Path to the input image file
|
|
output_path: Path where the output mask should be saved
|
|
server_address: ComfyUI server address
|
|
|
|
Returns:
|
|
Path to the saved output mask
|
|
"""
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
workflow_path = os.path.join(script_dir, "image_mask_extraction.json")
|
|
with open(workflow_path, "r") as f:
|
|
workflow = json.load(f)
|
|
|
|
prompt_text = f"Create a black and white alpha mask of {subject}, leaving everything else black"
|
|
|
|
print(f"Encoding input image...")
|
|
base64_image = encode_image_base64(input_image)
|
|
|
|
workflow["1:68"]["inputs"]["prompt"] = prompt_text
|
|
workflow["87"]["inputs"]["image"] = base64_image
|
|
workflow["50"]["inputs"]["seed"] = random.randint(1, 100000000)
|
|
|
|
unique_id = str(uuid.uuid4())[:8]
|
|
filename_prefix = f"masks/mask_{unique_id}"
|
|
workflow["82"]["inputs"]["filename_prefix"] = filename_prefix
|
|
|
|
print(f"Queuing mask extraction for: {subject}")
|
|
print(f"Input image: {input_image}")
|
|
print(f"Prompt: {prompt_text}")
|
|
|
|
response = queue_prompt(workflow, server_address)
|
|
prompt_id = response["prompt_id"]
|
|
print(f"Prompt ID: {prompt_id}")
|
|
|
|
print("Waiting for generation (up to 4 minutes)...")
|
|
outputs = wait_for_prompt_completion(prompt_id, server_address, timeout=240)
|
|
|
|
if not outputs:
|
|
raise RuntimeError("Timeout: Workflow did not complete in 4 minutes")
|
|
|
|
output_filename = None
|
|
output_subfolder = ""
|
|
output_type = "output"
|
|
|
|
for node_id, node_output in outputs.items():
|
|
if "images" in node_output:
|
|
for image_info in node_output["images"]:
|
|
output_filename = image_info["filename"]
|
|
output_subfolder = image_info.get("subfolder", "")
|
|
output_type = image_info.get("type", "output")
|
|
break
|
|
if output_filename:
|
|
break
|
|
|
|
if not output_filename:
|
|
raise RuntimeError("No output image found in workflow results")
|
|
|
|
print(f"Downloading generated mask: {output_filename}")
|
|
|
|
image_data = download_image(
|
|
output_filename, output_subfolder, output_type, server_address
|
|
)
|
|
|
|
output_dir_path = os.path.dirname(os.path.abspath(output_path))
|
|
os.makedirs(output_dir_path, exist_ok=True)
|
|
|
|
with open(output_path, "wb") as f:
|
|
f.write(image_data)
|
|
|
|
print(f"Saved mask: {output_path}")
|
|
|
|
return output_path
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Extract mask from image using ComfyUI"
|
|
)
|
|
parser.add_argument(
|
|
"subject", help="Subject to extract mask for (e.g., 'the stump', 'the door')"
|
|
)
|
|
parser.add_argument("input_image", help="Path to input image file")
|
|
parser.add_argument("output_path", help="Path where output mask should be saved")
|
|
parser.add_argument(
|
|
"--server",
|
|
default="127.0.0.1:8188",
|
|
help="ComfyUI server address (default: 127.0.0.1:8188)",
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Test mode: validate inputs and server connection without generating",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.exists(args.input_image):
|
|
print(f"Error: Input image not found: {args.input_image}")
|
|
sys.exit(1)
|
|
|
|
print(f"Subject: {args.subject}")
|
|
print(f"Input: {args.input_image}")
|
|
print(f"Output: {args.output_path}")
|
|
print(f"Server: {args.server}")
|
|
|
|
if args.dry_run:
|
|
print("\n[Dry Run Mode - Checking server connection...]")
|
|
if check_server(args.server):
|
|
print("✓ ComfyUI server is running and accessible")
|
|
print("\n✓ Dry run successful! All checks passed.")
|
|
sys.exit(0)
|
|
else:
|
|
print(f"✗ ComfyUI server is not accessible at {args.server}")
|
|
print(" Please ensure ComfyUI is running before extracting masks.")
|
|
sys.exit(1)
|
|
|
|
print("\nChecking ComfyUI server...")
|
|
if not check_server(args.server):
|
|
print(f"Error: ComfyUI server is not running at {args.server}")
|
|
print("Please start ComfyUI first or check the server address.")
|
|
print(f"\nTo test without generating, use: --dry-run")
|
|
sys.exit(1)
|
|
|
|
print("✓ ComfyUI server is running")
|
|
|
|
try:
|
|
output = extract_mask(
|
|
args.subject, args.input_image, args.output_path, args.server
|
|
)
|
|
print(f"\nMask extraction complete! Output: {output}")
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|