File size: 3,268 Bytes
967f336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# utils/interop.py
from __future__ import annotations
import torch

def log_shape(tag: str, t: torch.Tensor) -> None:
    try:
        mn = float(t.min()) if t.numel() else float("nan")
        mx = float(t.max()) if t.numel() else float("nan")
        print(f"[interop] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
              f"range=[{mn:.4f},{mx:.4f}]")
    except Exception as e:
        print(f"[interop] {tag}: <log failed: {e!r}>")

def _to_float01(x: torch.Tensor) -> torch.Tensor:
    x = x.to(torch.float32)
    if x.max() > 1.0:
        x = x / 255.0
    return x.clamp_(0.0, 1.0)

def _squeeze_bt(x: torch.Tensor) -> torch.Tensor:
    # Drop singleton Time and extra Batch: (B,T,C,H,W) → (B,C,H,W) or (C,H,W)
    if x.ndim == 5:
        if x.shape[1] == 1:
            x = x.squeeze(1)  # drop T
        if x.ndim == 5 and x.shape[0] == 1:
            x = x.squeeze(0)  # drop B
    # Edge case: (1,1,3,H,W)
    if x.ndim == 4 and x.shape[0] == 1 and x.shape[1] == 1 and x.shape[-3] == 3:
        x = x.squeeze(1)  # → (1,3,H,W)
    return x

def ensure_image_nchw(
    img: torch.Tensor,
    device: torch.device | str = "cuda",
    want_batched: bool = True,
) -> torch.Tensor:
    img = img.to(device)
    img = _squeeze_bt(img)
    if img.ndim == 3:
        # CHW or HWC
        if img.shape[0] in (1,3):
            chw = img
        else:
            chw = img.permute(2,0,1)  # HWC→CHW
        chw = _to_float01(chw.contiguous())
        return chw.unsqueeze(0) if want_batched else chw
    if img.ndim == 4:
        N,A,B,C = img.shape
        if A == 3:
            nchw = img
        elif C == 3:
            nchw = img.permute(0,3,1,2)  # NHWC→NCHW
        else:
            raise AssertionError(f"Cannot infer channels in image: {tuple(img.shape)}")
        return _to_float01(nchw.contiguous())
    raise AssertionError(f"Image must be 3D/4D; got {tuple(img.shape)}")

def ensure_mask_for_matanyone(
    mask: torch.Tensor,
    *,
    idx_mask: bool = False,
    threshold: float = 0.5,
    keep_soft: bool = False,
    device: torch.device | str = "cuda",
) -> torch.Tensor:
    mask = mask.to(device)
    mask = _squeeze_bt(mask)

    if idx_mask:
        # Return (H,W) labels {0,1}
        if mask.ndim == 3:
            if mask.shape[0] == 1:
                idx = (mask[0] >= threshold).to(torch.long)
            else:
                idx = torch.argmax(mask, dim=0).to(torch.long)
                idx = (idx > 0).to(torch.long)
        elif mask.ndim == 2:
            idx = (mask >= threshold).to(torch.long)
        else:
            raise AssertionError(f"idx mask must be 2D/3D; got {tuple(mask.shape)}")
        return idx

    # Channel mask path → (1,H,W) float [0,1]
    if mask.ndim == 2:
        out = mask.unsqueeze(0)
    elif mask.ndim == 3:
        if mask.shape[0] == 1:
            out = mask
        else:
            # choose largest area channel
            areas = mask.sum(dim=(-2,-1))
            out = mask[areas.argmax():areas.argmax()+1]
    else:
        raise AssertionError(f"mask must be 2D/3D; got {tuple(mask.shape)}")

    out = out.to(torch.float32)
    if not keep_soft:
        out = (out >= threshold).to(torch.float32)
    return out.clamp_(0.0, 1.0).contiguous()