import spaces # must be first! import os from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ( InitProcessGroupKwargs, ProjectConfiguration, set_seed ) import torch from contextlib import nullcontext import trimesh import gradio as gr from gradio_imageslider import ImageSlider from da2.utils.base import load_config from da2.utils.model import load_model from da2.utils.io import ( read_cv2_image, torch_transform, tensorize ) from da2.utils.vis import colorize_distance from da2.utils.d2pc import distance2pointcloud from datetime import ( timedelta, datetime ) import cv2 import numpy as np def prepare_to_run_demo(): config = load_config('configs/infer.json') kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout'])) accu_steps = config['accelerator']['accumulation_nsteps'] accelerator = Accelerator( gradient_accumulation_steps=accu_steps, mixed_precision=config['accelerator']['mixed_precision'], log_with=config['accelerator']['report_to'], project_config=ProjectConfiguration(project_dir='files/tmp'), kwargs_handlers=[kwargs] ) logger = get_logger(__name__, log_level='INFO') config['env']['logger'] = logger set_seed(config['env']['seed']) return config, accelerator def read_mask_demo(mask_path, shape): if mask_path is None: return np.ones(shape[1:]) > 0 mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) mask = mask > 0 return mask def load_infer_data_demo(image, mask, model_dtype, device): cv2_image = read_cv2_image(image) image = torch_transform(cv2_image) mask = read_mask_demo(mask, image.shape) image = tensorize(image, model_dtype, device) return image, cv2_image, mask def ply2glb(ply_path, glb_path): pcd = trimesh.load(ply_path) points = np.asarray(pcd.vertices) colors = np.asarray(pcd.visual.vertex_colors) cloud = trimesh.points.PointCloud(vertices=points, colors=colors) cloud.export(glb_path) @spaces.GPU def fn(image_path, mask_path): device = "cuda" if torch.cuda.is_available() else "cpu" name_base, _ = os.path.splitext(os.path.basename(image_path)) config, accelerator = prepare_to_run_demo() model = load_model(config, accelerator) model = model.to(device) image, cv2_image, mask = load_infer_data_demo(image_path, mask_path, model_dtype=config['spherevit']['dtype'], device=accelerator.device) if torch.backends.mps.is_available(): autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(accelerator.device.type) with autocast_ctx, torch.no_grad(): distance = model(image).cpu().numpy()[0] distance_vis = colorize_distance(distance, mask) save_path = f'files/cache/{name_base}.glb' normal_image = distance2pointcloud(distance, cv2_image, mask, save_path=save_path.replace('.glb', '.ply'), return_normal=True, save_distance=False) ply2glb(save_path.replace('.glb', '.ply'), save_path) return save_path, [normal_image, distance_vis] inputs = [ gr.Image(label="Input Image", type="filepath"), gr.Image(label="Input Mask", type="filepath"), ] outputs = [ gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Point Cloud"), gr.ImageSlider( label="Output Depth / Normal (transformed from the depth)", type="pil", slider_position=20, ) ] demo = gr.Interface( fn=fn, title="DA2: Depth Anything in Any Direction", description=""" Please consider starring our GitHub Repo if you find this demo useful! Note: the "Input Mask" is optional, all pixels are assumed to be valid if mask is None. """, inputs=inputs, outputs=outputs, examples=[ [os.path.join(os.path.dirname(__file__), "assets/demos/a1.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/a2.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/a3.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/a4.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/b0.png"), os.path.join(os.path.dirname(__file__), "assets/masks/b0.png")], [os.path.join(os.path.dirname(__file__), "assets/demos/b1.png"), os.path.join(os.path.dirname(__file__), "assets/masks/b1.png")], [os.path.join(os.path.dirname(__file__), "assets/demos/a5.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/a6.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/a7.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/a8.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/b2.png"), os.path.join(os.path.dirname(__file__), "assets/masks/b2.png")], [os.path.join(os.path.dirname(__file__), "assets/demos/b3.png"), os.path.join(os.path.dirname(__file__), "assets/masks/b3.png")], [os.path.join(os.path.dirname(__file__), "assets/demos/a9.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/a10.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/a11.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/a0.png"), None], [os.path.join(os.path.dirname(__file__), "assets/demos/b4.png"), os.path.join(os.path.dirname(__file__), "assets/masks/b4.png")], [os.path.join(os.path.dirname(__file__), "assets/demos/b5.png"), os.path.join(os.path.dirname(__file__), "assets/masks/b5.png")], ], examples_per_page=20 ) demo.launch( # server_name="0.0.0.0", # server_port=6381, )