|
|
""" |
|
|
Fixed MatAnyone Tensor Utilities |
|
|
Ensures all tensor operations remain in tensor format |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from typing import Tuple, Union |
|
|
|
|
|
|
|
|
def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]: |
|
|
""" |
|
|
FIXED VERSION: Pad tensor to be divisible by d |
|
|
|
|
|
Args: |
|
|
in_tensor: Input tensor (..., H, W) |
|
|
d: Divisor value |
|
|
|
|
|
Returns: |
|
|
padded_tensor: Padded tensor |
|
|
pad_info: Padding information (left, right, top, bottom) |
|
|
""" |
|
|
if not isinstance(in_tensor, torch.Tensor): |
|
|
raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)} - this is the source of F.pad() errors!") |
|
|
|
|
|
|
|
|
h, w = in_tensor.shape[-2:] |
|
|
|
|
|
|
|
|
new_h = ((h + d - 1) // d) * d |
|
|
new_w = ((w + d - 1) // d) * d |
|
|
|
|
|
pad_h = new_h - h |
|
|
pad_w = new_w - w |
|
|
|
|
|
|
|
|
pad_top = pad_h // 2 |
|
|
pad_bottom = pad_h - pad_top |
|
|
pad_left = pad_w // 2 |
|
|
pad_right = pad_w - pad_left |
|
|
|
|
|
|
|
|
pad_array = (pad_left, pad_right, pad_top, pad_bottom) |
|
|
|
|
|
|
|
|
out = F.pad(in_tensor, pad_array, mode='reflect') |
|
|
|
|
|
return out, pad_array |
|
|
|
|
|
|
|
|
def unpad_tensor(padded_tensor: torch.Tensor, pad_info: Tuple[int, int, int, int]) -> torch.Tensor: |
|
|
""" |
|
|
Remove padding from tensor |
|
|
|
|
|
Args: |
|
|
padded_tensor: Padded tensor |
|
|
pad_info: Padding information (left, right, top, bottom) |
|
|
|
|
|
Returns: |
|
|
unpadded_tensor: Original size tensor |
|
|
""" |
|
|
if not isinstance(padded_tensor, torch.Tensor): |
|
|
raise TypeError(f"Expected torch.Tensor, got {type(padded_tensor)}") |
|
|
|
|
|
pad_left, pad_right, pad_top, pad_bottom = pad_info |
|
|
|
|
|
|
|
|
h, w = padded_tensor.shape[-2:] |
|
|
|
|
|
|
|
|
top = pad_top |
|
|
bottom = h - pad_bottom if pad_bottom > 0 else h |
|
|
left = pad_left |
|
|
right = w - pad_right if pad_right > 0 else w |
|
|
|
|
|
|
|
|
unpadded = padded_tensor[..., top:bottom, left:right] |
|
|
|
|
|
return unpadded |
|
|
|
|
|
|
|
|
def ensure_tensor(input_data: Union[torch.Tensor, np.ndarray], device: torch.device = None) -> torch.Tensor: |
|
|
""" |
|
|
Convert input to tensor if needed and move to device |
|
|
|
|
|
Args: |
|
|
input_data: Input data (tensor or numpy array) |
|
|
device: Target device |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Converted tensor |
|
|
""" |
|
|
if isinstance(input_data, np.ndarray): |
|
|
tensor = torch.from_numpy(input_data).float() |
|
|
elif isinstance(input_data, torch.Tensor): |
|
|
tensor = input_data.float() |
|
|
else: |
|
|
raise TypeError(f"Unsupported input type: {type(input_data)}") |
|
|
|
|
|
if device is not None: |
|
|
tensor = tensor.to(device) |
|
|
|
|
|
return tensor |
|
|
|
|
|
|
|
|
def normalize_tensor(tensor: torch.Tensor, target_range: Tuple[float, float] = (0.0, 1.0)) -> torch.Tensor: |
|
|
""" |
|
|
Normalize tensor to target range |
|
|
|
|
|
Args: |
|
|
tensor: Input tensor |
|
|
target_range: Target (min, max) range |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Normalized tensor |
|
|
""" |
|
|
if not isinstance(tensor, torch.Tensor): |
|
|
raise TypeError(f"Expected torch.Tensor, got {type(tensor)}") |
|
|
|
|
|
min_val, max_val = target_range |
|
|
|
|
|
|
|
|
tensor_min = tensor.min() |
|
|
tensor_max = tensor.max() |
|
|
|
|
|
if tensor_max > tensor_min: |
|
|
normalized = (tensor - tensor_min) / (tensor_max - tensor_min) |
|
|
else: |
|
|
normalized = tensor - tensor_min |
|
|
|
|
|
|
|
|
scaled = normalized * (max_val - min_val) + min_val |
|
|
|
|
|
return scaled |
|
|
|
|
|
|
|
|
def resize_tensor(tensor: torch.Tensor, |
|
|
size: Tuple[int, int], |
|
|
mode: str = 'bilinear', |
|
|
align_corners: bool = False) -> torch.Tensor: |
|
|
""" |
|
|
Resize tensor while maintaining tensor format |
|
|
|
|
|
Args: |
|
|
tensor: Input tensor (C, H, W) or (B, C, H, W) |
|
|
size: Target (height, width) |
|
|
mode: Interpolation mode |
|
|
align_corners: Align corners flag |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Resized tensor |
|
|
""" |
|
|
if not isinstance(tensor, torch.Tensor): |
|
|
raise TypeError(f"Expected torch.Tensor, got {type(tensor)}") |
|
|
|
|
|
original_dims = tensor.ndim |
|
|
|
|
|
|
|
|
if tensor.ndim == 3: |
|
|
tensor = tensor.unsqueeze(0) |
|
|
|
|
|
|
|
|
resized = F.interpolate(tensor, size=size, mode=mode, align_corners=align_corners) |
|
|
|
|
|
|
|
|
if original_dims == 3: |
|
|
resized = resized.squeeze(0) |
|
|
|
|
|
return resized |
|
|
|
|
|
|
|
|
def safe_tensor_operation(func): |
|
|
""" |
|
|
Decorator to ensure tensor operations receive tensor inputs |
|
|
""" |
|
|
def wrapper(*args, **kwargs): |
|
|
|
|
|
for i, arg in enumerate(args): |
|
|
if hasattr(arg, 'shape') and not isinstance(arg, torch.Tensor): |
|
|
raise TypeError(f"Argument {i} must be torch.Tensor, got {type(arg)}") |
|
|
|
|
|
return func(*args, **kwargs) |
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
@safe_tensor_operation |
|
|
def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: |
|
|
""" |
|
|
Safely convert tensor to numpy array |
|
|
|
|
|
Args: |
|
|
tensor: Input tensor |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Numpy array |
|
|
""" |
|
|
if tensor.requires_grad: |
|
|
tensor = tensor.detach() |
|
|
|
|
|
if tensor.is_cuda: |
|
|
tensor = tensor.cpu() |
|
|
|
|
|
return tensor.numpy() |
|
|
|
|
|
|
|
|
def validate_tensor_shapes(*tensors: torch.Tensor, expected_dims: int = None) -> bool: |
|
|
""" |
|
|
Validate tensor shapes are compatible |
|
|
|
|
|
Args: |
|
|
tensors: Input tensors to validate |
|
|
expected_dims: Expected number of dimensions |
|
|
|
|
|
Returns: |
|
|
bool: True if valid |
|
|
""" |
|
|
if not tensors: |
|
|
return True |
|
|
|
|
|
if expected_dims is not None: |
|
|
for tensor in tensors: |
|
|
if tensor.ndim != expected_dims: |
|
|
raise ValueError(f"Expected {expected_dims}D tensor, got {tensor.ndim}D") |
|
|
|
|
|
|
|
|
reference_shape = tensors[0].shape[-2:] |
|
|
for tensor in tensors[1:]: |
|
|
if tensor.shape[-2:] != reference_shape: |
|
|
raise ValueError(f"Spatial dimensions mismatch: {reference_shape} vs {tensor.shape[-2:]}") |
|
|
|
|
|
return True |