File size: 6,465 Bytes
bbb6939
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""
Fixed MatAnyone Tensor Utilities
Ensures all tensor operations remain in tensor format
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Union


def pad_divide_by(in_tensor: torch.Tensor, d: int) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]:
    """
    FIXED VERSION: Pad tensor to be divisible by d
    
    Args:
        in_tensor: Input tensor (..., H, W)
        d: Divisor value
    
    Returns:
        padded_tensor: Padded tensor
        pad_info: Padding information (left, right, top, bottom)
    """
    if not isinstance(in_tensor, torch.Tensor):
        raise TypeError(f"Expected torch.Tensor, got {type(in_tensor)} - this is the source of F.pad() errors!")
    
    # Get spatial dimensions
    h, w = in_tensor.shape[-2:]
    
    # Calculate required padding
    new_h = ((h + d - 1) // d) * d
    new_w = ((w + d - 1) // d) * d
    
    pad_h = new_h - h
    pad_w = new_w - w
    
    # Split padding evenly on both sides
    pad_top = pad_h // 2
    pad_bottom = pad_h - pad_top
    pad_left = pad_w // 2
    pad_right = pad_w - pad_left
    
    # PyTorch padding format: (left, right, top, bottom)
    pad_array = (pad_left, pad_right, pad_top, pad_bottom)
    
    # CRITICAL: Ensure input is tensor before F.pad
    out = F.pad(in_tensor, pad_array, mode='reflect')
    
    return out, pad_array


def unpad_tensor(padded_tensor: torch.Tensor, pad_info: Tuple[int, int, int, int]) -> torch.Tensor:
    """
    Remove padding from tensor
    
    Args:
        padded_tensor: Padded tensor
        pad_info: Padding information (left, right, top, bottom)
    
    Returns:
        unpadded_tensor: Original size tensor
    """
    if not isinstance(padded_tensor, torch.Tensor):
        raise TypeError(f"Expected torch.Tensor, got {type(padded_tensor)}")
    
    pad_left, pad_right, pad_top, pad_bottom = pad_info
    
    # Get current dimensions
    h, w = padded_tensor.shape[-2:]
    
    # Calculate crop boundaries
    top = pad_top
    bottom = h - pad_bottom if pad_bottom > 0 else h
    left = pad_left  
    right = w - pad_right if pad_right > 0 else w
    
    # Crop tensor
    unpadded = padded_tensor[..., top:bottom, left:right]
    
    return unpadded


def ensure_tensor(input_data: Union[torch.Tensor, np.ndarray], device: torch.device = None) -> torch.Tensor:
    """
    Convert input to tensor if needed and move to device
    
    Args:
        input_data: Input data (tensor or numpy array)
        device: Target device
    
    Returns:
        torch.Tensor: Converted tensor
    """
    if isinstance(input_data, np.ndarray):
        tensor = torch.from_numpy(input_data).float()
    elif isinstance(input_data, torch.Tensor):
        tensor = input_data.float()
    else:
        raise TypeError(f"Unsupported input type: {type(input_data)}")
    
    if device is not None:
        tensor = tensor.to(device)
    
    return tensor


def normalize_tensor(tensor: torch.Tensor, target_range: Tuple[float, float] = (0.0, 1.0)) -> torch.Tensor:
    """
    Normalize tensor to target range
    
    Args:
        tensor: Input tensor
        target_range: Target (min, max) range
    
    Returns:
        torch.Tensor: Normalized tensor
    """
    if not isinstance(tensor, torch.Tensor):
        raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
    
    min_val, max_val = target_range
    
    # Normalize to [0, 1] first
    tensor_min = tensor.min()
    tensor_max = tensor.max()
    
    if tensor_max > tensor_min:
        normalized = (tensor - tensor_min) / (tensor_max - tensor_min)
    else:
        normalized = tensor - tensor_min
    
    # Scale to target range
    scaled = normalized * (max_val - min_val) + min_val
    
    return scaled


def resize_tensor(tensor: torch.Tensor, 
                 size: Tuple[int, int], 
                 mode: str = 'bilinear',
                 align_corners: bool = False) -> torch.Tensor:
    """
    Resize tensor while maintaining tensor format
    
    Args:
        tensor: Input tensor (C, H, W) or (B, C, H, W)
        size: Target (height, width)
        mode: Interpolation mode
        align_corners: Align corners flag
    
    Returns:
        torch.Tensor: Resized tensor
    """
    if not isinstance(tensor, torch.Tensor):
        raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
    
    original_dims = tensor.ndim
    
    # Add batch dimension if needed
    if tensor.ndim == 3:
        tensor = tensor.unsqueeze(0)
    
    # Resize
    resized = F.interpolate(tensor, size=size, mode=mode, align_corners=align_corners)
    
    # Remove batch dimension if it was added
    if original_dims == 3:
        resized = resized.squeeze(0)
    
    return resized


def safe_tensor_operation(func):
    """
    Decorator to ensure tensor operations receive tensor inputs
    """
    def wrapper(*args, **kwargs):
        # Check all args are tensors
        for i, arg in enumerate(args):
            if hasattr(arg, 'shape') and not isinstance(arg, torch.Tensor):
                raise TypeError(f"Argument {i} must be torch.Tensor, got {type(arg)}")
        
        return func(*args, **kwargs)
    
    return wrapper


@safe_tensor_operation  
def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
    """
    Safely convert tensor to numpy array
    
    Args:
        tensor: Input tensor
    
    Returns:
        np.ndarray: Numpy array
    """
    if tensor.requires_grad:
        tensor = tensor.detach()
    
    if tensor.is_cuda:
        tensor = tensor.cpu()
    
    return tensor.numpy()


def validate_tensor_shapes(*tensors: torch.Tensor, expected_dims: int = None) -> bool:
    """
    Validate tensor shapes are compatible
    
    Args:
        tensors: Input tensors to validate
        expected_dims: Expected number of dimensions
    
    Returns:
        bool: True if valid
    """
    if not tensors:
        return True
    
    if expected_dims is not None:
        for tensor in tensors:
            if tensor.ndim != expected_dims:
                raise ValueError(f"Expected {expected_dims}D tensor, got {tensor.ndim}D")
    
    # Check spatial dimensions match (last 2 dims)
    reference_shape = tensors[0].shape[-2:]
    for tensor in tensors[1:]:
        if tensor.shape[-2:] != reference_shape:
            raise ValueError(f"Spatial dimensions mismatch: {reference_shape} vs {tensor.shape[-2:]}")
    
    return True