progress
This commit is contained in:
252
tools/extract_mask.py
Executable file
252
tools/extract_mask.py
Executable file
@@ -0,0 +1,252 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Extract a mask from an image using ComfyUI workflow."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
|
||||
|
||||
def check_server(server_address: str = "127.0.0.1:8188", timeout: int = 5) -> bool:
|
||||
"""Check if ComfyUI server is running and accessible."""
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"http://{server_address}/system_stats",
|
||||
method="GET",
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=timeout) as response:
|
||||
return response.status == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def encode_image_base64(image_path: str) -> str:
|
||||
"""Encode an image file as base64 string."""
|
||||
with open(image_path, "rb") as f:
|
||||
return base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
|
||||
def queue_prompt(prompt: dict, server_address: str = "127.0.0.1:8188") -> dict:
|
||||
"""Queue a prompt to ComfyUI server."""
|
||||
client_id = str(uuid.uuid4())
|
||||
p = {"prompt": prompt, "client_id": client_id}
|
||||
data = json.dumps(p).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
f"http://{server_address}/prompt",
|
||||
data=data,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req) as response:
|
||||
return json.loads(response.read())
|
||||
except urllib.error.HTTPError as e:
|
||||
error_body = e.read().decode("utf-8")
|
||||
print(f"HTTP Error {e.code}: {error_body}")
|
||||
raise
|
||||
|
||||
|
||||
def get_history(prompt_id: str, server_address: str = "127.0.0.1:8188") -> dict:
|
||||
"""Get the history/status of a prompt from ComfyUI."""
|
||||
req = urllib.request.Request(
|
||||
f"http://{server_address}/history/{prompt_id}",
|
||||
method="GET",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req) as response:
|
||||
return json.loads(response.read())
|
||||
except urllib.error.HTTPError as e:
|
||||
error_body = e.read().decode("utf-8")
|
||||
print(f"HTTP Error {e.code}: {error_body}")
|
||||
raise
|
||||
|
||||
|
||||
def download_image(
|
||||
filename: str,
|
||||
subfolder: str,
|
||||
folder_type: str,
|
||||
server_address: str = "127.0.0.1:8188",
|
||||
) -> bytes:
|
||||
"""Download an image from ComfyUI."""
|
||||
params = {"filename": filename, "type": folder_type}
|
||||
if subfolder:
|
||||
params["subfolder"] = subfolder
|
||||
|
||||
url = f"http://{server_address}/view?{urlencode(params)}"
|
||||
|
||||
req = urllib.request.Request(url, method="GET")
|
||||
with urllib.request.urlopen(req) as response:
|
||||
return response.read()
|
||||
|
||||
|
||||
def wait_for_prompt_completion(
|
||||
prompt_id: str, server_address: str = "127.0.0.1:8188", timeout: int = 240
|
||||
) -> dict | None:
|
||||
"""Wait for a prompt to complete and return the output info."""
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
history = get_history(prompt_id, server_address)
|
||||
|
||||
if prompt_id in history:
|
||||
prompt_history = history[prompt_id]
|
||||
if "outputs" in prompt_history and prompt_history["outputs"]:
|
||||
return prompt_history["outputs"]
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_mask(
|
||||
subject: str,
|
||||
input_image: str,
|
||||
output_path: str,
|
||||
server_address: str = "127.0.0.1:8188",
|
||||
) -> str:
|
||||
"""Extract mask from image for given subject.
|
||||
|
||||
Args:
|
||||
subject: The subject to extract mask for (e.g., "the stump", "the door")
|
||||
input_image: Path to the input image file
|
||||
output_path: Path where the output mask should be saved
|
||||
server_address: ComfyUI server address
|
||||
|
||||
Returns:
|
||||
Path to the saved output mask
|
||||
"""
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
workflow_path = os.path.join(script_dir, "image_mask_extraction.json")
|
||||
with open(workflow_path, "r") as f:
|
||||
workflow = json.load(f)
|
||||
|
||||
prompt_text = f"Create a black and white alpha mask of {subject}"
|
||||
|
||||
print(f"Encoding input image...")
|
||||
base64_image = encode_image_base64(input_image)
|
||||
|
||||
workflow["1:68"]["inputs"]["prompt"] = prompt_text
|
||||
workflow["87"]["inputs"]["image"] = base64_image
|
||||
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
filename_prefix = f"masks/mask_{unique_id}"
|
||||
workflow["82"]["inputs"]["filename_prefix"] = filename_prefix
|
||||
|
||||
print(f"Queuing mask extraction for: {subject}")
|
||||
print(f"Input image: {input_image}")
|
||||
print(f"Prompt: {prompt_text}")
|
||||
|
||||
response = queue_prompt(workflow, server_address)
|
||||
prompt_id = response["prompt_id"]
|
||||
print(f"Prompt ID: {prompt_id}")
|
||||
|
||||
print("Waiting for generation (up to 4 minutes)...")
|
||||
outputs = wait_for_prompt_completion(prompt_id, server_address, timeout=240)
|
||||
|
||||
if not outputs:
|
||||
raise RuntimeError("Timeout: Workflow did not complete in 4 minutes")
|
||||
|
||||
output_filename = None
|
||||
output_subfolder = ""
|
||||
output_type = "output"
|
||||
|
||||
for node_id, node_output in outputs.items():
|
||||
if "images" in node_output:
|
||||
for image_info in node_output["images"]:
|
||||
output_filename = image_info["filename"]
|
||||
output_subfolder = image_info.get("subfolder", "")
|
||||
output_type = image_info.get("type", "output")
|
||||
break
|
||||
if output_filename:
|
||||
break
|
||||
|
||||
if not output_filename:
|
||||
raise RuntimeError("No output image found in workflow results")
|
||||
|
||||
print(f"Downloading generated mask: {output_filename}")
|
||||
|
||||
image_data = download_image(
|
||||
output_filename, output_subfolder, output_type, server_address
|
||||
)
|
||||
|
||||
output_dir_path = os.path.dirname(os.path.abspath(output_path))
|
||||
os.makedirs(output_dir_path, exist_ok=True)
|
||||
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
|
||||
print(f"Saved mask: {output_path}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract mask from image using ComfyUI"
|
||||
)
|
||||
parser.add_argument(
|
||||
"subject", help="Subject to extract mask for (e.g., 'the stump', 'the door')"
|
||||
)
|
||||
parser.add_argument("input_image", help="Path to input image file")
|
||||
parser.add_argument("output_path", help="Path where output mask should be saved")
|
||||
parser.add_argument(
|
||||
"--server",
|
||||
default="127.0.0.1:8188",
|
||||
help="ComfyUI server address (default: 127.0.0.1:8188)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Test mode: validate inputs and server connection without generating",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.input_image):
|
||||
print(f"Error: Input image not found: {args.input_image}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Subject: {args.subject}")
|
||||
print(f"Input: {args.input_image}")
|
||||
print(f"Output: {args.output_path}")
|
||||
print(f"Server: {args.server}")
|
||||
|
||||
if args.dry_run:
|
||||
print("\n[Dry Run Mode - Checking server connection...]")
|
||||
if check_server(args.server):
|
||||
print("✓ ComfyUI server is running and accessible")
|
||||
print("\n✓ Dry run successful! All checks passed.")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"✗ ComfyUI server is not accessible at {args.server}")
|
||||
print(" Please ensure ComfyUI is running before extracting masks.")
|
||||
sys.exit(1)
|
||||
|
||||
print("\nChecking ComfyUI server...")
|
||||
if not check_server(args.server):
|
||||
print(f"Error: ComfyUI server is not running at {args.server}")
|
||||
print("Please start ComfyUI first or check the server address.")
|
||||
print(f"\nTo test without generating, use: --dry-run")
|
||||
sys.exit(1)
|
||||
|
||||
print("✓ ComfyUI server is running")
|
||||
|
||||
try:
|
||||
output = extract_mask(
|
||||
args.subject, args.input_image, args.output_path, args.server
|
||||
)
|
||||
print(f"\nMask extraction complete! Output: {output}")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
403
tools/image_mask_extraction.json
Normal file
403
tools/image_mask_extraction.json
Normal file
File diff suppressed because one or more lines are too long
@@ -257,8 +257,8 @@ def main():
|
||||
parser.add_argument(
|
||||
"--min-area",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Minimum contour area to include in multiple mode (default: 100)",
|
||||
default=150,
|
||||
help="Minimum contour area to include (default: 150)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -279,13 +279,14 @@ def main():
|
||||
print("Error: No contours found in mask", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
contours = [c for c in contours if cv2.contourArea(c) >= args.min_area]
|
||||
|
||||
if not contours:
|
||||
print("Error: No contours meet minimum area requirement", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if args.mode == "multiple":
|
||||
contours = sorted(contours, key=cv2.contourArea, reverse=True)
|
||||
contours = [c for c in contours if cv2.contourArea(c) >= args.min_area]
|
||||
|
||||
if not contours:
|
||||
print("Error: No contours meet minimum area requirement", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
output_base = args.output if args.output else args.image.with_suffix("")
|
||||
output_dir = output_base.parent
|
||||
|
||||
Reference in New Issue
Block a user