|
|
|
|
|
from __future__ import annotations |
|
|
import torch |
|
|
|
|
|
def log_shape(tag: str, t: torch.Tensor) -> None: |
|
|
try: |
|
|
mn = float(t.min()) if t.numel() else float("nan") |
|
|
mx = float(t.max()) if t.numel() else float("nan") |
|
|
print(f"[mask_bridge] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} " |
|
|
f"range=[{mn:.4f},{mx:.4f}]") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def sam2_to_matanyone_mask( |
|
|
sam2_masks: torch.Tensor, |
|
|
iou_scores: torch.Tensor | None, |
|
|
threshold: float = 0.5, |
|
|
return_mode: str = "single", |
|
|
keep_soft: bool = False, |
|
|
) -> torch.Tensor: |
|
|
assert sam2_masks.ndim == 4, f"Expect (B,M,H,W). Got {tuple(sam2_masks.shape)}" |
|
|
B, M, H, W = sam2_masks.shape |
|
|
assert B == 1, "Bridge expects B=1 for first-frame bootstrapping" |
|
|
|
|
|
candidates = sam2_masks[0] |
|
|
if iou_scores is not None and iou_scores.ndim == 2 and iou_scores.shape[0] == 1: |
|
|
best_idx = int(torch.argmax(iou_scores[0]).item()) |
|
|
else: |
|
|
areas = candidates.sum(dim=(-2,-1)) |
|
|
best_idx = int(torch.argmax(areas).item()) |
|
|
|
|
|
if return_mode == "multi": |
|
|
out = candidates |
|
|
else: |
|
|
out = candidates[best_idx:best_idx+1] |
|
|
|
|
|
out = out.to(torch.float32) |
|
|
if not keep_soft: |
|
|
out = (out >= threshold).float() |
|
|
out = out.clamp_(0.0, 1.0).contiguous() |
|
|
log_shape("sam2→mat.mask", out) |
|
|
return out |
|
|
|