Files
ComfyUI-compass-paths/compass_image_loader.py

185 lines
6.2 KiB
Python

from PIL import Image
import folder_paths
import os
import torch
import numpy as np
VALID_DIRECTIONS = {"n", "ne", "e", "se", "s", "sw", "w", "nw"}
VALID_MODALITIES = {"image", "depth", "openpose"}
SUPPORTED_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".tiff"}
def _discover_directories():
base_dir = folder_paths.get_input_directory()
if not os.path.exists(base_dir):
return []
candidates = set()
for root, subdirs, _ in os.walk(base_dir, followlinks=True):
rel = os.path.relpath(root, base_dir)
subdirs_lower = {s.lower() for s in subdirs}
if VALID_DIRECTIONS & subdirs_lower or VALID_MODALITIES & subdirs_lower:
if rel == ".":
continue
candidates.add(rel)
return sorted(candidates)
def _resolve_target_dir(base_dir, directory, direction):
if directory and (not isinstance(directory, str) or directory.strip()):
path = os.path.join(base_dir, directory)
else:
path = base_dir
if direction and direction.strip():
path = os.path.join(path, direction.strip())
return path
def _list_image_files(target_dir):
try:
files = [
f for f in sorted(os.listdir(target_dir))
if os.path.isfile(os.path.join(target_dir, f))
and os.path.splitext(f)[1].lower() in SUPPORTED_EXTENSIONS
]
except OSError:
return []
return files
def _resize_image(image, target_w, target_h):
orig_w, orig_h = image.size
if target_w == 0 and target_h == 0:
return image, orig_w, orig_h
if target_w > 0 and target_h == 0:
fw = target_w
fh = max(1, int(orig_h * (target_w / orig_w)))
return image.resize((fw, fh), Image.Resampling.LANCZOS), fw, fh
if target_h > 0 and target_w == 0:
fh = target_h
fw = max(1, int(orig_w * (target_h / orig_h)))
return image.resize((fw, fh), Image.Resampling.LANCZOS), fw, fh
scale = max(target_w / orig_w, target_h / orig_h)
new_w = max(1, int(orig_w * scale))
new_h = max(1, int(orig_h * scale))
resized = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
left = (new_w - target_w) // 2
top = (new_h - target_h) // 2
return resized.crop((left, top, left + target_w, top + target_h)), target_w, target_h
class CompassImageLoader:
CATEGORY = "image/loaders"
@classmethod
def INPUT_TYPES(cls):
directories = _discover_directories()
return {
"required": {
"directory": (directories if directories else ["(none found)"],),
"direction": (["", "n", "ne", "e", "se", "s", "sw", "w", "nw"],),
"modality": (["image", "depth", "openpose"],),
"frame": ("STRING", {"default": ""}),
"width": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
"height": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
},
}
RETURN_TYPES = ("IMAGE", "STRING", "INT", "INT", "INT")
RETURN_NAMES = ("IMAGE", "path", "width", "height", "frame_count")
FUNCTION = "load_images"
def load_images(self, directory, direction, modality, frame=None, width=0, height=0):
base_dir = folder_paths.get_input_directory()
target_dir = _resolve_target_dir(base_dir, directory, direction)
modality_path = os.path.join(target_dir, modality)
if not os.path.isdir(modality_path):
raise RuntimeError(f"Compass directory not found: {modality_path}")
files = _list_image_files(modality_path)
if not files:
raise RuntimeError(f"No images found in: {modality_path}")
# Frame selection
if frame is None or str(frame).strip() == "":
selected_files = files
output_path = modality_path
else:
try:
index = int(str(frame).strip())
except (ValueError, TypeError):
raise RuntimeError(
f"Invalid frame number: '{frame}'. Must be an integer."
)
if index < 0 or index >= len(files):
raise RuntimeError(
f"Frame index {index} out of bounds. "
f"Found {len(files)} images in {modality_path}."
)
selected_files = [files[index]]
output_path = os.path.join(modality_path, files[index])
# Load and process images
tensors = []
final_w, final_h = 0, 0
for filename in selected_files:
filepath = os.path.join(modality_path, filename)
image = Image.open(filepath).convert("RGB")
image, final_w, final_h = _resize_image(image, width, height)
np_arr = np.array(image).astype(np.float32) / 255.0
tensors.append(torch.from_numpy(np_arr)[None,])
image_batch = (
tensors[0] if len(tensors) == 1 else torch.cat(tensors, dim=0)
)
return (image_batch, output_path, final_w, final_h, len(selected_files))
@classmethod
def IS_CHANGED(cls, directory, direction, modality, frame=None, width=0, height=0):
import hashlib
base_dir = folder_paths.get_input_directory()
target_dir = _resolve_target_dir(base_dir, directory, direction)
modality_path = os.path.join(target_dir, modality)
if not os.path.isdir(modality_path):
return ""
files = _list_image_files(modality_path)
m = hashlib.sha256()
m.update(f"{directory}|{direction}|{modality}|{frame}|{width}|{height}".encode())
if frame is None or str(frame).strip() == "":
for f in files:
fp = os.path.join(modality_path, f)
try:
st = os.stat(fp)
m.update(f"{f}:{st.st_mtime}:{st.st_size}".encode())
except OSError:
pass
else:
try:
index = int(str(frame).strip())
fp = os.path.join(modality_path, files[index])
with open(fp, "rb") as fh:
m.update(fh.read(65536))
except (ValueError, IndexError):
pass
return m.hexdigest()