import torch def normalize_poses(extrinsics, padding=0.1, return_stats=False): """ Normalize camera positions to unit cube, processing each batch separately Args: extrinsics: Camera extrinsic matrices with shape (B, S, 3, 4) padding: Boundary space within [0,1] range to prevent values near boundaries return_stats: Whether to return normalization statistics Returns: normalized_extrinsics: Normalized extrinsic matrices (optional) stats: Dictionary containing scale and translation information """ B, S, _, _ = extrinsics.shape device = extrinsics.device # Check input validity and handle NaN/Inf values for i in range(B): if torch.isnan(extrinsics[i]).any() or torch.isinf(extrinsics[i]).any(): print(f"Warning: dataset sample has NaN/Inf in extrinsics") extrinsics[i] = torch.nan_to_num( extrinsics[i], nan=0.0, posinf=1e6, neginf=-1e6 ) normalized_extrinsics = extrinsics.clone() # Store normalization parameters if needed if return_stats: stats = { 'scale_factors': torch.zeros(B, device=device), 'translation_vectors': torch.zeros(B, 3, device=device) } for b in range(B): # Extract camera positions for this batch positions = extrinsics[b, :, :3, 3] # (S, 3) # Filter valid positions to ignore outliers valid_mask = torch.isfinite(positions).all(dim=1) # (S,) if valid_mask.sum() == 0: # No valid positions, use default values print(f"Warning: Batch {b} has no valid camera positions") normalized_extrinsics[b, :, :3, 3] = 0.5 # Place at center if return_stats: stats['scale_factors'][b] = 1.0 stats['translation_vectors'][b] = 0.0 continue valid_positions = positions[valid_mask] # Calculate bounds using percentiles for robustness if valid_positions.shape[0] > 10: # Use 5% and 95% percentiles instead of min/max min_pos = torch.quantile(valid_positions, 0.05, dim=0) max_pos = torch.quantile(valid_positions, 0.95, dim=0) else: # Too few samples, use min/max min_pos = torch.min(valid_positions, dim=0)[0] max_pos = torch.max(valid_positions, dim=0)[0] # Calculate scale factor considering all dimensions pos_range = max_pos - min_pos # Add small epsilon to prevent dimension collapse eps = torch.maximum( torch.tensor(1e-6, device=device), torch.abs(max_pos) * 1e-6 ) pos_range = torch.maximum(pos_range, eps) # Use maximum range as scale factor for uniform scaling scale_factor = torch.max(pos_range) scale_factor = torch.clamp(scale_factor, min=1e-6, max=1e6) # Calculate center point for centering center = (min_pos + max_pos) / 2.0 # Normalize: center first, then scale with padding actual_scale = scale_factor / (1 - 2 * padding) normalized_positions = (positions - center) / actual_scale + 0.5 # Ensure all values are within valid range normalized_positions = torch.clamp(normalized_positions, 0.0, 1.0) # Handle invalid positions by setting them to scene center invalid_mask = ~torch.isfinite(positions).all(dim=1) if invalid_mask.any(): normalized_positions[invalid_mask] = 0.5 normalized_extrinsics[b, :, :3, 3] = normalized_positions if return_stats: stats['scale_factors'][b] = actual_scale stats['translation_vectors'][b] = center # Final validation assert torch.isfinite(normalized_extrinsics).all(), "Output contains non-finite values" if return_stats: return normalized_extrinsics, stats return normalized_extrinsics def normalize_depth(depth, eps=1e-6, min_percentile=1, max_percentile=99): """ Normalize depth values to [0, 1] range using percentile-based scaling. Args: depth: Input depth tensor with shape (B, S, H, W) eps: Small epsilon value to prevent division by zero min_percentile: Lower percentile for robust min calculation (default: 1) max_percentile: Upper percentile for robust max calculation (default: 99) Returns: normalized_depth: Depth tensor normalized to [0, 1] range with same shape (B, S, H, W) """ B, S, H, W = depth.shape depth = depth.flatten(0,1) # [B*S, H, W] # Handle invalid values depth = torch.nan_to_num(depth, nan=0.0, posinf=1e6, neginf=0.0) normalized_list = [] for i in range(depth.shape[0]): depth_img = depth[i] # [H, W] depth_flat = depth_img.flatten() # Filter out zero values if needed non_zero_mask = depth_flat > 0 if non_zero_mask.sum() > 0: values_to_use = depth_flat[non_zero_mask] else: values_to_use = depth_flat # Only calculate percentiles when there are enough values if values_to_use.numel() > 100: # Ensure enough samples for percentile calculation # Calculate min and max percentiles depth_min = torch.quantile(values_to_use, min_percentile/100.0) depth_max = torch.quantile(values_to_use, max_percentile/100.0) else: # If too few samples, use min/max values depth_min = values_to_use.min() depth_max = values_to_use.max() # Handle case where max equals min if depth_max == depth_min: depth_max = depth_min + 1.0 # Use relative epsilon scale = torch.abs(depth_max - depth_min) eps_val = max(eps, scale.item() * eps) # Perform normalization depth_norm_img = (depth_img - depth_min) / (depth_max - depth_min + eps_val) # Ensure output is within [0,1] range depth_norm_img = torch.clamp(depth_norm_img, 0.0, 1.0) normalized_list.append(depth_norm_img) # Recombine all normalized images depth_norm = torch.stack(normalized_list) return depth_norm.reshape(B, S, H, W)