File size: 6,440 Bytes
0ca05b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch


def normalize_poses(extrinsics, padding=0.1, return_stats=False):
    """
    Normalize camera positions to unit cube, processing each batch separately
    
    Args:
        extrinsics: Camera extrinsic matrices with shape (B, S, 3, 4)
        padding: Boundary space within [0,1] range to prevent values near boundaries
        return_stats: Whether to return normalization statistics
    
    Returns:
        normalized_extrinsics: Normalized extrinsic matrices
        (optional) stats: Dictionary containing scale and translation information
    """
    B, S, _, _ = extrinsics.shape
    device = extrinsics.device
    
    # Check input validity and handle NaN/Inf values
    for i in range(B):
        if torch.isnan(extrinsics[i]).any() or torch.isinf(extrinsics[i]).any():
            print(f"Warning: dataset sample has NaN/Inf in extrinsics")
            extrinsics[i] = torch.nan_to_num(
                extrinsics[i], nan=0.0, posinf=1e6, neginf=-1e6
            )
    
    normalized_extrinsics = extrinsics.clone()
    
    # Store normalization parameters if needed
    if return_stats:
        stats = {
            'scale_factors': torch.zeros(B, device=device),
            'translation_vectors': torch.zeros(B, 3, device=device)
        }
    
    for b in range(B):
        # Extract camera positions for this batch
        positions = extrinsics[b, :, :3, 3]  # (S, 3)
        
        # Filter valid positions to ignore outliers
        valid_mask = torch.isfinite(positions).all(dim=1)  # (S,)
        
        if valid_mask.sum() == 0:
            # No valid positions, use default values
            print(f"Warning: Batch {b} has no valid camera positions")
            normalized_extrinsics[b, :, :3, 3] = 0.5  # Place at center
            if return_stats:
                stats['scale_factors'][b] = 1.0
                stats['translation_vectors'][b] = 0.0
            continue
        
        valid_positions = positions[valid_mask]
        
        # Calculate bounds using percentiles for robustness
        if valid_positions.shape[0] > 10:
            # Use 5% and 95% percentiles instead of min/max
            min_pos = torch.quantile(valid_positions, 0.05, dim=0)
            max_pos = torch.quantile(valid_positions, 0.95, dim=0)
        else:
            # Too few samples, use min/max
            min_pos = torch.min(valid_positions, dim=0)[0]
            max_pos = torch.max(valid_positions, dim=0)[0]
        
        # Calculate scale factor considering all dimensions
        pos_range = max_pos - min_pos
        
        # Add small epsilon to prevent dimension collapse
        eps = torch.maximum(
            torch.tensor(1e-6, device=device),
            torch.abs(max_pos) * 1e-6
        )
        pos_range = torch.maximum(pos_range, eps)
        
        # Use maximum range as scale factor for uniform scaling
        scale_factor = torch.max(pos_range)
        scale_factor = torch.clamp(scale_factor, min=1e-6, max=1e6)
        
        # Calculate center point for centering
        center = (min_pos + max_pos) / 2.0
        
        # Normalize: center first, then scale with padding
        actual_scale = scale_factor / (1 - 2 * padding)
        normalized_positions = (positions - center) / actual_scale + 0.5
        
        # Ensure all values are within valid range
        normalized_positions = torch.clamp(normalized_positions, 0.0, 1.0)
        
        # Handle invalid positions by setting them to scene center
        invalid_mask = ~torch.isfinite(positions).all(dim=1)
        if invalid_mask.any():
            normalized_positions[invalid_mask] = 0.5
        
        normalized_extrinsics[b, :, :3, 3] = normalized_positions
        
        if return_stats:
            stats['scale_factors'][b] = actual_scale
            stats['translation_vectors'][b] = center
    
    # Final validation
    assert torch.isfinite(normalized_extrinsics).all(), "Output contains non-finite values"
    
    if return_stats:
        return normalized_extrinsics, stats
    return normalized_extrinsics


def normalize_depth(depth, eps=1e-6, min_percentile=1, max_percentile=99):
    """
    Normalize depth values to [0, 1] range using percentile-based scaling.
    
    Args:
        depth: Input depth tensor with shape (B, S, H, W)
        eps: Small epsilon value to prevent division by zero
        min_percentile: Lower percentile for robust min calculation (default: 1)
        max_percentile: Upper percentile for robust max calculation (default: 99)
    
    Returns:
        normalized_depth: Depth tensor normalized to [0, 1] range with same shape (B, S, H, W)
    """
    B, S, H, W = depth.shape
    depth = depth.flatten(0,1)  # [B*S, H, W]
    
    # Handle invalid values
    depth = torch.nan_to_num(depth, nan=0.0, posinf=1e6, neginf=0.0)
    
    normalized_list = []
    for i in range(depth.shape[0]):
        depth_img = depth[i]  # [H, W]
        depth_flat = depth_img.flatten()
        
        # Filter out zero values if needed
        non_zero_mask = depth_flat > 0
        if non_zero_mask.sum() > 0:
            values_to_use = depth_flat[non_zero_mask]
        else:
            values_to_use = depth_flat
        
        # Only calculate percentiles when there are enough values
        if values_to_use.numel() > 100:  # Ensure enough samples for percentile calculation
            # Calculate min and max percentiles
            depth_min = torch.quantile(values_to_use, min_percentile/100.0)
            depth_max = torch.quantile(values_to_use, max_percentile/100.0)
        else:
            # If too few samples, use min/max values
            depth_min = values_to_use.min()
            depth_max = values_to_use.max()
        
        # Handle case where max equals min
        if depth_max == depth_min:
            depth_max = depth_min + 1.0
        
        # Use relative epsilon
        scale = torch.abs(depth_max - depth_min)
        eps_val = max(eps, scale.item() * eps)
        
        # Perform normalization
        depth_norm_img = (depth_img - depth_min) / (depth_max - depth_min + eps_val)
        
        # Ensure output is within [0,1] range
        depth_norm_img = torch.clamp(depth_norm_img, 0.0, 1.0)
        
        normalized_list.append(depth_norm_img)
    
    # Recombine all normalized images
    depth_norm = torch.stack(normalized_list)
    
    return depth_norm.reshape(B, S, H, W)