Spaces:
Running
Running
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import open3d as o3d | |
| import os | |
| from PIL import Image | |
| import tempfile | |
| import torch | |
| from transformers import GLPNImageProcessor, GLPNForDepthEstimation | |
| def predict_depth(image): | |
| feature_extractor = GLPNImageProcessor.from_pretrained("vinvino02/glpn-nyu") | |
| model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-nyu") | |
| # load and resize the input image | |
| new_height = 480 if image.height > 480 else image.height | |
| new_height -= (new_height % 32) | |
| new_width = int(new_height * image.width / image.height) | |
| diff = new_width % 32 | |
| new_width = new_width - diff if diff < 16 else new_width + 32 - diff | |
| new_size = (new_width, new_height) | |
| image = image.resize(new_size) | |
| # prepare image for the model | |
| inputs = feature_extractor(images=image, return_tensors="pt") | |
| # get the prediction from the model | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predicted_depth = outputs.predicted_depth | |
| output = predicted_depth.squeeze().cpu().numpy() * 1000.0 | |
| # remove borders | |
| pad = 16 | |
| output = output[pad:-pad, pad:-pad] | |
| image = image.crop((pad, pad, image.width - pad, image.height - pad)) | |
| return image, output | |
| def generate_mesh(image, depth_image, quality): | |
| width, height = image.size | |
| # depth_image = (depth_map * 255 / np.max(depth_map)).astype('uint8') | |
| image = np.array(image) | |
| # create rgbd image | |
| depth_o3d = o3d.geometry.Image(depth_image) | |
| image_o3d = o3d.geometry.Image(image) | |
| rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(image_o3d, depth_o3d, | |
| convert_rgb_to_intensity=False) | |
| # camera settings | |
| camera_intrinsic = o3d.camera.PinholeCameraIntrinsic() | |
| camera_intrinsic.set_intrinsics(width, height, 500, 500, width / 2, height / 2) | |
| # create point cloud | |
| pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, camera_intrinsic) | |
| # outliers removal | |
| cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=20.0) | |
| pcd = pcd.select_by_index(ind) | |
| # estimate normals | |
| pcd.estimate_normals() | |
| pcd.orient_normals_to_align_with_direction(orientation_reference=(0., 0., -1.)) | |
| # surface reconstruction | |
| mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=quality, n_threads=1)[0] | |
| # rotate the mesh | |
| rotation = mesh.get_rotation_matrix_from_xyz((np.pi, np.pi, 0)) | |
| mesh.rotate(rotation, center=(0, 0, 0)) | |
| # save the mesh | |
| temp_name = next(tempfile._get_candidate_names()) + '.obj' | |
| o3d.io.write_triangle_mesh(temp_name, mesh) | |
| return temp_name | |
| def predict(image, quality): | |
| image, depth_map = predict_depth(image) | |
| depth_image = (depth_map * 255 / np.max(depth_map)).astype('uint8') | |
| mesh_path = generate_mesh(image, depth_image, quality + 5) | |
| colormap = plt.get_cmap('plasma') | |
| depth_image = (colormap(depth_image) * 255).astype('uint8') | |
| depth_image = Image.fromarray(depth_image) | |
| return depth_image, mesh_path | |
| # GUI | |
| title = 'Image2Mesh' | |
| description = 'This demo predicts the depth of an image and then generates the 3D mesh. ' \ | |
| 'Choosing a higher quality increases the time to generate the mesh. You can download the mesh by ' \ | |
| 'clicking the top-right button on the 3D viewer. ' | |
| examples = [[f'examples/{name}', 3] for name in sorted(os.listdir('examples'))] | |
| # example image source: | |
| # N. Silberman, D. Hoiem, P. Kohli, and Rob Fergus, | |
| # Indoor Segmentation and Support Inference from RGBD Images (2012) | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type='pil', label='Input Image'), | |
| gr.Slider(1, 5, step=1, value=3, label='Mesh quality') | |
| ], | |
| outputs=[ | |
| gr.Image(label='Depth'), | |
| gr.Model3D(label='3D Model', clear_color=[0.0, 0.0, 0.0, 0.0]) | |
| ], | |
| examples=examples, | |
| allow_flagging='never', | |
| cache_examples=False, | |
| title=title, | |
| description=description | |
| ) | |
| iface.launch() |