MogensR's picture
Create utils/interop.py
967f336
# utils/interop.py
from __future__ import annotations
import torch
def log_shape(tag: str, t: torch.Tensor) -> None:
try:
mn = float(t.min()) if t.numel() else float("nan")
mx = float(t.max()) if t.numel() else float("nan")
print(f"[interop] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
f"range=[{mn:.4f},{mx:.4f}]")
except Exception as e:
print(f"[interop] {tag}: <log failed: {e!r}>")
def _to_float01(x: torch.Tensor) -> torch.Tensor:
x = x.to(torch.float32)
if x.max() > 1.0:
x = x / 255.0
return x.clamp_(0.0, 1.0)
def _squeeze_bt(x: torch.Tensor) -> torch.Tensor:
# Drop singleton Time and extra Batch: (B,T,C,H,W) → (B,C,H,W) or (C,H,W)
if x.ndim == 5:
if x.shape[1] == 1:
x = x.squeeze(1) # drop T
if x.ndim == 5 and x.shape[0] == 1:
x = x.squeeze(0) # drop B
# Edge case: (1,1,3,H,W)
if x.ndim == 4 and x.shape[0] == 1 and x.shape[1] == 1 and x.shape[-3] == 3:
x = x.squeeze(1) # → (1,3,H,W)
return x
def ensure_image_nchw(
img: torch.Tensor,
device: torch.device | str = "cuda",
want_batched: bool = True,
) -> torch.Tensor:
img = img.to(device)
img = _squeeze_bt(img)
if img.ndim == 3:
# CHW or HWC
if img.shape[0] in (1,3):
chw = img
else:
chw = img.permute(2,0,1) # HWC→CHW
chw = _to_float01(chw.contiguous())
return chw.unsqueeze(0) if want_batched else chw
if img.ndim == 4:
N,A,B,C = img.shape
if A == 3:
nchw = img
elif C == 3:
nchw = img.permute(0,3,1,2) # NHWC→NCHW
else:
raise AssertionError(f"Cannot infer channels in image: {tuple(img.shape)}")
return _to_float01(nchw.contiguous())
raise AssertionError(f"Image must be 3D/4D; got {tuple(img.shape)}")
def ensure_mask_for_matanyone(
mask: torch.Tensor,
*,
idx_mask: bool = False,
threshold: float = 0.5,
keep_soft: bool = False,
device: torch.device | str = "cuda",
) -> torch.Tensor:
mask = mask.to(device)
mask = _squeeze_bt(mask)
if idx_mask:
# Return (H,W) labels {0,1}
if mask.ndim == 3:
if mask.shape[0] == 1:
idx = (mask[0] >= threshold).to(torch.long)
else:
idx = torch.argmax(mask, dim=0).to(torch.long)
idx = (idx > 0).to(torch.long)
elif mask.ndim == 2:
idx = (mask >= threshold).to(torch.long)
else:
raise AssertionError(f"idx mask must be 2D/3D; got {tuple(mask.shape)}")
return idx
# Channel mask path → (1,H,W) float [0,1]
if mask.ndim == 2:
out = mask.unsqueeze(0)
elif mask.ndim == 3:
if mask.shape[0] == 1:
out = mask
else:
# choose largest area channel
areas = mask.sum(dim=(-2,-1))
out = mask[areas.argmax():areas.argmax()+1]
else:
raise AssertionError(f"mask must be 2D/3D; got {tuple(mask.shape)}")
out = out.to(torch.float32)
if not keep_soft:
out = (out >= threshold).to(torch.float32)
return out.clamp_(0.0, 1.0).contiguous()