""" Fixed MatAnyone Tensor Utilities Ensures all tensor operations remain in tensor format """ import torch import torch.nn.functional as F import numpy as np from typing import Tuple, Union def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]: """ FIXED VERSION: Pad tensor to be divisible by d Args: in_tensor: Input tensor (..., H, W) d: Divisor value Returns: padded_tensor: Padded tensor pad_info: Padding information (left, right, top, bottom) """ if not isinstance(in_tensor, torch.Tensor): raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)} - this is the source of F.pad() errors!") # Get spatial dimensions h, w = in_tensor.shape[-2:] # Calculate required padding new_h = ((h + d - 1) // d) * d new_w = ((w + d - 1) // d) * d pad_h = new_h - h pad_w = new_w - w # Split padding evenly on both sides pad_top = pad_h // 2 pad_bottom = pad_h - pad_top pad_left = pad_w // 2 pad_right = pad_w - pad_left # PyTorch padding format: (left, right, top, bottom) pad_array = (pad_left, pad_right, pad_top, pad_bottom) # CRITICAL: Ensure input is tensor before F.pad out = F.pad(in_tensor, pad_array, mode='reflect') return out, pad_array def unpad_tensor(padded_tensor: torch.Tensor, pad_info: Tuple[int, int, int, int]) -> torch.Tensor: """ Remove padding from tensor Args: padded_tensor: Padded tensor pad_info: Padding information (left, right, top, bottom) Returns: unpadded_tensor: Original size tensor """ if not isinstance(padded_tensor, torch.Tensor): raise TypeError(f"Expected torch.Tensor, got {type(padded_tensor)}") pad_left, pad_right, pad_top, pad_bottom = pad_info # Get current dimensions h, w = padded_tensor.shape[-2:] # Calculate crop boundaries top = pad_top bottom = h - pad_bottom if pad_bottom > 0 else h left = pad_left right = w - pad_right if pad_right > 0 else w # Crop tensor unpadded = padded_tensor[..., top:bottom, left:right] return unpadded def ensure_tensor(input_data: Union[torch.Tensor, np.ndarray], device: torch.device = None) -> torch.Tensor: """ Convert input to tensor if needed and move to device Args: input_data: Input data (tensor or numpy array) device: Target device Returns: torch.Tensor: Converted tensor """ if isinstance(input_data, np.ndarray): tensor = torch.from_numpy(input_data).float() elif isinstance(input_data, torch.Tensor): tensor = input_data.float() else: raise TypeError(f"Unsupported input type: {type(input_data)}") if device is not None: tensor = tensor.to(device) return tensor def normalize_tensor(tensor: torch.Tensor, target_range: Tuple[float, float] = (0.0, 1.0)) -> torch.Tensor: """ Normalize tensor to target range Args: tensor: Input tensor target_range: Target (min, max) range Returns: torch.Tensor: Normalized tensor """ if not isinstance(tensor, torch.Tensor): raise TypeError(f"Expected torch.Tensor, got {type(tensor)}") min_val, max_val = target_range # Normalize to [0, 1] first tensor_min = tensor.min() tensor_max = tensor.max() if tensor_max > tensor_min: normalized = (tensor - tensor_min) / (tensor_max - tensor_min) else: normalized = tensor - tensor_min # Scale to target range scaled = normalized * (max_val - min_val) + min_val return scaled def resize_tensor(tensor: torch.Tensor, size: Tuple[int, int], mode: str = 'bilinear', align_corners: bool = False) -> torch.Tensor: """ Resize tensor while maintaining tensor format Args: tensor: Input tensor (C, H, W) or (B, C, H, W) size: Target (height, width) mode: Interpolation mode align_corners: Align corners flag Returns: torch.Tensor: Resized tensor """ if not isinstance(tensor, torch.Tensor): raise TypeError(f"Expected torch.Tensor, got {type(tensor)}") original_dims = tensor.ndim # Add batch dimension if needed if tensor.ndim == 3: tensor = tensor.unsqueeze(0) # Resize resized = F.interpolate(tensor, size=size, mode=mode, align_corners=align_corners) # Remove batch dimension if it was added if original_dims == 3: resized = resized.squeeze(0) return resized def safe_tensor_operation(func): """ Decorator to ensure tensor operations receive tensor inputs """ def wrapper(*args, **kwargs): # Check all args are tensors for i, arg in enumerate(args): if hasattr(arg, 'shape') and not isinstance(arg, torch.Tensor): raise TypeError(f"Argument {i} must be torch.Tensor, got {type(arg)}") return func(*args, **kwargs) return wrapper @safe_tensor_operation def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: """ Safely convert tensor to numpy array Args: tensor: Input tensor Returns: np.ndarray: Numpy array """ if tensor.requires_grad: tensor = tensor.detach() if tensor.is_cuda: tensor = tensor.cpu() return tensor.numpy() def validate_tensor_shapes(*tensors: torch.Tensor, expected_dims: int = None) -> bool: """ Validate tensor shapes are compatible Args: tensors: Input tensors to validate expected_dims: Expected number of dimensions Returns: bool: True if valid """ if not tensors: return True if expected_dims is not None: for tensor in tensors: if tensor.ndim != expected_dims: raise ValueError(f"Expected {expected_dims}D tensor, got {tensor.ndim}D") # Check spatial dimensions match (last 2 dims) reference_shape = tensors[0].shape[-2:] for tensor in tensors[1:]: if tensor.shape[-2:] != reference_shape: raise ValueError(f"Spatial dimensions mismatch: {reference_shape} vs {tensor.shape[-2:]}") return True