GeoRemover / code_depth /depth_infer.py
zixinz
depth estimatro
5458ff3
# code_depth/depth_infer.py
import os
from pathlib import Path
import numpy as np
import torch
import cv2
import matplotlib.cm as cm
from PIL import Image
# 让 `from video_depth_anything.video_depth import VideoDepthAnything` 能被找到
HERE = Path(__file__).resolve().parent
import sys
if str(HERE) not in sys.path:
sys.path.append(str(HERE))
from video_depth_anything.video_depth import VideoDepthAnything # noqa
_MODEL_CFGS = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
}
class DepthModel:
def __init__(self, repo_root: Path, encoder: str = "vitl", device: str | None = None):
self.encoder = encoder
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = VideoDepthAnything(**_MODEL_CFGS[encoder]).to(self.device).eval()
ckpt = repo_root / "code_depth" / "checkpoints" / f"video_depth_anything_{encoder}.pth"
if not ckpt.is_file():
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
state = torch.load(str(ckpt), map_location="cpu")
self.model.load_state_dict(state, strict=True)
@torch.inference_mode()
def infer(
self,
image: Image.Image | np.ndarray,
max_res: int = 1280,
input_size: int = 518,
fp32: bool = False,
grayscale: bool = False,
) -> Image.Image:
"""返回一张深度可视化图(PIL.Image)。"""
if isinstance(image, Image.Image):
rgb = np.array(image.convert("RGB"))
else:
# 假设是 numpy 的 RGB/HWC
assert image.ndim == 3 and image.shape[2] in (3, 4), "Expect HxWxC image"
rgb = image[..., :3].copy()
h, w = rgb.shape[:2]
if max(h, w) > max_res:
scale = max_res / max(h, w)
rgb = cv2.resize(rgb, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
# 模型接口是“视频深度”,单帧就堆一维
frame_tensor = np.stack([rgb], axis=0)
depths, _ = self.model.infer_video_depth(
frame_tensor, 32, input_size=input_size, device=self.device, fp32=fp32
)
depth = depths[0]
# 可视化
d_min, d_max = depth.min(), depth.max()
depth_norm = ((depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8)
if grayscale:
return Image.fromarray(depth_norm, mode="L")
cmap = np.array(cm.get_cmap("inferno").colors)
depth_vis = (cmap[depth_norm] * 255).astype(np.uint8)
return Image.fromarray(depth_vis)