Spaces:
Running
on
Zero
Running
on
Zero
| # code_depth/depth_infer.py | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| import matplotlib.cm as cm | |
| from PIL import Image | |
| # 让 `from video_depth_anything.video_depth import VideoDepthAnything` 能被找到 | |
| HERE = Path(__file__).resolve().parent | |
| import sys | |
| if str(HERE) not in sys.path: | |
| sys.path.append(str(HERE)) | |
| from video_depth_anything.video_depth import VideoDepthAnything # noqa | |
| _MODEL_CFGS = { | |
| 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
| 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
| } | |
| class DepthModel: | |
| def __init__(self, repo_root: Path, encoder: str = "vitl", device: str | None = None): | |
| self.encoder = encoder | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = VideoDepthAnything(**_MODEL_CFGS[encoder]).to(self.device).eval() | |
| ckpt = repo_root / "code_depth" / "checkpoints" / f"video_depth_anything_{encoder}.pth" | |
| if not ckpt.is_file(): | |
| raise FileNotFoundError(f"Checkpoint not found: {ckpt}") | |
| state = torch.load(str(ckpt), map_location="cpu") | |
| self.model.load_state_dict(state, strict=True) | |
| def infer( | |
| self, | |
| image: Image.Image | np.ndarray, | |
| max_res: int = 1280, | |
| input_size: int = 518, | |
| fp32: bool = False, | |
| grayscale: bool = False, | |
| ) -> Image.Image: | |
| """返回一张深度可视化图(PIL.Image)。""" | |
| if isinstance(image, Image.Image): | |
| rgb = np.array(image.convert("RGB")) | |
| else: | |
| # 假设是 numpy 的 RGB/HWC | |
| assert image.ndim == 3 and image.shape[2] in (3, 4), "Expect HxWxC image" | |
| rgb = image[..., :3].copy() | |
| h, w = rgb.shape[:2] | |
| if max(h, w) > max_res: | |
| scale = max_res / max(h, w) | |
| rgb = cv2.resize(rgb, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR) | |
| # 模型接口是“视频深度”,单帧就堆一维 | |
| frame_tensor = np.stack([rgb], axis=0) | |
| depths, _ = self.model.infer_video_depth( | |
| frame_tensor, 32, input_size=input_size, device=self.device, fp32=fp32 | |
| ) | |
| depth = depths[0] | |
| # 可视化 | |
| d_min, d_max = depth.min(), depth.max() | |
| depth_norm = ((depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8) | |
| if grayscale: | |
| return Image.fromarray(depth_norm, mode="L") | |
| cmap = np.array(cm.get_cmap("inferno").colors) | |
| depth_vis = (cmap[depth_norm] * 255).astype(np.uint8) | |
| return Image.fromarray(depth_vis) | |