File size: 2,698 Bytes
5458ff3
69b2678
5458ff3
 
69b2678
5458ff3
 
69b2678
 
5458ff3
 
69b2678
5458ff3
 
69b2678
5458ff3
69b2678
5458ff3
 
 
 
69b2678
5458ff3
 
 
 
 
69b2678
5458ff3
 
 
 
 
69b2678
5458ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69b2678
5458ff3
 
 
 
69b2678
5458ff3
 
 
 
 
 
69b2678
5458ff3
 
 
69b2678
5458ff3
 
69b2678
5458ff3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# code_depth/depth_infer.py
import os
from pathlib import Path
import numpy as np
import torch
import cv2
import matplotlib.cm as cm
from PIL import Image

# 让 `from video_depth_anything.video_depth import VideoDepthAnything` 能被找到
HERE = Path(__file__).resolve().parent
import sys
if str(HERE) not in sys.path:
    sys.path.append(str(HERE))

from video_depth_anything.video_depth import VideoDepthAnything  # noqa

_MODEL_CFGS = {
    'vits': {'encoder': 'vits', 'features': 64,  'out_channels': [48, 96, 192, 384]},
    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
}

class DepthModel:
    def __init__(self, repo_root: Path, encoder: str = "vitl", device: str | None = None):
        self.encoder = encoder
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = VideoDepthAnything(**_MODEL_CFGS[encoder]).to(self.device).eval()

        ckpt = repo_root / "code_depth" / "checkpoints" / f"video_depth_anything_{encoder}.pth"
        if not ckpt.is_file():
            raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
        state = torch.load(str(ckpt), map_location="cpu")
        self.model.load_state_dict(state, strict=True)

    @torch.inference_mode()
    def infer(
        self,
        image: Image.Image | np.ndarray,
        max_res: int = 1280,
        input_size: int = 518,
        fp32: bool = False,
        grayscale: bool = False,
    ) -> Image.Image:
        """返回一张深度可视化图(PIL.Image)。"""
        if isinstance(image, Image.Image):
            rgb = np.array(image.convert("RGB"))
        else:
            # 假设是 numpy 的 RGB/HWC
            assert image.ndim == 3 and image.shape[2] in (3, 4), "Expect HxWxC image"
            rgb = image[..., :3].copy()

        h, w = rgb.shape[:2]
        if max(h, w) > max_res:
            scale = max_res / max(h, w)
            rgb = cv2.resize(rgb, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)

        # 模型接口是“视频深度”,单帧就堆一维
        frame_tensor = np.stack([rgb], axis=0)
        depths, _ = self.model.infer_video_depth(
            frame_tensor, 32, input_size=input_size, device=self.device, fp32=fp32
        )
        depth = depths[0]

        # 可视化
        d_min, d_max = depth.min(), depth.max()
        depth_norm = ((depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8)

        if grayscale:
            return Image.fromarray(depth_norm, mode="L")

        cmap = np.array(cm.get_cmap("inferno").colors)
        depth_vis = (cmap[depth_norm] * 255).astype(np.uint8)
        return Image.fromarray(depth_vis)