|
|
|
|
|
from __future__ import annotations |
|
|
import torch |
|
|
|
|
|
def log_shape(tag: str, t: torch.Tensor) -> None: |
|
|
try: |
|
|
mn = float(t.min()) if t.numel() else float("nan") |
|
|
mx = float(t.max()) if t.numel() else float("nan") |
|
|
print(f"[interop] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} " |
|
|
f"range=[{mn:.4f},{mx:.4f}]") |
|
|
except Exception as e: |
|
|
print(f"[interop] {tag}: <log failed: {e!r}>") |
|
|
|
|
|
def _to_float01(x: torch.Tensor) -> torch.Tensor: |
|
|
x = x.to(torch.float32) |
|
|
if x.max() > 1.0: |
|
|
x = x / 255.0 |
|
|
return x.clamp_(0.0, 1.0) |
|
|
|
|
|
def _squeeze_bt(x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
if x.ndim == 5: |
|
|
if x.shape[1] == 1: |
|
|
x = x.squeeze(1) |
|
|
if x.ndim == 5 and x.shape[0] == 1: |
|
|
x = x.squeeze(0) |
|
|
|
|
|
if x.ndim == 4 and x.shape[0] == 1 and x.shape[1] == 1 and x.shape[-3] == 3: |
|
|
x = x.squeeze(1) |
|
|
return x |
|
|
|
|
|
def ensure_image_nchw( |
|
|
img: torch.Tensor, |
|
|
device: torch.device | str = "cuda", |
|
|
want_batched: bool = True, |
|
|
) -> torch.Tensor: |
|
|
img = img.to(device) |
|
|
img = _squeeze_bt(img) |
|
|
if img.ndim == 3: |
|
|
|
|
|
if img.shape[0] in (1,3): |
|
|
chw = img |
|
|
else: |
|
|
chw = img.permute(2,0,1) |
|
|
chw = _to_float01(chw.contiguous()) |
|
|
return chw.unsqueeze(0) if want_batched else chw |
|
|
if img.ndim == 4: |
|
|
N,A,B,C = img.shape |
|
|
if A == 3: |
|
|
nchw = img |
|
|
elif C == 3: |
|
|
nchw = img.permute(0,3,1,2) |
|
|
else: |
|
|
raise AssertionError(f"Cannot infer channels in image: {tuple(img.shape)}") |
|
|
return _to_float01(nchw.contiguous()) |
|
|
raise AssertionError(f"Image must be 3D/4D; got {tuple(img.shape)}") |
|
|
|
|
|
def ensure_mask_for_matanyone( |
|
|
mask: torch.Tensor, |
|
|
*, |
|
|
idx_mask: bool = False, |
|
|
threshold: float = 0.5, |
|
|
keep_soft: bool = False, |
|
|
device: torch.device | str = "cuda", |
|
|
) -> torch.Tensor: |
|
|
mask = mask.to(device) |
|
|
mask = _squeeze_bt(mask) |
|
|
|
|
|
if idx_mask: |
|
|
|
|
|
if mask.ndim == 3: |
|
|
if mask.shape[0] == 1: |
|
|
idx = (mask[0] >= threshold).to(torch.long) |
|
|
else: |
|
|
idx = torch.argmax(mask, dim=0).to(torch.long) |
|
|
idx = (idx > 0).to(torch.long) |
|
|
elif mask.ndim == 2: |
|
|
idx = (mask >= threshold).to(torch.long) |
|
|
else: |
|
|
raise AssertionError(f"idx mask must be 2D/3D; got {tuple(mask.shape)}") |
|
|
return idx |
|
|
|
|
|
|
|
|
if mask.ndim == 2: |
|
|
out = mask.unsqueeze(0) |
|
|
elif mask.ndim == 3: |
|
|
if mask.shape[0] == 1: |
|
|
out = mask |
|
|
else: |
|
|
|
|
|
areas = mask.sum(dim=(-2,-1)) |
|
|
out = mask[areas.argmax():areas.argmax()+1] |
|
|
else: |
|
|
raise AssertionError(f"mask must be 2D/3D; got {tuple(mask.shape)}") |
|
|
|
|
|
out = out.to(torch.float32) |
|
|
if not keep_soft: |
|
|
out = (out >= threshold).to(torch.float32) |
|
|
return out.clamp_(0.0, 1.0).contiguous() |
|
|
|