File size: 6,465 Bytes
bbb6939 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
"""
Fixed MatAnyone Tensor Utilities
Ensures all tensor operations remain in tensor format
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Union
def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]:
"""
FIXED VERSION: Pad tensor to be divisible by d
Args:
in_tensor: Input tensor (..., H, W)
d: Divisor value
Returns:
padded_tensor: Padded tensor
pad_info: Padding information (left, right, top, bottom)
"""
if not isinstance(in_tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)} - this is the source of F.pad() errors!")
# Get spatial dimensions
h, w = in_tensor.shape[-2:]
# Calculate required padding
new_h = ((h + d - 1) // d) * d
new_w = ((w + d - 1) // d) * d
pad_h = new_h - h
pad_w = new_w - w
# Split padding evenly on both sides
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_left = pad_w // 2
pad_right = pad_w - pad_left
# PyTorch padding format: (left, right, top, bottom)
pad_array = (pad_left, pad_right, pad_top, pad_bottom)
# CRITICAL: Ensure input is tensor before F.pad
out = F.pad(in_tensor, pad_array, mode='reflect')
return out, pad_array
def unpad_tensor(padded_tensor: torch.Tensor, pad_info: Tuple[int, int, int, int]) -> torch.Tensor:
"""
Remove padding from tensor
Args:
padded_tensor: Padded tensor
pad_info: Padding information (left, right, top, bottom)
Returns:
unpadded_tensor: Original size tensor
"""
if not isinstance(padded_tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(padded_tensor)}")
pad_left, pad_right, pad_top, pad_bottom = pad_info
# Get current dimensions
h, w = padded_tensor.shape[-2:]
# Calculate crop boundaries
top = pad_top
bottom = h - pad_bottom if pad_bottom > 0 else h
left = pad_left
right = w - pad_right if pad_right > 0 else w
# Crop tensor
unpadded = padded_tensor[..., top:bottom, left:right]
return unpadded
def ensure_tensor(input_data: Union[torch.Tensor, np.ndarray], device: torch.device = None) -> torch.Tensor:
"""
Convert input to tensor if needed and move to device
Args:
input_data: Input data (tensor or numpy array)
device: Target device
Returns:
torch.Tensor: Converted tensor
"""
if isinstance(input_data, np.ndarray):
tensor = torch.from_numpy(input_data).float()
elif isinstance(input_data, torch.Tensor):
tensor = input_data.float()
else:
raise TypeError(f"Unsupported input type: {type(input_data)}")
if device is not None:
tensor = tensor.to(device)
return tensor
def normalize_tensor(tensor: torch.Tensor, target_range: Tuple[float, float] = (0.0, 1.0)) -> torch.Tensor:
"""
Normalize tensor to target range
Args:
tensor: Input tensor
target_range: Target (min, max) range
Returns:
torch.Tensor: Normalized tensor
"""
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
min_val, max_val = target_range
# Normalize to [0, 1] first
tensor_min = tensor.min()
tensor_max = tensor.max()
if tensor_max > tensor_min:
normalized = (tensor - tensor_min) / (tensor_max - tensor_min)
else:
normalized = tensor - tensor_min
# Scale to target range
scaled = normalized * (max_val - min_val) + min_val
return scaled
def resize_tensor(tensor: torch.Tensor,
size: Tuple[int, int],
mode: str = 'bilinear',
align_corners: bool = False) -> torch.Tensor:
"""
Resize tensor while maintaining tensor format
Args:
tensor: Input tensor (C, H, W) or (B, C, H, W)
size: Target (height, width)
mode: Interpolation mode
align_corners: Align corners flag
Returns:
torch.Tensor: Resized tensor
"""
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
original_dims = tensor.ndim
# Add batch dimension if needed
if tensor.ndim == 3:
tensor = tensor.unsqueeze(0)
# Resize
resized = F.interpolate(tensor, size=size, mode=mode, align_corners=align_corners)
# Remove batch dimension if it was added
if original_dims == 3:
resized = resized.squeeze(0)
return resized
def safe_tensor_operation(func):
"""
Decorator to ensure tensor operations receive tensor inputs
"""
def wrapper(*args, **kwargs):
# Check all args are tensors
for i, arg in enumerate(args):
if hasattr(arg, 'shape') and not isinstance(arg, torch.Tensor):
raise TypeError(f"Argument {i} must be torch.Tensor, got {type(arg)}")
return func(*args, **kwargs)
return wrapper
@safe_tensor_operation
def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
"""
Safely convert tensor to numpy array
Args:
tensor: Input tensor
Returns:
np.ndarray: Numpy array
"""
if tensor.requires_grad:
tensor = tensor.detach()
if tensor.is_cuda:
tensor = tensor.cpu()
return tensor.numpy()
def validate_tensor_shapes(*tensors: torch.Tensor, expected_dims: int = None) -> bool:
"""
Validate tensor shapes are compatible
Args:
tensors: Input tensors to validate
expected_dims: Expected number of dimensions
Returns:
bool: True if valid
"""
if not tensors:
return True
if expected_dims is not None:
for tensor in tensors:
if tensor.ndim != expected_dims:
raise ValueError(f"Expected {expected_dims}D tensor, got {tensor.ndim}D")
# Check spatial dimensions match (last 2 dims)
reference_shape = tensors[0].shape[-2:]
for tensor in tensors[1:]:
if tensor.shape[-2:] != reference_shape:
raise ValueError(f"Spatial dimensions mismatch: {reference_shape} vs {tensor.shape[-2:]}")
return True |