File size: 5,998 Bytes
e295367
6b12a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c75f84c
6b12a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503137a
6b12a63
7fc5c17
bc88909
6b12a63
 
7fc5c17
6b12a63
 
 
 
 
 
 
 
 
bc88909
6b12a63
 
3873404
6b12a63
 
 
 
 
 
 
 
 
 
3873404
6b12a63
 
 
 
 
 
 
 
9dd8da5
 
6b12a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c75f84c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import spaces  # must be first!
import os
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import (
    InitProcessGroupKwargs, 
    ProjectConfiguration, 
    set_seed
)
import torch
from contextlib import nullcontext
import trimesh
import gradio as gr
from gradio_imageslider import ImageSlider
from da2.utils.base import load_config
from da2.utils.model import load_model
from da2.utils.io import (
    read_cv2_image,
    torch_transform,
    tensorize
)
from da2.utils.vis import colorize_distance
from da2.utils.d2pc import distance2pointcloud
from datetime import (
    timedelta,
    datetime
)
import cv2
import numpy as np

def prepare_to_run_demo():
    config = load_config('configs/infer.json')
    kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout']))
    accu_steps = config['accelerator']['accumulation_nsteps']
    accelerator = Accelerator(
        gradient_accumulation_steps=accu_steps,
        mixed_precision=config['accelerator']['mixed_precision'],
        log_with=config['accelerator']['report_to'],
        project_config=ProjectConfiguration(project_dir='files/tmp'),
        kwargs_handlers=[kwargs]
    )
    logger = get_logger(__name__, log_level='INFO')
    config['env']['logger'] = logger
    set_seed(config['env']['seed'])
    return config, accelerator

def read_mask_demo(mask_path, shape):
    if mask_path is None:
        return np.ones(shape[1:]) > 0
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    mask = mask > 0
    return mask

def load_infer_data_demo(image, mask, model_dtype, device):
    cv2_image = read_cv2_image(image)
    image = torch_transform(cv2_image)
    mask = read_mask_demo(mask, image.shape)
    image = tensorize(image, model_dtype, device)
    return image, cv2_image, mask

def ply2glb(ply_path, glb_path):
    pcd = trimesh.load(ply_path)
    points = np.asarray(pcd.vertices)
    colors = np.asarray(pcd.visual.vertex_colors)
    cloud = trimesh.points.PointCloud(vertices=points, colors=colors)
    cloud.export(glb_path)

@spaces.GPU
def fn(image_path, mask_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    name_base, _ = os.path.splitext(os.path.basename(image_path))
    config, accelerator = prepare_to_run_demo()
    model = load_model(config, accelerator)
    model = model.to(device)
    image, cv2_image, mask = load_infer_data_demo(image_path, mask_path, 
        model_dtype=config['spherevit']['dtype'], device=accelerator.device)
    if torch.backends.mps.is_available():
        autocast_ctx = nullcontext()
    else:
        autocast_ctx = torch.autocast(accelerator.device.type)
    with autocast_ctx, torch.no_grad():
        distance = model(image).cpu().numpy()[0]
        distance_vis = colorize_distance(distance, mask)
        save_path = f'files/cache/{name_base}.glb'
        normal_image = distance2pointcloud(distance, cv2_image, mask, save_path=save_path.replace('.glb', '.ply'), return_normal=True, save_distance=False)
        ply2glb(save_path.replace('.glb', '.ply'), save_path)
        return save_path, [normal_image, distance_vis]

inputs = [
    gr.Image(label="Input Image", type="filepath"),
    gr.Image(label="Input Mask", type="filepath"),
]
outputs = [
    gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0],  label="3D Point Cloud"),
    gr.ImageSlider(
        label="Output Depth / Normal (transformed from the depth)",
        type="pil",
        slider_position=20,
    )
]

demo = gr.Interface(
    fn=fn,
    title="DA<sup>2</sup>: <u>D</u>epth <u>A</u>nything in <u>A</u>ny <u>D</u>irection",
    description="""
        <strong>Please consider starring <span style="color: orange">&#9733;</span> our <a href="https://github.com/EnVision-Research/DA-2" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this demo useful!</strong>

        Note: the "Input Mask" is optional, all pixels are assumed to be valid if mask is None.
    """,
    inputs=inputs,
    outputs=outputs,
    examples=[
        [os.path.join(os.path.dirname(__file__), "assets/demos/a1.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a2.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a3.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a4.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b0.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b0.png")],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b1.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b1.png")],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a5.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a6.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a7.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a8.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b2.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b2.png")],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b3.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b3.png")],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a9.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a10.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a11.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a0.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b4.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b4.png")],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b5.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b5.png")],
    ],
    examples_per_page=20
)

demo.launch(
    # server_name="0.0.0.0",
    # server_port=6381,
)