# utils/mask_bridge.py from __future__ import annotations import torch def log_shape(tag: str, t: torch.Tensor) -> None: try: mn = float(t.min()) if t.numel() else float("nan") mx = float(t.max()) if t.numel() else float("nan") print(f"[mask_bridge] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} " f"range=[{mn:.4f},{mx:.4f}]") except Exception: pass def sam2_to_matanyone_mask( sam2_masks: torch.Tensor, # (B,M,H,W) after post_process iou_scores: torch.Tensor | None, # (B,M) or None threshold: float = 0.5, return_mode: str = "single", # "single"→(1,H,W) or "multi"→(C,H,W) keep_soft: bool = False, ) -> torch.Tensor: assert sam2_masks.ndim == 4, f"Expect (B,M,H,W). Got {tuple(sam2_masks.shape)}" B, M, H, W = sam2_masks.shape assert B == 1, "Bridge expects B=1 for first-frame bootstrapping" candidates = sam2_masks[0] # (M,H,W) if iou_scores is not None and iou_scores.ndim == 2 and iou_scores.shape[0] == 1: best_idx = int(torch.argmax(iou_scores[0]).item()) else: areas = candidates.sum(dim=(-2,-1)) best_idx = int(torch.argmax(areas).item()) if return_mode == "multi": out = candidates # (M,H,W) treat as (C,H,W) else: out = candidates[best_idx:best_idx+1] # (1,H,W) out = out.to(torch.float32) if not keep_soft: out = (out >= threshold).float() out = out.clamp_(0.0, 1.0).contiguous() log_shape("sam2→mat.mask", out) return out