Spaces:
Running
Running
| """ | |
| Inference utilities. | |
| """ | |
| import warnings | |
| from typing import Any, Dict, List | |
| import numpy as np | |
| import torch | |
| from mapanything.utils.geometry import ( | |
| depth_edge, | |
| get_rays_in_camera_frame, | |
| normals_edge, | |
| points_to_normals, | |
| quaternion_to_rotation_matrix, | |
| recover_pinhole_intrinsics_from_ray_directions, | |
| rotation_matrix_to_quaternion, | |
| ) | |
| from mapanything.utils.image import rgb | |
| # Hard constraints - exactly what users can provide | |
| ALLOWED_VIEW_KEYS = { | |
| "img", # Required - input images | |
| "data_norm_type", # Required - normalization type of the input images | |
| "depth_z", # Optional - Z depth maps | |
| "ray_directions", # Optional - ray directions in camera frame | |
| "intrinsics", # Optional - pinhole camera intrinsics (conflicts with ray_directions) | |
| "camera_poses", # Optional - camera poses | |
| "is_metric_scale", # Optional - whether inputs are metric scale | |
| "true_shape", # Optional - original image shape | |
| "idx", # Optional - index of the view | |
| "instance", # Optional - instance info of the view | |
| } | |
| REQUIRED_KEYS = {"img", "data_norm_type"} | |
| # Define conflicting keys that cannot be used together | |
| CONFLICTING_KEYS = [ | |
| ("intrinsics", "ray_directions") # Both represent camera projection | |
| ] | |
| def loss_of_one_batch_multi_view( | |
| batch, | |
| model, | |
| criterion, | |
| device, | |
| use_amp=False, | |
| amp_dtype="bf16", | |
| ret=None, | |
| ignore_keys=None, | |
| ): | |
| """ | |
| Calculate loss for a batch with multiple views. | |
| Args: | |
| batch (list): List of view dictionaries containing input data. | |
| model (torch.nn.Module): Model to run inference with. | |
| criterion (callable, optional): Loss function to compute the loss. | |
| device (torch.device): Device to run the computation on. | |
| use_amp (bool, optional): Whether to use automatic mixed precision. Defaults to False. | |
| amp_dtype (str, optional): Floating point type to use for automatic mixed precision. Options: ["fp32", "fp16", "bf16"]. Defaults to "bf16". | |
| ret (str, optional): If provided, return only the specified key from the result dictionary. | |
| ignore_keys (set, optional): Set of keys to ignore when moving tensors to device. | |
| Defaults to {"dataset", "label", "instance", | |
| "idx", "true_shape", "rng", "data_norm_type"}. | |
| Returns: | |
| dict or Any: If ret is None, returns a dictionary containing views, predictions, and loss. | |
| Otherwise, returns the value associated with the ret key. | |
| """ | |
| # Move necessary tensors to device | |
| if ignore_keys is None: | |
| ignore_keys = set( | |
| [ | |
| "depthmap", | |
| "dataset", | |
| "label", | |
| "instance", | |
| "idx", | |
| "true_shape", | |
| "rng", | |
| "data_norm_type", | |
| ] | |
| ) | |
| for view in batch: | |
| for name in view.keys(): | |
| if name in ignore_keys: | |
| continue | |
| view[name] = view[name].to(device, non_blocking=True) | |
| # Determine the mixed precision floating point type | |
| if use_amp: | |
| if amp_dtype == "fp16": | |
| amp_dtype = torch.float16 | |
| elif amp_dtype == "bf16": | |
| if torch.cuda.is_bf16_supported(): | |
| amp_dtype = torch.bfloat16 | |
| else: | |
| warnings.warn( | |
| "bf16 is not supported on this device. Using fp16 instead." | |
| ) | |
| amp_dtype = torch.float16 | |
| elif amp_dtype == "fp32": | |
| amp_dtype = torch.float32 | |
| else: | |
| amp_dtype = torch.float32 | |
| # Run model and compute loss | |
| with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype): | |
| preds = model(batch) | |
| with torch.autocast("cuda", enabled=False): | |
| loss = criterion(batch, preds) if criterion is not None else None | |
| result = {f"view{i + 1}": view for i, view in enumerate(batch)} | |
| result.update({f"pred{i + 1}": pred for i, pred in enumerate(preds)}) | |
| result["loss"] = loss | |
| return result[ret] if ret else result | |
| def validate_input_views_for_inference( | |
| views: List[Dict[str, Any]], | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Strict validation and preprocessing of input views. | |
| Args: | |
| views: List of view dictionaries | |
| Returns: | |
| Validated and preprocessed views | |
| Raises: | |
| ValueError: For invalid keys, missing required keys, conflicting inputs, or invalid camera pose constraints | |
| """ | |
| # Ensure input is not empty | |
| if not views: | |
| raise ValueError("At least one view must be provided") | |
| # Track which views have camera poses | |
| views_with_poses = [] | |
| # Validate each view | |
| for view_idx, view in enumerate(views): | |
| # Check for invalid keys | |
| provided_keys = set(view.keys()) | |
| invalid_keys = provided_keys - ALLOWED_VIEW_KEYS | |
| if invalid_keys: | |
| raise ValueError( | |
| f"View {view_idx} contains invalid keys: {invalid_keys}. " | |
| f"Allowed keys are: {sorted(ALLOWED_VIEW_KEYS)}" | |
| ) | |
| # Check for missing required keys | |
| missing_keys = REQUIRED_KEYS - provided_keys | |
| if missing_keys: | |
| raise ValueError(f"View {view_idx} missing required keys: {missing_keys}") | |
| # Check for conflicting keys | |
| for conflict_set in CONFLICTING_KEYS: | |
| present_conflicts = [key for key in conflict_set if key in provided_keys] | |
| if len(present_conflicts) > 1: | |
| raise ValueError( | |
| f"View {view_idx} contains conflicting keys: {present_conflicts}. " | |
| f"Only one of {conflict_set} can be provided at a time." | |
| ) | |
| # Check depth constraint: If depth is provided, intrinsics or ray_directions must also be provided | |
| if "depth_z" in provided_keys: | |
| if ( | |
| "intrinsics" not in provided_keys | |
| and "ray_directions" not in provided_keys | |
| ): | |
| raise ValueError( | |
| f"View {view_idx} depth constraint violation: If 'depth_z' is provided, " | |
| f"then 'intrinsics' or 'ray_directions' must also be provided. " | |
| f"Z Depth values require camera calibration information to be meaningful for an image." | |
| ) | |
| # Track views with camera poses | |
| if "camera_poses" in provided_keys: | |
| views_with_poses.append(view_idx) | |
| # Cross-view constraint: If any view has camera_poses, view 0 must have them too | |
| if views_with_poses and 0 not in views_with_poses: | |
| raise ValueError( | |
| f"Camera pose constraint violation: Views {views_with_poses} have camera_poses, " | |
| f"but view 0 (reference view) does not. When using camera_poses, the first view " | |
| f"must also provide camera_poses to serve as the reference frame." | |
| ) | |
| return views | |
| def preprocess_input_views_for_inference( | |
| views: List[Dict[str, Any]], | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Pre-process input views to match the expected internal input format. | |
| The following steps are performed: | |
| 1. Convert intrinsics to ray directions when required. If ray directions are already provided, unit normalize them. | |
| 2. Convert depth_z to depth_along_ray | |
| 3. Convert camera_poses to the expected input keys (camera_pose_quats and camera_pose_trans) | |
| 4. Default is_metric_scale to True when not provided | |
| Args: | |
| views: List of view dictionaries | |
| Returns: | |
| Preprocessed views with consistent internal format | |
| """ | |
| processed_views = [] | |
| for view_idx, view in enumerate(views): | |
| # Copy the view dictionary to avoid modifying the original input | |
| processed_view = dict(view) | |
| # Step 1: Convert intrinsics to ray_directions when required. If ray_directions are provided, unit normalize them. | |
| if "intrinsics" in view: | |
| images = view["img"] | |
| height, width = images.shape[-2:] | |
| intrinsics = view["intrinsics"] | |
| _, ray_directions = get_rays_in_camera_frame( | |
| intrinsics=intrinsics, | |
| height=height, | |
| width=width, | |
| normalize_to_unit_sphere=True, | |
| ) | |
| processed_view["ray_directions"] = ray_directions | |
| del processed_view["intrinsics"] | |
| elif "ray_directions" in view: | |
| ray_directions = view["ray_directions"] | |
| ray_norm = torch.norm(ray_directions, dim=-1, keepdim=True) | |
| processed_view["ray_directions"] = ray_directions / (ray_norm + 1e-8) | |
| # Step 2: Convert depth_z to depth_along_ray | |
| if "depth_z" in view: | |
| depth_z = view["depth_z"] | |
| ray_directions = processed_view["ray_directions"] | |
| ray_directions_unit_plane = ray_directions / ray_directions[..., 2:3] | |
| pts3d_cam = depth_z * ray_directions_unit_plane | |
| depth_along_ray = torch.norm(pts3d_cam, dim=-1, keepdim=True) | |
| processed_view["depth_along_ray"] = depth_along_ray | |
| del processed_view["depth_z"] | |
| # Step 3: Convert camera_poses to expected input keys | |
| if "camera_poses" in view: | |
| camera_poses = view["camera_poses"] | |
| if isinstance(camera_poses, tuple) and len(camera_poses) == 2: | |
| quats, trans = camera_poses | |
| processed_view["camera_pose_quats"] = quats | |
| processed_view["camera_pose_trans"] = trans | |
| elif torch.is_tensor(camera_poses) and camera_poses.shape[-2:] == (4, 4): | |
| rotation_matrices = camera_poses[:, :3, :3] | |
| translation_vectors = camera_poses[:, :3, 3] | |
| quats = rotation_matrix_to_quaternion(rotation_matrices) | |
| processed_view["camera_pose_quats"] = quats | |
| processed_view["camera_pose_trans"] = translation_vectors | |
| else: | |
| raise ValueError( | |
| f"View {view_idx}: camera_poses must be either a tuple of (quats, trans) " | |
| f"or a tensor of (B, 4, 4) transformation matrices." | |
| ) | |
| del processed_view["camera_poses"] | |
| # Step 4: Default is_metric_scale to True when not provided | |
| if "is_metric_scale" not in processed_view: | |
| # Get batch size from the image tensor | |
| batch_size = view["img"].shape[0] | |
| # Default to True for all samples in the batch | |
| processed_view["is_metric_scale"] = torch.ones( | |
| batch_size, dtype=torch.bool, device=view["img"].device | |
| ) | |
| # Rename keys to match expected model input format | |
| if "ray_directions" in processed_view: | |
| processed_view["ray_directions_cam"] = processed_view["ray_directions"] | |
| del processed_view["ray_directions"] | |
| # Append the processed view to the list | |
| processed_views.append(processed_view) | |
| return processed_views | |
| def postprocess_model_outputs_for_inference( | |
| raw_outputs: List[Dict[str, torch.Tensor]], | |
| input_views: List[Dict[str, Any]], | |
| apply_mask: bool = True, | |
| mask_edges: bool = True, | |
| edge_normal_threshold: float = 5.0, | |
| edge_depth_threshold: float = 0.03, | |
| apply_confidence_mask: bool = False, | |
| confidence_percentile: float = 10, | |
| ) -> List[Dict[str, torch.Tensor]]: | |
| """ | |
| Post-process raw model outputs by copying raw outputs and adding essential derived fields. | |
| This function simplifies the raw model outputs by: | |
| 1. Copying all raw outputs as-is | |
| 2. Adding denormalized images (img_no_norm) | |
| 3. Adding Z depth (depth_z) from camera frame points | |
| 4. Recovering pinhole camera intrinsics from ray directions | |
| 5. Adding camera pose matrices (camera_poses) if pose data is available | |
| 6. Applying mask to dense geometry outputs if requested (supports edge masking and confidence masking) | |
| Args: | |
| raw_outputs: List of raw model output dictionaries, one per view | |
| input_views: List of original input view dictionaries, one per view | |
| apply_mask: Whether to apply non-ambiguous mask to dense outputs. Defaults to True. | |
| mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True. | |
| apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False. | |
| confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10. | |
| Returns: | |
| List of processed output dictionaries containing: | |
| - All original raw outputs (after masking dense geometry outputs if requested) | |
| - 'img_no_norm': Denormalized RGB images (B, H, W, 3) | |
| - 'depth_z': Z depth from camera frame (B, H, W, 1) if points in camera frame available | |
| - 'intrinsics': Recovered pinhole camera intrinsics (B, 3, 3) if ray directions available | |
| - 'camera_poses': 4x4 pose matrices (B, 4, 4) if pose data available | |
| - 'mask': comprehensive mask for dense geometry outputs (B, H, W, 1) if requested | |
| """ | |
| processed_outputs = [] | |
| for view_idx, (raw_output, original_view) in enumerate( | |
| zip(raw_outputs, input_views) | |
| ): | |
| # Start by copying all raw outputs | |
| processed_output = dict(raw_output) | |
| # 1. Add denormalized images | |
| img = original_view["img"] # Shape: (B, 3, H, W) | |
| data_norm_type = original_view["data_norm_type"][0] | |
| img_hwc = rgb(img, data_norm_type) | |
| # Convert numpy back to torch if needed (rgb returns numpy) | |
| if isinstance(img_hwc, np.ndarray): | |
| img_hwc = torch.from_numpy(img_hwc).to(img.device) | |
| processed_output["img_no_norm"] = img_hwc | |
| # 2. Add Z depth if we have camera frame points | |
| if "pts3d_cam" in processed_output: | |
| processed_output["depth_z"] = processed_output["pts3d_cam"][..., 2:3] | |
| # 3. Recover pinhole camera intrinsics from ray directions if available | |
| if "ray_directions" in processed_output: | |
| intrinsics = recover_pinhole_intrinsics_from_ray_directions( | |
| processed_output["ray_directions"] | |
| ) | |
| processed_output["intrinsics"] = intrinsics | |
| # 4. Add camera pose matrices if both translation and quaternions are available | |
| if "cam_trans" in processed_output and "cam_quats" in processed_output: | |
| cam_trans = processed_output["cam_trans"] # (B, 3) | |
| cam_quats = processed_output["cam_quats"] # (B, 4) | |
| batch_size = cam_trans.shape[0] | |
| # Convert quaternions to rotation matrices | |
| rotation_matrices = quaternion_to_rotation_matrix(cam_quats) # (B, 3, 3) | |
| # Create 4x4 pose matrices | |
| pose_matrices = ( | |
| torch.eye(4, device=img.device).unsqueeze(0).repeat(batch_size, 1, 1) | |
| ) | |
| pose_matrices[:, :3, :3] = rotation_matrices | |
| pose_matrices[:, :3, 3] = cam_trans | |
| processed_output["camera_poses"] = pose_matrices # (B, 4, 4) | |
| # 5. Apply comprehensive mask to dense geometry outputs if requested | |
| if apply_mask: | |
| final_mask = None | |
| # Start with non-ambiguous mask if available | |
| if "non_ambiguous_mask" in processed_output: | |
| non_ambiguous_mask = ( | |
| processed_output["non_ambiguous_mask"].cpu().numpy() | |
| ) # (B, H, W) | |
| final_mask = non_ambiguous_mask | |
| # Apply confidence mask if requested and available | |
| if apply_confidence_mask and "conf" in processed_output: | |
| confidences = processed_output["conf"].cpu() # (B, H, W) | |
| # Compute percentile threshold for each batch element | |
| batch_size = confidences.shape[0] | |
| conf_mask = torch.zeros_like(confidences, dtype=torch.bool) | |
| percentile_threshold = ( | |
| torch.quantile( | |
| confidences.reshape(batch_size, -1), | |
| confidence_percentile / 100.0, | |
| dim=1, | |
| ) | |
| .unsqueeze(-1) | |
| .unsqueeze(-1) | |
| ) # Shape: (B, 1, 1) | |
| # Compute mask for each batch element | |
| conf_mask = confidences > percentile_threshold | |
| conf_mask = conf_mask.numpy() | |
| if final_mask is not None: | |
| final_mask = final_mask & conf_mask | |
| else: | |
| final_mask = conf_mask | |
| # Apply edge mask if requested and we have the required data | |
| if mask_edges and final_mask is not None and "pts3d" in processed_output: | |
| # Get 3D points for edge computation | |
| pred_pts3d = processed_output["pts3d"].cpu().numpy() # (B, H, W, 3) | |
| batch_size, height, width = final_mask.shape | |
| edge_masks = [] | |
| for b in range(batch_size): | |
| batch_final_mask = final_mask[b] # (H, W) | |
| batch_pts3d = pred_pts3d[b] # (H, W, 3) | |
| if batch_final_mask.any(): # Only compute if we have valid points | |
| # Compute normals and normal-based edge mask | |
| normals, normals_mask = points_to_normals( | |
| batch_pts3d, mask=batch_final_mask | |
| ) | |
| normal_edges = normals_edge( | |
| normals, tol=edge_normal_threshold, mask=normals_mask | |
| ) | |
| # Compute depth-based edge mask | |
| depth_z = ( | |
| processed_output["depth_z"][b].squeeze(-1).cpu().numpy() | |
| ) | |
| depth_edges = depth_edge( | |
| depth_z, rtol=edge_depth_threshold, mask=batch_final_mask | |
| ) | |
| # Combine both edge types | |
| edge_mask = ~(depth_edges & normal_edges) | |
| edge_masks.append(edge_mask) | |
| else: | |
| # No valid points, keep all as invalid | |
| edge_masks.append(np.zeros_like(batch_final_mask, dtype=bool)) | |
| # Stack batch edge masks and combine with final mask | |
| edge_mask = np.stack(edge_masks, axis=0) # (B, H, W) | |
| final_mask = final_mask & edge_mask | |
| # Apply final mask to dense geometry outputs if we have a mask | |
| if final_mask is not None: | |
| # Convert mask to torch tensor | |
| final_mask_torch = torch.from_numpy(final_mask).to( | |
| processed_output["pts3d"].device | |
| ) | |
| final_mask_torch = final_mask_torch.unsqueeze(-1) # (B, H, W, 1) | |
| # Apply mask to dense geometry outputs (zero out invalid regions) | |
| dense_geometry_keys = [ | |
| "pts3d", | |
| "pts3d_cam", | |
| "depth_along_ray", | |
| "depth_z", | |
| ] | |
| for key in dense_geometry_keys: | |
| if key in processed_output: | |
| processed_output[key] = processed_output[key] * final_mask_torch | |
| # Add mask to processed output | |
| processed_output["mask"] = final_mask_torch | |
| processed_outputs.append(processed_output) | |
| return processed_outputs | |