Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| import glob | |
| import os | |
| from src.utils.video_utils import video_to_image_frames | |
| IMAGE_EXTS = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tiff', '*.webp'] | |
| VIDEO_EXTS = ['.mp4', '.avi', '.mov', '.webm', '.gif'] | |
| def load_and_preprocess_images(image_file_paths, preprocessing_mode="crop", output_size=518): | |
| """ | |
| Transform raw image files into model-ready tensor batches with standardized dimensions. | |
| This utility function handles the complete pipeline from file paths to batched tensors, | |
| ensuring compatibility with neural network requirements while preserving image quality. | |
| Args: | |
| image_file_paths (list): Collection of file system paths pointing to image files | |
| preprocessing_mode (str, optional): Image transformation strategy: | |
| - "crop" (default): Resize width to 518px, center-crop height if oversized | |
| - "pad": Scale largest dimension to 518px, pad smaller dimension to square | |
| output_size (int, optional): Target dimension for model input (default: 518) | |
| Returns: | |
| torch.Tensor: Processed image batch with shape (1, N, 3, H, W) ready for model inference | |
| Raises: | |
| ValueError: When input validation fails (empty list or invalid mode) | |
| Implementation Details: | |
| - Automatic alpha channel handling: RGBA images composited onto white backgrounds | |
| - Dimension normalization: All outputs divisible by 14 for patch-based processing | |
| - Batch consistency: Different-sized images padded to uniform dimensions | |
| - Memory optimization: Efficient tensor operations with minimal data copying | |
| - Quality preservation: Bicubic resampling maintains visual fidelity | |
| """ | |
| # Input validation and parameter setup | |
| if len(image_file_paths) == 0: | |
| raise ValueError("At least 1 image is required") | |
| if preprocessing_mode not in ["crop", "pad"]: | |
| raise ValueError("preprocessing_mode must be either 'crop' or 'pad'") | |
| processed_image_list = [] | |
| image_dimension_set = set() | |
| tensor_converter = transforms.ToTensor() | |
| model_target_size = output_size | |
| # Individual image processing pipeline | |
| for image_file_path in image_file_paths: | |
| # File system to memory conversion | |
| loaded_image = Image.open(image_file_path) | |
| # Transparency handling for RGBA images | |
| if loaded_image.mode == "RGBA": | |
| # Generate white canvas matching image dimensions | |
| white_background = Image.new("RGBA", loaded_image.size, (255, 255, 255, 255)) | |
| # Blend transparent pixels with white background | |
| loaded_image = Image.alpha_composite(white_background, loaded_image) | |
| # Format standardization to RGB | |
| loaded_image = loaded_image.convert("RGB") | |
| original_width, original_height = loaded_image.size | |
| # Dimension calculation based on preprocessing strategy | |
| if preprocessing_mode == "pad": | |
| # Proportional scaling to fit largest dimension within target | |
| if original_width >= original_height: | |
| scaled_width = model_target_size | |
| scaled_height = round(original_height * (scaled_width / original_width) / 14) * 14 # Patch compatibility | |
| else: | |
| scaled_height = model_target_size | |
| scaled_width = round(original_width * (scaled_height / original_height) / 14) * 14 # Patch compatibility | |
| else: # preprocessing_mode == "crop" | |
| # Width normalization with proportional height adjustment | |
| scaled_width = model_target_size | |
| scaled_height = round(original_height * (scaled_width / original_width) / 14) * 14 | |
| # High-quality image resizing | |
| loaded_image = loaded_image.resize((scaled_width, scaled_height), Image.Resampling.BICUBIC) | |
| image_tensor = tensor_converter(loaded_image) # Normalize to [0, 1] range | |
| # Height trimming for crop mode (center-based) | |
| if preprocessing_mode == "crop" and scaled_height > model_target_size: | |
| crop_start_y = (scaled_height - model_target_size) // 2 | |
| image_tensor = image_tensor[:, crop_start_y : crop_start_y + model_target_size, :] | |
| # Square padding for pad mode (centered) | |
| if preprocessing_mode == "pad": | |
| height_padding_needed = model_target_size - image_tensor.shape[1] | |
| width_padding_needed = model_target_size - image_tensor.shape[2] | |
| if height_padding_needed > 0 or width_padding_needed > 0: | |
| padding_top = height_padding_needed // 2 | |
| padding_bottom = height_padding_needed - padding_top | |
| padding_left = width_padding_needed // 2 | |
| padding_right = width_padding_needed - padding_left | |
| # White padding application (value=1.0 for normalized images) | |
| image_tensor = torch.nn.functional.pad( | |
| image_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=1.0 | |
| ) | |
| image_dimension_set.add((image_tensor.shape[1], image_tensor.shape[2])) | |
| processed_image_list.append(image_tensor) | |
| # Cross-image dimension harmonization | |
| if len(image_dimension_set) > 1: | |
| print(f"Warning: Found images with different shapes: {image_dimension_set}") | |
| # Calculate maximum dimensions across the batch | |
| maximum_height = max(dimension[0] for dimension in image_dimension_set) | |
| maximum_width = max(dimension[1] for dimension in image_dimension_set) | |
| # Uniform padding to achieve batch consistency | |
| uniformly_sized_images = [] | |
| for image_tensor in processed_image_list: | |
| height_padding_needed = maximum_height - image_tensor.shape[1] | |
| width_padding_needed = maximum_width - image_tensor.shape[2] | |
| if height_padding_needed > 0 or width_padding_needed > 0: | |
| padding_top = height_padding_needed // 2 | |
| padding_bottom = height_padding_needed - padding_top | |
| padding_left = width_padding_needed // 2 | |
| padding_right = width_padding_needed - padding_left | |
| image_tensor = torch.nn.functional.pad( | |
| image_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=1.0 | |
| ) | |
| uniformly_sized_images.append(image_tensor) | |
| processed_image_list = uniformly_sized_images | |
| # Batch tensor construction | |
| batched_images = torch.stack(processed_image_list) # Concatenate along batch dimension | |
| # Single image batch dimension handling | |
| if len(image_file_paths) == 1: | |
| # Ensure proper 4D tensor structure (batch, channels, height, width) | |
| if batched_images.dim() == 3: | |
| batched_images = batched_images.unsqueeze(0) | |
| return batched_images.unsqueeze(0) | |
| def _handle_alpha_channel(img_data): | |
| """Process RGBA images by blending with white background.""" | |
| if img_data.mode == "RGBA": | |
| white_bg = Image.new("RGBA", img_data.size, (255, 255, 255, 255)) | |
| img_data = Image.alpha_composite(white_bg, img_data) | |
| return img_data.convert("RGB") | |
| def _calculate_resize_dims(orig_w, orig_h, max_dim, resize_strategy, patch_size=14): | |
| """Calculate new dimensions based on resize strategy.""" | |
| if resize_strategy == "pad": | |
| if orig_w >= orig_h: | |
| new_w = max_dim | |
| new_h = round(orig_h * (new_w / orig_w) / patch_size) * patch_size | |
| else: | |
| new_h = max_dim | |
| new_w = round(orig_w * (new_h / orig_h) / patch_size) * patch_size | |
| else: # crop strategy | |
| new_w = max_dim | |
| new_h = round(orig_h * (new_w / orig_w) / patch_size) * patch_size | |
| return new_w, new_h | |
| def _apply_padding(tensor_img, target_dim): | |
| """Apply padding to make tensor square.""" | |
| h_pad = target_dim - tensor_img.shape[1] | |
| w_pad = target_dim - tensor_img.shape[2] | |
| if h_pad > 0 or w_pad > 0: | |
| pad_top, pad_bottom = h_pad // 2, h_pad - h_pad // 2 | |
| pad_left, pad_right = w_pad // 2, w_pad - w_pad // 2 | |
| return torch.nn.functional.pad( | |
| tensor_img, (pad_left, pad_right, pad_top, pad_bottom), | |
| mode="constant", value=1.0 | |
| ) | |
| return tensor_img | |
| def prepare_images_to_tensor(file_paths, resize_strategy="crop", target_size=518): | |
| """ | |
| Process image files into uniform tensor batch for model input. | |
| Args: | |
| file_paths (list): Paths to image files | |
| resize_strategy (str): "crop" or "pad" processing mode | |
| target_size (int): Target size for processing | |
| Returns: | |
| torch.Tensor: Processed image batch (1, N, 3, H, W) | |
| """ | |
| if not file_paths: | |
| raise ValueError("At least 1 image is required") | |
| if resize_strategy not in ["crop", "pad"]: | |
| raise ValueError("Strategy must be 'crop' or 'pad'") | |
| tensor_list = [] | |
| dimension_set = set() | |
| converter = transforms.ToTensor() | |
| # Process each image file | |
| for file_path in file_paths: | |
| img_data = Image.open(file_path) | |
| img_data = _handle_alpha_channel(img_data) | |
| orig_w, orig_h = img_data.size | |
| new_w, new_h = _calculate_resize_dims(orig_w, orig_h, target_size, resize_strategy) | |
| # Resize and convert to tensor | |
| img_data = img_data.resize((new_w, new_h), Image.Resampling.BICUBIC) | |
| tensor_img = converter(img_data) | |
| # Apply center crop for crop strategy | |
| if resize_strategy == "crop" and new_h > target_size: | |
| crop_start = (new_h - target_size) // 2 | |
| tensor_img = tensor_img[:, crop_start:crop_start + target_size, :] | |
| # Apply padding for pad strategy | |
| if resize_strategy == "pad": | |
| tensor_img = _apply_padding(tensor_img, target_size) | |
| dimension_set.add((tensor_img.shape[1], tensor_img.shape[2])) | |
| tensor_list.append(tensor_img) | |
| # Handle mixed dimensions | |
| if len(dimension_set) > 1: | |
| print(f"Warning: Mixed image dimensions found: {dimension_set}") | |
| max_h = max(dims[0] for dims in dimension_set) | |
| max_w = max(dims[1] for dims in dimension_set) | |
| tensor_list = [_apply_padding(img, max(max_h, max_w)) if img.shape[1] != max_h or img.shape[2] != max_w | |
| else img for img in tensor_list] | |
| batch_tensor = torch.stack(tensor_list) | |
| # Ensure proper batch dimensions | |
| if batch_tensor.dim() == 3: | |
| batch_tensor = batch_tensor.unsqueeze(0) | |
| return batch_tensor.unsqueeze(0) | |
| def extract_load_and_preprocess_images(image_folder_or_video_path, fps=1, target_size=518, mode="crop"): | |
| # Support multiple image formats | |
| if image_folder_or_video_path.is_file() and image_folder_or_video_path.suffix.lower() in VIDEO_EXTS: | |
| frame_paths = video_to_image_frames(str(image_folder_or_video_path), fps=fps) | |
| img_paths = sorted(frame_paths) | |
| else: | |
| img_paths = [] | |
| for ext in IMAGE_EXTS: | |
| img_paths.extend(glob.glob(os.path.join(str(image_folder_or_video_path), ext))) | |
| img_paths = sorted(img_paths) | |
| images = prepare_images_to_tensor(img_paths, resize_strategy=mode, target_size=target_size) | |
| return images |