# code_depth/depth_infer.py import os from pathlib import Path import numpy as np import torch import cv2 import matplotlib.cm as cm from PIL import Image # 让 `from video_depth_anything.video_depth import VideoDepthAnything` 能被找到 HERE = Path(__file__).resolve().parent import sys if str(HERE) not in sys.path: sys.path.append(str(HERE)) from video_depth_anything.video_depth import VideoDepthAnything # noqa _MODEL_CFGS = { 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, } class DepthModel: def __init__(self, repo_root: Path, encoder: str = "vitl", device: str | None = None): self.encoder = encoder self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.model = VideoDepthAnything(**_MODEL_CFGS[encoder]).to(self.device).eval() ckpt = repo_root / "code_depth" / "checkpoints" / f"video_depth_anything_{encoder}.pth" if not ckpt.is_file(): raise FileNotFoundError(f"Checkpoint not found: {ckpt}") state = torch.load(str(ckpt), map_location="cpu") self.model.load_state_dict(state, strict=True) @torch.inference_mode() def infer( self, image: Image.Image | np.ndarray, max_res: int = 1280, input_size: int = 518, fp32: bool = False, grayscale: bool = False, ) -> Image.Image: """返回一张深度可视化图(PIL.Image)。""" if isinstance(image, Image.Image): rgb = np.array(image.convert("RGB")) else: # 假设是 numpy 的 RGB/HWC assert image.ndim == 3 and image.shape[2] in (3, 4), "Expect HxWxC image" rgb = image[..., :3].copy() h, w = rgb.shape[:2] if max(h, w) > max_res: scale = max_res / max(h, w) rgb = cv2.resize(rgb, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR) # 模型接口是“视频深度”,单帧就堆一维 frame_tensor = np.stack([rgb], axis=0) depths, _ = self.model.infer_video_depth( frame_tensor, 32, input_size=input_size, device=self.device, fp32=fp32 ) depth = depths[0] # 可视化 d_min, d_max = depth.min(), depth.max() depth_norm = ((depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8) if grayscale: return Image.fromarray(depth_norm, mode="L") cmap = np.array(cm.get_cmap("inferno").colors) depth_vis = (cmap[depth_norm] * 255).astype(np.uint8) return Image.fromarray(depth_vis)