""" Inference utilities. """ import warnings from typing import Any, Dict, List import numpy as np import torch from mapanything.utils.geometry import ( depth_edge, get_rays_in_camera_frame, normals_edge, points_to_normals, quaternion_to_rotation_matrix, recover_pinhole_intrinsics_from_ray_directions, rotation_matrix_to_quaternion, ) from mapanything.utils.image import rgb # Hard constraints - exactly what users can provide ALLOWED_VIEW_KEYS = { "img", # Required - input images "data_norm_type", # Required - normalization type of the input images "depth_z", # Optional - Z depth maps "ray_directions", # Optional - ray directions in camera frame "intrinsics", # Optional - pinhole camera intrinsics (conflicts with ray_directions) "camera_poses", # Optional - camera poses "is_metric_scale", # Optional - whether inputs are metric scale "true_shape", # Optional - original image shape "idx", # Optional - index of the view "instance", # Optional - instance info of the view } REQUIRED_KEYS = {"img", "data_norm_type"} # Define conflicting keys that cannot be used together CONFLICTING_KEYS = [ ("intrinsics", "ray_directions") # Both represent camera projection ] def loss_of_one_batch_multi_view( batch, model, criterion, device, use_amp=False, amp_dtype="bf16", ret=None, ignore_keys=None, ): """ Calculate loss for a batch with multiple views. Args: batch (list): List of view dictionaries containing input data. model (torch.nn.Module): Model to run inference with. criterion (callable, optional): Loss function to compute the loss. device (torch.device): Device to run the computation on. use_amp (bool, optional): Whether to use automatic mixed precision. Defaults to False. amp_dtype (str, optional): Floating point type to use for automatic mixed precision. Options: ["fp32", "fp16", "bf16"]. Defaults to "bf16". ret (str, optional): If provided, return only the specified key from the result dictionary. ignore_keys (set, optional): Set of keys to ignore when moving tensors to device. Defaults to {"dataset", "label", "instance", "idx", "true_shape", "rng", "data_norm_type"}. Returns: dict or Any: If ret is None, returns a dictionary containing views, predictions, and loss. Otherwise, returns the value associated with the ret key. """ # Move necessary tensors to device if ignore_keys is None: ignore_keys = set( [ "depthmap", "dataset", "label", "instance", "idx", "true_shape", "rng", "data_norm_type", ] ) for view in batch: for name in view.keys(): if name in ignore_keys: continue view[name] = view[name].to(device, non_blocking=True) # Determine the mixed precision floating point type if use_amp: if amp_dtype == "fp16": amp_dtype = torch.float16 elif amp_dtype == "bf16": if torch.cuda.is_bf16_supported(): amp_dtype = torch.bfloat16 else: warnings.warn( "bf16 is not supported on this device. Using fp16 instead." ) amp_dtype = torch.float16 elif amp_dtype == "fp32": amp_dtype = torch.float32 else: amp_dtype = torch.float32 # Run model and compute loss with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype): preds = model(batch) with torch.autocast("cuda", enabled=False): loss = criterion(batch, preds) if criterion is not None else None result = {f"view{i + 1}": view for i, view in enumerate(batch)} result.update({f"pred{i + 1}": pred for i, pred in enumerate(preds)}) result["loss"] = loss return result[ret] if ret else result def validate_input_views_for_inference( views: List[Dict[str, Any]], ) -> List[Dict[str, Any]]: """ Strict validation and preprocessing of input views. Args: views: List of view dictionaries Returns: Validated and preprocessed views Raises: ValueError: For invalid keys, missing required keys, conflicting inputs, or invalid camera pose constraints """ # Ensure input is not empty if not views: raise ValueError("At least one view must be provided") # Track which views have camera poses views_with_poses = [] # Validate each view for view_idx, view in enumerate(views): # Check for invalid keys provided_keys = set(view.keys()) invalid_keys = provided_keys - ALLOWED_VIEW_KEYS if invalid_keys: raise ValueError( f"View {view_idx} contains invalid keys: {invalid_keys}. " f"Allowed keys are: {sorted(ALLOWED_VIEW_KEYS)}" ) # Check for missing required keys missing_keys = REQUIRED_KEYS - provided_keys if missing_keys: raise ValueError(f"View {view_idx} missing required keys: {missing_keys}") # Check for conflicting keys for conflict_set in CONFLICTING_KEYS: present_conflicts = [key for key in conflict_set if key in provided_keys] if len(present_conflicts) > 1: raise ValueError( f"View {view_idx} contains conflicting keys: {present_conflicts}. " f"Only one of {conflict_set} can be provided at a time." ) # Check depth constraint: If depth is provided, intrinsics or ray_directions must also be provided if "depth_z" in provided_keys: if ( "intrinsics" not in provided_keys and "ray_directions" not in provided_keys ): raise ValueError( f"View {view_idx} depth constraint violation: If 'depth_z' is provided, " f"then 'intrinsics' or 'ray_directions' must also be provided. " f"Z Depth values require camera calibration information to be meaningful for an image." ) # Track views with camera poses if "camera_poses" in provided_keys: views_with_poses.append(view_idx) # Cross-view constraint: If any view has camera_poses, view 0 must have them too if views_with_poses and 0 not in views_with_poses: raise ValueError( f"Camera pose constraint violation: Views {views_with_poses} have camera_poses, " f"but view 0 (reference view) does not. When using camera_poses, the first view " f"must also provide camera_poses to serve as the reference frame." ) return views def preprocess_input_views_for_inference( views: List[Dict[str, Any]], ) -> List[Dict[str, Any]]: """ Pre-process input views to match the expected internal input format. The following steps are performed: 1. Convert intrinsics to ray directions when required. If ray directions are already provided, unit normalize them. 2. Convert depth_z to depth_along_ray 3. Convert camera_poses to the expected input keys (camera_pose_quats and camera_pose_trans) 4. Default is_metric_scale to True when not provided Args: views: List of view dictionaries Returns: Preprocessed views with consistent internal format """ processed_views = [] for view_idx, view in enumerate(views): # Copy the view dictionary to avoid modifying the original input processed_view = dict(view) # Step 1: Convert intrinsics to ray_directions when required. If ray_directions are provided, unit normalize them. if "intrinsics" in view: images = view["img"] height, width = images.shape[-2:] intrinsics = view["intrinsics"] _, ray_directions = get_rays_in_camera_frame( intrinsics=intrinsics, height=height, width=width, normalize_to_unit_sphere=True, ) processed_view["ray_directions"] = ray_directions del processed_view["intrinsics"] elif "ray_directions" in view: ray_directions = view["ray_directions"] ray_norm = torch.norm(ray_directions, dim=-1, keepdim=True) processed_view["ray_directions"] = ray_directions / (ray_norm + 1e-8) # Step 2: Convert depth_z to depth_along_ray if "depth_z" in view: depth_z = view["depth_z"] ray_directions = processed_view["ray_directions"] ray_directions_unit_plane = ray_directions / ray_directions[..., 2:3] pts3d_cam = depth_z * ray_directions_unit_plane depth_along_ray = torch.norm(pts3d_cam, dim=-1, keepdim=True) processed_view["depth_along_ray"] = depth_along_ray del processed_view["depth_z"] # Step 3: Convert camera_poses to expected input keys if "camera_poses" in view: camera_poses = view["camera_poses"] if isinstance(camera_poses, tuple) and len(camera_poses) == 2: quats, trans = camera_poses processed_view["camera_pose_quats"] = quats processed_view["camera_pose_trans"] = trans elif torch.is_tensor(camera_poses) and camera_poses.shape[-2:] == (4, 4): rotation_matrices = camera_poses[:, :3, :3] translation_vectors = camera_poses[:, :3, 3] quats = rotation_matrix_to_quaternion(rotation_matrices) processed_view["camera_pose_quats"] = quats processed_view["camera_pose_trans"] = translation_vectors else: raise ValueError( f"View {view_idx}: camera_poses must be either a tuple of (quats, trans) " f"or a tensor of (B, 4, 4) transformation matrices." ) del processed_view["camera_poses"] # Step 4: Default is_metric_scale to True when not provided if "is_metric_scale" not in processed_view: # Get batch size from the image tensor batch_size = view["img"].shape[0] # Default to True for all samples in the batch processed_view["is_metric_scale"] = torch.ones( batch_size, dtype=torch.bool, device=view["img"].device ) # Rename keys to match expected model input format if "ray_directions" in processed_view: processed_view["ray_directions_cam"] = processed_view["ray_directions"] del processed_view["ray_directions"] # Append the processed view to the list processed_views.append(processed_view) return processed_views def postprocess_model_outputs_for_inference( raw_outputs: List[Dict[str, torch.Tensor]], input_views: List[Dict[str, Any]], apply_mask: bool = True, mask_edges: bool = True, edge_normal_threshold: float = 5.0, edge_depth_threshold: float = 0.03, apply_confidence_mask: bool = False, confidence_percentile: float = 10, ) -> List[Dict[str, torch.Tensor]]: """ Post-process raw model outputs by copying raw outputs and adding essential derived fields. This function simplifies the raw model outputs by: 1. Copying all raw outputs as-is 2. Adding denormalized images (img_no_norm) 3. Adding Z depth (depth_z) from camera frame points 4. Recovering pinhole camera intrinsics from ray directions 5. Adding camera pose matrices (camera_poses) if pose data is available 6. Applying mask to dense geometry outputs if requested (supports edge masking and confidence masking) Args: raw_outputs: List of raw model output dictionaries, one per view input_views: List of original input view dictionaries, one per view apply_mask: Whether to apply non-ambiguous mask to dense outputs. Defaults to True. mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True. apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False. confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10. Returns: List of processed output dictionaries containing: - All original raw outputs (after masking dense geometry outputs if requested) - 'img_no_norm': Denormalized RGB images (B, H, W, 3) - 'depth_z': Z depth from camera frame (B, H, W, 1) if points in camera frame available - 'intrinsics': Recovered pinhole camera intrinsics (B, 3, 3) if ray directions available - 'camera_poses': 4x4 pose matrices (B, 4, 4) if pose data available - 'mask': comprehensive mask for dense geometry outputs (B, H, W, 1) if requested """ processed_outputs = [] for view_idx, (raw_output, original_view) in enumerate( zip(raw_outputs, input_views) ): # Start by copying all raw outputs processed_output = dict(raw_output) # 1. Add denormalized images img = original_view["img"] # Shape: (B, 3, H, W) data_norm_type = original_view["data_norm_type"][0] img_hwc = rgb(img, data_norm_type) # Convert numpy back to torch if needed (rgb returns numpy) if isinstance(img_hwc, np.ndarray): img_hwc = torch.from_numpy(img_hwc).to(img.device) processed_output["img_no_norm"] = img_hwc # 2. Add Z depth if we have camera frame points if "pts3d_cam" in processed_output: processed_output["depth_z"] = processed_output["pts3d_cam"][..., 2:3] # 3. Recover pinhole camera intrinsics from ray directions if available if "ray_directions" in processed_output: intrinsics = recover_pinhole_intrinsics_from_ray_directions( processed_output["ray_directions"] ) processed_output["intrinsics"] = intrinsics # 4. Add camera pose matrices if both translation and quaternions are available if "cam_trans" in processed_output and "cam_quats" in processed_output: cam_trans = processed_output["cam_trans"] # (B, 3) cam_quats = processed_output["cam_quats"] # (B, 4) batch_size = cam_trans.shape[0] # Convert quaternions to rotation matrices rotation_matrices = quaternion_to_rotation_matrix(cam_quats) # (B, 3, 3) # Create 4x4 pose matrices pose_matrices = ( torch.eye(4, device=img.device).unsqueeze(0).repeat(batch_size, 1, 1) ) pose_matrices[:, :3, :3] = rotation_matrices pose_matrices[:, :3, 3] = cam_trans processed_output["camera_poses"] = pose_matrices # (B, 4, 4) # 5. Apply comprehensive mask to dense geometry outputs if requested if apply_mask: final_mask = None # Start with non-ambiguous mask if available if "non_ambiguous_mask" in processed_output: non_ambiguous_mask = ( processed_output["non_ambiguous_mask"].cpu().numpy() ) # (B, H, W) final_mask = non_ambiguous_mask # Apply confidence mask if requested and available if apply_confidence_mask and "conf" in processed_output: confidences = processed_output["conf"].cpu() # (B, H, W) # Compute percentile threshold for each batch element batch_size = confidences.shape[0] conf_mask = torch.zeros_like(confidences, dtype=torch.bool) percentile_threshold = ( torch.quantile( confidences.reshape(batch_size, -1), confidence_percentile / 100.0, dim=1, ) .unsqueeze(-1) .unsqueeze(-1) ) # Shape: (B, 1, 1) # Compute mask for each batch element conf_mask = confidences > percentile_threshold conf_mask = conf_mask.numpy() if final_mask is not None: final_mask = final_mask & conf_mask else: final_mask = conf_mask # Apply edge mask if requested and we have the required data if mask_edges and final_mask is not None and "pts3d" in processed_output: # Get 3D points for edge computation pred_pts3d = processed_output["pts3d"].cpu().numpy() # (B, H, W, 3) batch_size, height, width = final_mask.shape edge_masks = [] for b in range(batch_size): batch_final_mask = final_mask[b] # (H, W) batch_pts3d = pred_pts3d[b] # (H, W, 3) if batch_final_mask.any(): # Only compute if we have valid points # Compute normals and normal-based edge mask normals, normals_mask = points_to_normals( batch_pts3d, mask=batch_final_mask ) normal_edges = normals_edge( normals, tol=edge_normal_threshold, mask=normals_mask ) # Compute depth-based edge mask depth_z = ( processed_output["depth_z"][b].squeeze(-1).cpu().numpy() ) depth_edges = depth_edge( depth_z, rtol=edge_depth_threshold, mask=batch_final_mask ) # Combine both edge types edge_mask = ~(depth_edges & normal_edges) edge_masks.append(edge_mask) else: # No valid points, keep all as invalid edge_masks.append(np.zeros_like(batch_final_mask, dtype=bool)) # Stack batch edge masks and combine with final mask edge_mask = np.stack(edge_masks, axis=0) # (B, H, W) final_mask = final_mask & edge_mask # Apply final mask to dense geometry outputs if we have a mask if final_mask is not None: # Convert mask to torch tensor final_mask_torch = torch.from_numpy(final_mask).to( processed_output["pts3d"].device ) final_mask_torch = final_mask_torch.unsqueeze(-1) # (B, H, W, 1) # Apply mask to dense geometry outputs (zero out invalid regions) dense_geometry_keys = [ "pts3d", "pts3d_cam", "depth_along_ray", "depth_z", ] for key in dense_geometry_keys: if key in processed_output: processed_output[key] = processed_output[key] * final_mask_torch # Add mask to processed output processed_output["mask"] = final_mask_torch processed_outputs.append(processed_output) return processed_outputs