Spaces:
Running
Running
| import gradio as gr | |
| from PIL import Image | |
| import src.depth_pro as depth_pro | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import subprocess | |
| import spaces | |
| import torch | |
| import tempfile | |
| import os | |
| # Run the script to get pretrained models | |
| subprocess.run(["bash", "get_pretrained_models.sh"]) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # Load model and preprocessing transform | |
| model, transform = depth_pro.create_model_and_transforms() | |
| model = model.to(device) | |
| model.eval() | |
| def resize_image(image_path, max_size=1536): | |
| with Image.open(image_path) as img: | |
| # Calculate the new size while maintaining aspect ratio | |
| ratio = max_size / max(img.size) | |
| new_size = tuple([int(x * ratio) for x in img.size]) | |
| # Resize the image | |
| img = img.resize(new_size, Image.LANCZOS) | |
| # Create a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: | |
| img.save(temp_file, format="PNG") | |
| return temp_file.name | |
| def predict_depth(input_image): | |
| temp_file = None | |
| try: | |
| # Resize the input image | |
| temp_file = resize_image(input_image) | |
| # Preprocess the image | |
| result = depth_pro.load_rgb(temp_file) | |
| image = result[0] | |
| f_px = result[-1] # Assuming f_px is the last item in the returned tuple | |
| image = transform(image) | |
| image = image.to(device) | |
| # Run inference | |
| prediction = model.infer(image, f_px=f_px) | |
| depth = prediction["depth"] # Depth in [m] | |
| focallength_px = prediction["focallength_px"] # Focal length in pixels | |
| # Convert depth to numpy array if it's a torch tensor | |
| if isinstance(depth, torch.Tensor): | |
| depth = depth.cpu().numpy() | |
| # Ensure depth is a 2D numpy array | |
| if depth.ndim != 2: | |
| depth = depth.squeeze() | |
| # Calculate inverse depth | |
| inverse_depth = 1.0 / depth | |
| # Clip inverse depth to 0-10 range | |
| inverse_depth_clipped = np.clip(inverse_depth, 0, 10) | |
| # Normalize | |
| norm_depth = np.log(depth) | |
| # Create a color map | |
| plt.figure(figsize=(15.36, 15.36), dpi=100) # Set figure size to 1536x1536 pixels | |
| plt.imshow(norm_depth, cmap='viridis') | |
| plt.colorbar(label='Normalized Depth') | |
| plt.title('Predicted Normalized Depth Map') | |
| plt.axis('off') | |
| # Save the plot to a file | |
| output_path = "inverse_depth_map.png" | |
| plt.savefig(output_path, dpi=100, bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| return output_path, f"Focal length: {focallength_px:.2f} pixels" | |
| except Exception as e: | |
| return None, f"An error occurred: {str(e)}" | |
| finally: | |
| # Clean up the temporary file | |
| if temp_file and os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| # Example images | |
| example_images = [ | |
| "examples/lemur.jpg", | |
| ] | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_depth, | |
| inputs=gr.Image(type="filepath"), | |
| outputs=[ | |
| gr.Image(type="filepath", label="Inverse Depth Map", height=768, width=768), | |
| gr.Textbox(label="Focal Length or Error Message") | |
| ], | |
| title="DepthPro Demo", | |
| description="[DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload an image to predict its inverse depth map and focal length. Large images will be automatically resized to 1536x1536 pixels.", | |
| examples=example_images | |
| ) | |
| # Launch the interface | |
| iface.launch() |