diff --git a/compass_image_loader.py b/compass_image_loader.py index d752fca..54788ac 100644 --- a/compass_image_loader.py +++ b/compass_image_loader.py @@ -7,6 +7,7 @@ import numpy as np VALID_DIRECTIONS = {"n", "ne", "e", "se", "s", "sw", "w", "nw"} VALID_MODALITIES = {"image", "depth", "openpose"} +SUPPORTED_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".tiff"} def _discover_directories(): @@ -16,14 +17,66 @@ def _discover_directories(): candidates = set() for root, subdirs, _ in os.walk(base_dir, followlinks=True): - current_path = root[os.path.dirname(root) + 1:] - sub_dirs_lower = {s.lower() for s in subdirs} - if VALID_DIRECTIONS & sub_dirs_lower or VALID_MODALITIES & sub_dirs_lower: - candidates.add(current_path) + rel = os.path.relpath(root, base_dir) + subdirs_lower = {s.lower() for s in subdirs} + if VALID_DIRECTIONS & subdirs_lower or VALID_MODALITIES & subdirs_lower: + if rel == ".": + continue + candidates.add(rel) return sorted(candidates) +def _resolve_target_dir(base_dir, directory, direction): + if directory and (not isinstance(directory, str) or directory.strip()): + path = os.path.join(base_dir, directory) + else: + path = base_dir + + if direction and direction.strip(): + path = os.path.join(path, direction.strip()) + + return path + + +def _list_image_files(target_dir): + try: + files = [ + f for f in sorted(os.listdir(target_dir)) + if os.path.isfile(os.path.join(target_dir, f)) + and os.path.splitext(f)[1].lower() in SUPPORTED_EXTENSIONS + ] + except OSError: + return [] + return files + + +def _resize_image(image, target_w, target_h): + orig_w, orig_h = image.size + + if target_w == 0 and target_h == 0: + return image, orig_w, orig_h + + if target_w > 0 and target_h == 0: + fw = target_w + fh = max(1, int(orig_h * (target_w / orig_w))) + return image.resize((fw, fh), Image.Resampling.LANCZOS), fw, fh + + if target_h > 0 and target_w == 0: + fh = target_h + fw = max(1, int(orig_w * (target_h / orig_h))) + return image.resize((fw, fh), Image.Resampling.LANCZOS), fw, fh + + scale = max(target_w / orig_w, target_h / orig_h) + new_w = max(1, int(orig_w * scale)) + new_h = max(1, int(orig_h * scale)) + + resized = image.resize((new_w, new_h), Image.Resampling.LANCZOS) + left = (new_w - target_w) // 2 + top = (new_h - target_h) // 2 + return resized.crop((left, top, left + target_w, top + target_h)), target_w, target_h + + class CompassImageLoader: CATEGORY = "image/loaders" @@ -47,111 +100,85 @@ class CompassImageLoader: def load_images(self, directory, direction, modality, frame=None, width=0, height=0): base_dir = folder_paths.get_input_directory() + target_dir = _resolve_target_dir(base_dir, directory, direction) + modality_path = os.path.join(target_dir, modality) - if direction and direction.strip(): - target_dir = os.path.join(base_dir, directory, direction, modality) - else: - target_dir = os.path.join(base_dir, directory, modality) - - if not os.path.isdir(target_dir): - raise RuntimeError(f"Compass directory not found: {target_dir}") - - supported_extensions = {"png", "jpg", "jpeg", "webp", "bmp", "gif", "tiff"} - files = [ - f for f in sorted(os.listdir(target_dir)) - if os.path.isfile(os.path.join(target_dir, f)) and f.split(".")[-1].lower() in supported_extensions - ] + if not os.path.isdir(modality_path): + raise RuntimeError(f"Compass directory not found: {modality_path}") + files = _list_image_files(modality_path) if not files: - raise RuntimeError(f"No images found in: {target_dir}") + raise RuntimeError(f"No images found in: {modality_path}") + # Frame selection if frame is None or str(frame).strip() == "": selected_files = files - output_path = target_dir + output_path = modality_path else: try: index = int(str(frame).strip()) except (ValueError, TypeError): - raise RuntimeError(f"Invalid frame number: '{frame}'. Must be an integer.") + raise RuntimeError( + f"Invalid frame number: '{frame}'. Must be an integer." + ) if index < 0 or index >= len(files): raise RuntimeError( - f"Frame index {index} out of bounds. Found {len(files)} images in {target_dir}." + f"Frame index {index} out of bounds. " + f"Found {len(files)} images in {modality_path}." ) selected_files = [files[index]] - output_path = os.path.join(target_dir, files[index]) + output_path = os.path.join(modality_path, files[index]) + # Load and process images tensors = [] final_w, final_h = 0, 0 for filename in selected_files: - filepath = os.path.join(target_dir, filename) + filepath = os.path.join(modality_path, filename) image = Image.open(filepath).convert("RGB") - orig_w, orig_h = image.size + image, final_w, final_h = _resize_image(image, width, height) - if width == 0 and height == 0: - pass - elif width > 0 and height == 0: - fw = width - fh = int(orig_h * (width / orig_w)) - image = image.resize((fw, fh), Image.Resampling.LANCZOS) - elif height > 0 and width == 0: - fh = height - fw = int(orig_w * (height / orig_h)) - image = image.resize((fw, fh), Image.Resampling.LANCZOS) - else: - scale = max(width / orig_w, height / orig_h) - new_w = int(orig_w * scale) - new_h = int(orig_h * scale) - image = image.resize((new_w, new_h), Image.Resampling.LANCZOS) - left = (new_w - width) // 2 - top = (new_h - height) // 2 - right = left + width - bottom = top + height - image = image.crop((left, top, right, bottom)) + np_arr = np.array(image).astype(np.float32) / 255.0 + tensors.append(torch.from_numpy(np_arr)[None,]) - final_w, final_h = image.size[0], image.size[1] - - np_image = np.array(image).astype(np.float32) / 255.0 - tensor = torch.from_numpy(np_image)[None,] - tensors.append(tensor) - - if len(tensors) == 1: - image_batch = tensors[0] - else: - image_batch = torch.cat(tensors, dim=0) + image_batch = ( + tensors[0] if len(tensors) == 1 else torch.cat(tensors, dim=0) + ) return (image_batch, output_path, final_w, final_h, len(selected_files)) @classmethod def IS_CHANGED(cls, directory, direction, modality, frame=None, width=0, height=0): - base_dir = folder_paths.get_input_directory() - - if direction and direction.strip(): - target_dir = os.path.join(base_dir, directory, direction, modality) - else: - target_dir = os.path.join(base_dir, directory, modality) - import hashlib - m = hashlib.sha256() - m.update(f"{directory}:{direction}:{modality}:{frame}:{width}:{height}".encode("utf-8")) - if not os.path.isdir(target_dir): + base_dir = folder_paths.get_input_directory() + target_dir = _resolve_target_dir(base_dir, directory, direction) + modality_path = os.path.join(target_dir, modality) + + if not os.path.isdir(modality_path): return "" - supported_extensions = {"png", "jpg", "jpeg", "webp", "bmp", "gif", "tiff"} - files = [ - f for f in sorted(os.listdir(target_dir)) - if os.path.isfile(os.path.join(target_dir, f)) and f.split(".")[-1].lower() in supported_extensions - ] + files = _list_image_files(modality_path) + m = hashlib.sha256() + m.update(f"{directory}|{direction}|{modality}|{frame}|{width}|{height}".encode()) if frame is None or str(frame).strip() == "": - m.update(":".join(files).encode("utf-8")) + for f in files: + fp = os.path.join(modality_path, f) + try: + st = os.stat(fp) + m.update(f"{f}:{st.st_mtime}:{st.st_size}".encode()) + except OSError: + pass else: try: index = int(str(frame).strip()) - filepath = os.path.join(target_dir, files[index]) - with open(filepath, "rb") as f: - m.update(f.read(65536)) - except (ValueError, IndexErro + fp = os.path.join(modality_path, files[index]) + with open(fp, "rb") as fh: + m.update(fh.read(65536)) + except (ValueError, IndexError): + pass + + return m.hexdigest()