Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,698 Bytes
5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 69b2678 5458ff3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
# code_depth/depth_infer.py
import os
from pathlib import Path
import numpy as np
import torch
import cv2
import matplotlib.cm as cm
from PIL import Image
# 让 `from video_depth_anything.video_depth import VideoDepthAnything` 能被找到
HERE = Path(__file__).resolve().parent
import sys
if str(HERE) not in sys.path:
sys.path.append(str(HERE))
from video_depth_anything.video_depth import VideoDepthAnything # noqa
_MODEL_CFGS = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
}
class DepthModel:
def __init__(self, repo_root: Path, encoder: str = "vitl", device: str | None = None):
self.encoder = encoder
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = VideoDepthAnything(**_MODEL_CFGS[encoder]).to(self.device).eval()
ckpt = repo_root / "code_depth" / "checkpoints" / f"video_depth_anything_{encoder}.pth"
if not ckpt.is_file():
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
state = torch.load(str(ckpt), map_location="cpu")
self.model.load_state_dict(state, strict=True)
@torch.inference_mode()
def infer(
self,
image: Image.Image | np.ndarray,
max_res: int = 1280,
input_size: int = 518,
fp32: bool = False,
grayscale: bool = False,
) -> Image.Image:
"""返回一张深度可视化图(PIL.Image)。"""
if isinstance(image, Image.Image):
rgb = np.array(image.convert("RGB"))
else:
# 假设是 numpy 的 RGB/HWC
assert image.ndim == 3 and image.shape[2] in (3, 4), "Expect HxWxC image"
rgb = image[..., :3].copy()
h, w = rgb.shape[:2]
if max(h, w) > max_res:
scale = max_res / max(h, w)
rgb = cv2.resize(rgb, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
# 模型接口是“视频深度”,单帧就堆一维
frame_tensor = np.stack([rgb], axis=0)
depths, _ = self.model.infer_video_depth(
frame_tensor, 32, input_size=input_size, device=self.device, fp32=fp32
)
depth = depths[0]
# 可视化
d_min, d_max = depth.min(), depth.max()
depth_norm = ((depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8)
if grayscale:
return Image.fromarray(depth_norm, mode="L")
cmap = np.array(cm.get_cmap("inferno").colors)
depth_vis = (cmap[depth_norm] * 255).astype(np.uint8)
return Image.fromarray(depth_vis)
|