MogensR's picture
Create matanyone_fixed/utils/tensor_utils.py
bbb6939 verified
"""
Fixed MatAnyone Tensor Utilities
Ensures all tensor operations remain in tensor format
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Union
def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]:
"""
FIXED VERSION: Pad tensor to be divisible by d
Args:
in_tensor: Input tensor (..., H, W)
d: Divisor value
Returns:
padded_tensor: Padded tensor
pad_info: Padding information (left, right, top, bottom)
"""
if not isinstance(in_tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)} - this is the source of F.pad() errors!")
# Get spatial dimensions
h, w = in_tensor.shape[-2:]
# Calculate required padding
new_h = ((h + d - 1) // d) * d
new_w = ((w + d - 1) // d) * d
pad_h = new_h - h
pad_w = new_w - w
# Split padding evenly on both sides
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_left = pad_w // 2
pad_right = pad_w - pad_left
# PyTorch padding format: (left, right, top, bottom)
pad_array = (pad_left, pad_right, pad_top, pad_bottom)
# CRITICAL: Ensure input is tensor before F.pad
out = F.pad(in_tensor, pad_array, mode='reflect')
return out, pad_array
def unpad_tensor(padded_tensor: torch.Tensor, pad_info: Tuple[int, int, int, int]) -> torch.Tensor:
"""
Remove padding from tensor
Args:
padded_tensor: Padded tensor
pad_info: Padding information (left, right, top, bottom)
Returns:
unpadded_tensor: Original size tensor
"""
if not isinstance(padded_tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(padded_tensor)}")
pad_left, pad_right, pad_top, pad_bottom = pad_info
# Get current dimensions
h, w = padded_tensor.shape[-2:]
# Calculate crop boundaries
top = pad_top
bottom = h - pad_bottom if pad_bottom > 0 else h
left = pad_left
right = w - pad_right if pad_right > 0 else w
# Crop tensor
unpadded = padded_tensor[..., top:bottom, left:right]
return unpadded
def ensure_tensor(input_data: Union[torch.Tensor, np.ndarray], device: torch.device = None) -> torch.Tensor:
"""
Convert input to tensor if needed and move to device
Args:
input_data: Input data (tensor or numpy array)
device: Target device
Returns:
torch.Tensor: Converted tensor
"""
if isinstance(input_data, np.ndarray):
tensor = torch.from_numpy(input_data).float()
elif isinstance(input_data, torch.Tensor):
tensor = input_data.float()
else:
raise TypeError(f"Unsupported input type: {type(input_data)}")
if device is not None:
tensor = tensor.to(device)
return tensor
def normalize_tensor(tensor: torch.Tensor, target_range: Tuple[float, float] = (0.0, 1.0)) -> torch.Tensor:
"""
Normalize tensor to target range
Args:
tensor: Input tensor
target_range: Target (min, max) range
Returns:
torch.Tensor: Normalized tensor
"""
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
min_val, max_val = target_range
# Normalize to [0, 1] first
tensor_min = tensor.min()
tensor_max = tensor.max()
if tensor_max > tensor_min:
normalized = (tensor - tensor_min) / (tensor_max - tensor_min)
else:
normalized = tensor - tensor_min
# Scale to target range
scaled = normalized * (max_val - min_val) + min_val
return scaled
def resize_tensor(tensor: torch.Tensor,
size: Tuple[int, int],
mode: str = 'bilinear',
align_corners: bool = False) -> torch.Tensor:
"""
Resize tensor while maintaining tensor format
Args:
tensor: Input tensor (C, H, W) or (B, C, H, W)
size: Target (height, width)
mode: Interpolation mode
align_corners: Align corners flag
Returns:
torch.Tensor: Resized tensor
"""
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
original_dims = tensor.ndim
# Add batch dimension if needed
if tensor.ndim == 3:
tensor = tensor.unsqueeze(0)
# Resize
resized = F.interpolate(tensor, size=size, mode=mode, align_corners=align_corners)
# Remove batch dimension if it was added
if original_dims == 3:
resized = resized.squeeze(0)
return resized
def safe_tensor_operation(func):
"""
Decorator to ensure tensor operations receive tensor inputs
"""
def wrapper(*args, **kwargs):
# Check all args are tensors
for i, arg in enumerate(args):
if hasattr(arg, 'shape') and not isinstance(arg, torch.Tensor):
raise TypeError(f"Argument {i} must be torch.Tensor, got {type(arg)}")
return func(*args, **kwargs)
return wrapper
@safe_tensor_operation
def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
"""
Safely convert tensor to numpy array
Args:
tensor: Input tensor
Returns:
np.ndarray: Numpy array
"""
if tensor.requires_grad:
tensor = tensor.detach()
if tensor.is_cuda:
tensor = tensor.cpu()
return tensor.numpy()
def validate_tensor_shapes(*tensors: torch.Tensor, expected_dims: int = None) -> bool:
"""
Validate tensor shapes are compatible
Args:
tensors: Input tensors to validate
expected_dims: Expected number of dimensions
Returns:
bool: True if valid
"""
if not tensors:
return True
if expected_dims is not None:
for tensor in tensors:
if tensor.ndim != expected_dims:
raise ValueError(f"Expected {expected_dims}D tensor, got {tensor.ndim}D")
# Check spatial dimensions match (last 2 dims)
reference_shape = tensors[0].shape[-2:]
for tensor in tensors[1:]:
if tensor.shape[-2:] != reference_shape:
raise ValueError(f"Spatial dimensions mismatch: {reference_shape} vs {tensor.shape[-2:]}")
return True