File size: 19,692 Bytes
9507532
 
 
 
 
37de32d
9507532
37de32d
9507532
 
37de32d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9507532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37de32d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
"""
Inference utilities.
"""

import warnings
from typing import Any, Dict, List

import numpy as np
import torch

from mapanything.utils.geometry import (
    depth_edge,
    get_rays_in_camera_frame,
    normals_edge,
    points_to_normals,
    quaternion_to_rotation_matrix,
    recover_pinhole_intrinsics_from_ray_directions,
    rotation_matrix_to_quaternion,
)
from mapanything.utils.image import rgb

# Hard constraints - exactly what users can provide
ALLOWED_VIEW_KEYS = {
    "img",  # Required - input images
    "data_norm_type",  # Required - normalization type of the input images
    "depth_z",  # Optional - Z depth maps
    "ray_directions",  # Optional - ray directions in camera frame
    "intrinsics",  # Optional - pinhole camera intrinsics (conflicts with ray_directions)
    "camera_poses",  # Optional - camera poses
    "is_metric_scale",  # Optional - whether inputs are metric scale
    "true_shape",  # Optional - original image shape
    "idx",  # Optional - index of the view
    "instance",  # Optional - instance info of the view
}

REQUIRED_KEYS = {"img", "data_norm_type"}

# Define conflicting keys that cannot be used together
CONFLICTING_KEYS = [
    ("intrinsics", "ray_directions")  # Both represent camera projection
]


def loss_of_one_batch_multi_view(
    batch,
    model,
    criterion,
    device,
    use_amp=False,
    amp_dtype="bf16",
    ret=None,
    ignore_keys=None,
):
    """
    Calculate loss for a batch with multiple views.

    Args:
        batch (list): List of view dictionaries containing input data.
        model (torch.nn.Module): Model to run inference with.
        criterion (callable, optional): Loss function to compute the loss.
        device (torch.device): Device to run the computation on.
        use_amp (bool, optional): Whether to use automatic mixed precision. Defaults to False.
        amp_dtype (str, optional): Floating point type to use for automatic mixed precision. Options: ["fp32", "fp16", "bf16"]. Defaults to "bf16".
        ret (str, optional): If provided, return only the specified key from the result dictionary.
        ignore_keys (set, optional): Set of keys to ignore when moving tensors to device.
                                   Defaults to {"dataset", "label", "instance",
                                   "idx", "true_shape", "rng", "data_norm_type"}.

    Returns:
        dict or Any: If ret is None, returns a dictionary containing views, predictions, and loss.
                     Otherwise, returns the value associated with the ret key.
    """
    # Move necessary tensors to device
    if ignore_keys is None:
        ignore_keys = set(
            [
                "depthmap",
                "dataset",
                "label",
                "instance",
                "idx",
                "true_shape",
                "rng",
                "data_norm_type",
            ]
        )
    for view in batch:
        for name in view.keys():
            if name in ignore_keys:
                continue
            view[name] = view[name].to(device, non_blocking=True)

    # Determine the mixed precision floating point type
    if use_amp:
        if amp_dtype == "fp16":
            amp_dtype = torch.float16
        elif amp_dtype == "bf16":
            if torch.cuda.is_bf16_supported():
                amp_dtype = torch.bfloat16
            else:
                warnings.warn(
                    "bf16 is not supported on this device. Using fp16 instead."
                )
                amp_dtype = torch.float16
        elif amp_dtype == "fp32":
            amp_dtype = torch.float32
    else:
        amp_dtype = torch.float32

    # Run model and compute loss
    with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype):
        preds = model(batch)
        with torch.autocast("cuda", enabled=False):
            loss = criterion(batch, preds) if criterion is not None else None

    result = {f"view{i + 1}": view for i, view in enumerate(batch)}
    result.update({f"pred{i + 1}": pred for i, pred in enumerate(preds)})
    result["loss"] = loss

    return result[ret] if ret else result


