DA-2 / app.py
haodongli's picture
minor
327e881
import spaces # must be first!
import os
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import (
InitProcessGroupKwargs,
ProjectConfiguration,
set_seed
)
import torch
from contextlib import nullcontext
import trimesh
import gradio as gr
from gradio_imageslider import ImageSlider
from da2.utils.base import load_config
from da2.utils.model import load_model
from da2.utils.io import (
read_cv2_image,
torch_transform,
tensorize
)
from da2.utils.vis import colorize_distance
from da2.utils.d2pc import distance2pointcloud
from datetime import (
timedelta,
datetime
)
import cv2
import numpy as np
def prepare_to_run_demo():
config = load_config('configs/infer.json')
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout']))
accu_steps = config['accelerator']['accumulation_nsteps']
accelerator = Accelerator(
gradient_accumulation_steps=accu_steps,
mixed_precision=config['accelerator']['mixed_precision'],
log_with=config['accelerator']['report_to'],
project_config=ProjectConfiguration(project_dir='files/tmp'),
kwargs_handlers=[kwargs]
)
logger = get_logger(__name__, log_level='INFO')
config['env']['logger'] = logger
set_seed(config['env']['seed'])
return config, accelerator
def read_mask_demo(mask_path, shape):
if mask_path is None:
return np.ones(shape[1:]) > 0
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = mask > 0
return mask
def load_infer_data_demo(image, mask, model_dtype, device):
cv2_image = read_cv2_image(image)
image = torch_transform(cv2_image)
mask = read_mask_demo(mask, image.shape)
image = tensorize(image, model_dtype, device)
return image, cv2_image, mask
def ply2glb(ply_path, glb_path):
pcd = trimesh.load(ply_path)
points = np.asarray(pcd.vertices)
colors = np.asarray(pcd.visual.vertex_colors)
cloud = trimesh.points.PointCloud(vertices=points, colors=colors)
cloud.export(glb_path)
@spaces.GPU
def fn(image_path, mask_path):
device = "cuda" if torch.cuda.is_available() else "cpu"
name_base, _ = os.path.splitext(os.path.basename(image_path))
config, accelerator = prepare_to_run_demo()
model = load_model(config, accelerator)
model = model.to(device)
image, cv2_image, mask = load_infer_data_demo(image_path, mask_path,
model_dtype=config['spherevit']['dtype'], device=accelerator.device)
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx, torch.no_grad():
distance = model(image).cpu().numpy()[0]
distance_vis = colorize_distance(distance, mask)
save_path = f'files/cache/{name_base}.glb'
normal_image = distance2pointcloud(distance, cv2_image, mask, save_path=save_path.replace('.glb', '.ply'), return_normal=True, save_distance=False)
ply2glb(save_path.replace('.glb', '.ply'), save_path)
return save_path, [normal_image, distance_vis]
inputs = [
gr.Image(label="Input Image", type="filepath"),
gr.Image(label="Input Mask", type="filepath"),
]
outputs = [
gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Point Cloud"),
gr.ImageSlider(
label="Output Depth / Normal (transformed from the depth)",
type="pil",
slider_position=20,
)
]
demo = gr.Interface(
fn=fn,
title="DA<sup>2</sup>: <u>D</u>epth <u>A</u>nything in <u>A</u>ny <u>D</u>irection",
description="""
<strong>Please consider starring <span style="color: orange">&#9733;</span> our <a href="https://github.com/EnVision-Research/DA-2" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this demo useful!</strong>
Note: the "Input Mask" is optional, all pixels are assumed to be valid if mask is None.
""",
inputs=inputs,
outputs=outputs,
examples=[
[os.path.join(os.path.dirname(__file__), "assets/demos/a1.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/a2.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/a3.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/a4.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/b0.png"),
os.path.join(os.path.dirname(__file__), "assets/masks/b0.png")],
[os.path.join(os.path.dirname(__file__), "assets/demos/b1.png"),
os.path.join(os.path.dirname(__file__), "assets/masks/b1.png")],
[os.path.join(os.path.dirname(__file__), "assets/demos/a5.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/a6.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/a7.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/a8.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/b2.png"),
os.path.join(os.path.dirname(__file__), "assets/masks/b2.png")],
[os.path.join(os.path.dirname(__file__), "assets/demos/b3.png"),
os.path.join(os.path.dirname(__file__), "assets/masks/b3.png")],
[os.path.join(os.path.dirname(__file__), "assets/demos/a9.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/a10.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/a11.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/a0.png"), None],
[os.path.join(os.path.dirname(__file__), "assets/demos/b4.png"),
os.path.join(os.path.dirname(__file__), "assets/masks/b4.png")],
[os.path.join(os.path.dirname(__file__), "assets/demos/b5.png"),
os.path.join(os.path.dirname(__file__), "assets/masks/b5.png")],
],
examples_per_page=20
)
demo.launch(
# server_name="0.0.0.0",
# server_port=6381,
)