def validate_input_views_for_inference(
    views: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
    """
    Strict validation and preprocessing of input views.

    Args:
        views: List of view dictionaries

    Returns:
        Validated and preprocessed views

    Raises:
        ValueError: For invalid keys, missing required keys, conflicting inputs, or invalid camera pose constraints
    """
    # Ensure input is not empty
    if not views:
        raise ValueError("At least one view must be provided")

    # Track which views have camera poses
    views_with_poses = []

    # Validate each view
    for view_idx, view in enumerate(views):
        # Check for invalid keys
        provided_keys = set(view.keys())
        invalid_keys = provided_keys - ALLOWED_VIEW_KEYS
        if invalid_keys:
            raise ValueError(
                f"View {view_idx} contains invalid keys: {invalid_keys}. "
                f"Allowed keys are: {sorted(ALLOWED_VIEW_KEYS)}"
            )

        # Check for missing required keys
        missing_keys = REQUIRED_KEYS - provided_keys
        if missing_keys:
            raise ValueError(f"View {view_idx} missing required keys: {missing_keys}")

        # Check for conflicting keys
        for conflict_set in CONFLICTING_KEYS:
            present_conflicts = [key for key in conflict_set if key in provided_keys]
            if len(present_conflicts) > 1:
                raise ValueError(
                    f"View {view_idx} contains conflicting keys: {present_conflicts}. "
                    f"Only one of {conflict_set} can be provided at a time."
                )

        # Check depth constraint: If depth is provided, intrinsics or ray_directions must also be provided
        if "depth_z" in provided_keys:
            if (
                "intrinsics" not in provided_keys
                and "ray_directions" not in provided_keys
            ):
                raise ValueError(
                    f"View {view_idx} depth constraint violation: If 'depth_z' is provided, "
                    f"then 'intrinsics' or 'ray_directions' must also be provided. "
                    f"Z Depth values require camera calibration information to be meaningful for an image."
                )

        # Track views with camera poses
        if "camera_poses" in provided_keys:
            views_with_poses.append(view_idx)

    # Cross-view constraint: If any view has camera_poses, view 0 must have them too
    if views_with_poses and 0 not in views_with_poses:
        raise ValueError(
            f"Camera pose constraint violation: Views {views_with_poses} have camera_poses, "
            f"but view 0 (reference view) does not. When using camera_poses, the first view "
            f"must also provide camera_poses to serve as the reference frame."
        )

    return views


def preprocess_input_views_for_inference(
    views: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
    """
    Pre-process input views to match the expected internal input format.

    The following steps are performed:
    1. Convert intrinsics to ray directions when required. If ray directions are already provided, unit normalize them.
    2. Convert depth_z to depth_along_ray
    3. Convert camera_poses to the expected input keys (camera_pose_quats and camera_pose_trans)
    4. Default is_metric_scale to True when not provided

    Args:
        views: List of view dictionaries

    Returns:
        Preprocessed views with consistent internal format
    """
    processed_views = []

    for view_idx, view in enumerate(views):
        # Copy the view dictionary to avoid modifying the original input
        processed_view = dict(view)

        # Step 1: Convert intrinsics to ray_directions when required. If ray_directions are provided, unit normalize them.
        if "intrinsics" in view:
            images = view["img"]
            height, width = images.shape[-2:]
            intrinsics = view["intrinsics"]
            _, ray_directions = get_rays_in_camera_frame(
                intrinsics=intrinsics,
                height=height,
                width=width,
                normalize_to_unit_sphere=True,
            )
            processed_view["ray_directions"] = ray_directions
            del processed_view["intrinsics"]
        elif "ray_directions" in view:
            ray_directions = view["ray_directions"]
            ray_norm = torch.norm(ray_directions, dim=-1, keepdim=True)
            processed_view["ray_directions"] = ray_directions / (ray_norm + 1e-8)

        # Step 2: Convert depth_z to depth_along_ray
        if "depth_z" in view:
            depth_z = view["depth_z"]
            ray_directions = processed_view["ray_directions"]
            ray_directions_unit_plane = ray_directions / ray_directions[..., 2:3]
            pts3d_cam = depth_z * ray_directions_unit_plane
            depth_along_ray = torch.norm(pts3d_cam, dim=-1, keepdim=True)
            processed_view["depth_along_ray"] = depth_along_ray
            del processed_view["depth_z"]

        # Step 3: Convert camera_poses to expected input keys
        if "camera_poses" in view:
            camera_poses = view["camera_poses"]
            if isinstance(camera_poses, tuple) and len(camera_poses) == 2:
                quats, trans = camera_poses
                processed_view["camera_pose_quats"] = quats
                processed_view["camera_pose_trans"] = trans
            elif torch.is_tensor(camera_poses) and camera_poses.shape[-2:] == (4, 4):
                rotation_matrices = camera_poses[:, :3, :3]
                translation_vectors = camera_poses[:, :3, 3]
                quats = rotation_matrix_to_quaternion(rotation_matrices)
                processed_view["camera_pose_quats"] = quats
                processed_view["camera_pose_trans"] = translation_vectors
            else:
                raise ValueError(
                    f"View {view_idx}: camera_poses must be either a tuple of (quats, trans) "
                    f"or a tensor of (B, 4, 4) transformation matrices."
                )
            del processed_view["camera_poses"]

        # Step 4: Default is_metric_scale to True when not provided
        if "is_metric_scale" not in processed_view:
            # Get batch size from the image tensor
            batch_size = view["img"].shape[0]
            # Default to True for all samples in the batch
            processed_view["is_metric_scale"] = torch.ones(
                batch_size, dtype=torch.bool, device=view["img"].device
            )

        # Rename keys to match expected model input format
        if "ray_directions" in processed_view:
            processed_view["ray_directions_cam"] = processed_view["ray_directions"]
            del processed_view["ray_directions"]

        # Append the processed view to the list
        processed_views.append(processed_view)

    return processed_views


def postprocess_model_outputs_for_inference(
    raw_outputs: List[Dict[str, torch.Tensor]],
    input_views: List[Dict[str, Any]],
    apply_mask: bool = True,
    mask_edges: bool = True,
    edge_normal_threshold: float = 5.0,
    edge_depth_threshold: float = 0.03,
    apply_confidence_mask: bool = False,
    confidence_percentile: float = 10,
) -> List[Dict[str, torch.Tensor]]:
    """
    Post-process raw model outputs by copying raw outputs and adding essential derived fields.

    This function simplifies the raw model outputs by:
    1. Copying all raw outputs as-is
    2. Adding denormalized images (img_no_norm)
    3. Adding Z depth (depth_z) from camera frame points
    4. Recovering pinhole camera intrinsics from ray directions
    5. Adding camera pose matrices (camera_poses) if pose data is available
    6. Applying mask to dense geometry outputs if requested (supports edge masking and confidence masking)

    Args:
        raw_outputs: List of raw model output dictionaries, one per view
        input_views: List of original input view dictionaries, one per view
        apply_mask: Whether to apply non-ambiguous mask to dense outputs. Defaults to True.
        mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
        apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False.
        confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10.

    Returns:
        List of processed output dictionaries containing:
            - All original raw outputs (after masking dense geometry outputs if requested)
            - 'img_no_norm': Denormalized RGB images (B, H, W, 3)
            - 'depth_z': Z depth from camera frame (B, H, W, 1) if points in camera frame available
            - 'intrinsics': Recovered pinhole camera intrinsics (B, 3, 3) if ray directions available
            - 'camera_poses': 4x4 pose matrices (B, 4, 4) if pose data available
            - 'mask': comprehensive mask for dense geometry outputs (B, H, W, 1) if requested

    """
    processed_outputs = []

    for view_idx, (raw_output, original_view) in enumerate(
        zip(raw_outputs, input_views)
    ):
        # Start by copying all raw outputs
        processed_output = dict(raw_output)

        # 1. Add denormalized images
        img = original_view["img"]  # Shape: (B, 3, H, W)
        data_norm_type = original_view["data_norm_type"][0]
        img_hwc = rgb(img, data_norm_type)

        # Convert numpy back to torch if needed (rgb returns numpy)
        if isinstance(img_hwc, np.ndarray):
            img_hwc = torch.from_numpy(img_hwc).to(img.device)

        processed_output["img_no_norm"] = img_hwc

        # 2. Add Z depth if we have camera frame points
        if "pts3d_cam" in processed_output:
            processed_output["depth_z"] = processed_output["pts3d_cam"][..., 2:3]

        # 3. Recover pinhole camera intrinsics from ray directions if available
        if "ray_directions" in processed_output:
            intrinsics = recover_pinhole_intrinsics_from_ray_directions(
                processed_output["ray_directions"]
            )
            processed_output["intrinsics"] = intrinsics

        # 4. Add camera pose matrices if both translation and quaternions are available
        if "cam_trans" in processed_output and "cam_quats" in processed_output:
            cam_trans = processed_output["cam_trans"]  # (B, 3)
            cam_quats = processed_output["cam_quats"]  # (B, 4)
            batch_size = cam_trans.shape[0]

            # Convert quaternions to rotation matrices
            rotation_matrices = quaternion_to_rotation_matrix(cam_quats)  # (B, 3, 3)

            # Create 4x4 pose matrices
            pose_matrices = (
                torch.eye(4, device=img.device).unsqueeze(0).repeat(batch_size, 1, 1)
            )
            pose_matrices[:, :3, :3] = rotation_matrices
            pose_matrices[:, :3, 3] = cam_trans

            processed_output["camera_poses"] = pose_matrices  # (B, 4, 4)

        # 5. Apply comprehensive mask to dense geometry outputs if requested
        if apply_mask:
            final_mask = None

            # Start with non-ambiguous mask if available
            if "non_ambiguous_mask" in processed_output:
                non_ambiguous_mask = (
                    processed_output["non_ambiguous_mask"].cpu().numpy()
                )  # (B, H, W)
                final_mask = non_ambiguous_mask

            # Apply confidence mask if requested and available
            if apply_confidence_mask and "conf" in processed_output:
                confidences = processed_output["conf"].cpu()  # (B, H, W)
                # Compute percentile threshold for each batch element
                batch_size = confidences.shape[0]
                conf_mask = torch.zeros_like(confidences, dtype=torch.bool)
                percentile_threshold = (
                    torch.quantile(
                        confidences.reshape(batch_size, -1),
                        confidence_percentile / 100.0,
                        dim=1,
                    )
                    .unsqueeze(-1)
                    .unsqueeze(-1)
                )  # Shape: (B, 1, 1)

                # Compute mask for each batch element
                conf_mask = confidences > percentile_threshold
                conf_mask = conf_mask.numpy()

                if final_mask is not None:
                    final_mask = final_mask & conf_mask
                else:
                    final_mask = conf_mask

            # Apply edge mask if requested and we have the required data
            if mask_edges and final_mask is not None and "pts3d" in processed_output:
                # Get 3D points for edge computation
                pred_pts3d = processed_output["pts3d"].cpu().numpy()  # (B, H, W, 3)
                batch_size, height, width = final_mask.shape

                edge_masks = []
                for b in range(batch_size):
                    batch_final_mask = final_mask[b]  # (H, W)
                    batch_pts3d = pred_pts3d[b]  # (H, W, 3)

                    if batch_final_mask.any():  # Only compute if we have valid points
                        # Compute normals and normal-based edge mask
                        normals, normals_mask = points_to_normals(
                            batch_pts3d, mask=batch_final_mask
                        )
                        normal_edges = normals_edge(
                            normals, tol=edge_normal_threshold, mask=normals_mask
                        )

                        # Compute depth-based edge mask
                        depth_z = (
                            processed_output["depth_z"][b].squeeze(-1).cpu().numpy()
                        )
                        depth_edges = depth_edge(
                            depth_z, rtol=edge_depth_threshold, mask=batch_final_mask
                        )

                        # Combine both edge types
                        edge_mask = ~(depth_edges & normal_edges)
                        edge_masks.append(edge_mask)
                    else:
                        # No valid points, keep all as invalid
                        edge_masks.append(np.zeros_like(batch_final_mask, dtype=bool))

                # Stack batch edge masks and combine with final mask
                edge_mask = np.stack(edge_masks, axis=0)  # (B, H, W)
                final_mask = final_mask & edge_mask

            # Apply final mask to dense geometry outputs if we have a mask
            if final_mask is not None:
                # Convert mask to torch tensor
                final_mask_torch = torch.from_numpy(final_mask).to(
                    processed_output["pts3d"].device
                )
                final_mask_torch = final_mask_torch.unsqueeze(-1)  # (B, H, W, 1)

                # Apply mask to dense geometry outputs (zero out invalid regions)
                dense_geometry_keys = [
                    "pts3d",
                    "pts3d_cam",
                    "depth_along_ray",
                    "depth_z",
                ]
                for key in dense_geometry_keys:
                    if key in processed_output:
                        processed_output[key] = processed_output[key] * final_mask_torch

                # Add mask to processed output
                processed_output["mask"] = final_mask_torch

        processed_outputs.append(processed_output)

    return processed_outputs