demo / app.py
Tohru127's picture
Update app.py
05ad083 verified
raw
history blame
7.86 kB
import os
os.environ.setdefault("OMP_NUM_THREADS", "1") # silence libgomp spam on HF
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
from pathlib import Path
import io
import numpy as np
from PIL import Image
import torch
from transformers import GLPNForDepthEstimation, GLPNImageProcessor
import open3d as o3d
import gradio as gr
# ----------------------------
# Device & model (load once)
# ----------------------------
DEVICE = torch.device(
"cuda" if torch.cuda.is_available()
else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
)
PROCESSOR = GLPNImageProcessor.from_pretrained("vinvino02/glpn-nyu")
MODEL = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-nyu").to(DEVICE).eval()
# ----------------------------
# Helpers (faithful to main.py logic)
# ----------------------------
def _resize_like_main(pil_img: Image.Image, cap_h: int = 480):
"""Mirror your main.py: cap height at 480, then round down to multiple of 32, preserve aspect."""
new_h = min(pil_img.height, cap_h)
new_h -= (new_h % 32)
if new_h < 32:
new_h = 32
new_w = int(new_h * pil_img.width / pil_img.height)
return pil_img.resize((new_w, new_h), Image.BILINEAR), (pil_img.width, pil_img.height)
@torch.inference_mode()
def estimate_depth_glpn(pil_img: Image.Image) -> np.ndarray:
"""
GLPN forward that DOES NOT rely on .post_process_depth()
(fix for your AttributeError). We upsample back to the original size manually.
Returns depth in float32 (larger = farther).
"""
resized, (orig_w, orig_h) = _resize_like_main(pil_img)
inputs = PROCESSOR(images=resized, return_tensors="pt")
for k in inputs:
inputs[k] = inputs[k].to(DEVICE)
outputs = MODEL(**inputs)
pred = outputs.predicted_depth # [B, 1, h, w]
depth = pred[0, 0].float().detach().cpu().numpy() # resized size
# Resize depth back to original image size for downstream Open3D steps
depth_img = Image.fromarray(depth)
depth_full = depth_img.resize((orig_w, orig_h), Image.BILINEAR)
depth_full = np.array(depth_full).astype(np.float32)
return depth_full
def depth_vis(depth: np.ndarray) -> Image.Image:
"""Normalize depth to 0..255 for a PNG preview (like your matplotlib preview)."""
d = depth.copy()
d = d - np.nanmin(d)
maxv = np.nanmax(d)
if maxv <= 0:
maxv = 1.0
d = (255.0 * d / maxv).astype(np.uint8)
return Image.fromarray(d)
def rgbd_from_rgb_depth(rgb: Image.Image, depth_f32: np.ndarray) -> o3d.geometry.RGBDImage:
"""
Create Open3D RGBD using an 8-bit depth *preview* for visualization consistency
(same as your main.py normalization step).
"""
rgb_np = np.array(rgb)
# match your main.py: depth to 0..255 uint8 before feeding create_from_color_and_depth
d8 = (depth_f32 * 255.0 / (depth_f32.max() + 1e-8)).astype(np.uint8)
depth_o3d = o3d.geometry.Image(d8)
color_o3d = o3d.geometry.Image(rgb_np)
rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
color_o3d, depth_o3d, convert_rgb_to_intensity=False
)
return rgbd
def pointcloud_from_rgbd(rgbd: o3d.geometry.RGBDImage, w: int, h: int) -> o3d.geometry.PointCloud:
"""
Reproduce your simple pinhole intrinsics (fx=fy=500, cx=w/2, cy=h/2) and back-project.
"""
K = o3d.camera.PinholeCameraIntrinsic()
K.set_intrinsics(w, h, 500.0, 500.0, w / 2.0, h / 2.0)
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, K)
return pcd
def filter_pointcloud(pcd: o3d.geometry.PointCloud):
"""
Statistical outlier removal ~ your 'noise removal' step. Tuned conservatively.
"""
if len(pcd.points) == 0:
return pcd
cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
pcd_f = pcd.select_by_index(ind)
pcd_f.estimate_normals(
search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.05, max_nn=30)
)
return pcd_f
def poisson_mesh(pcd: o3d.geometry.PointCloud, rotate_up=True) -> o3d.geometry.TriangleMesh:
if len(pcd.points) == 0:
return o3d.geometry.TriangleMesh()
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
pcd, depth=10, n_threads=1
)
# Flip like your main.py (rotate X by pi)
if rotate_up:
R = mesh.get_rotation_matrix_from_xyz((np.pi, 0.0, 0.0))
mesh.rotate(R, center=(0, 0, 0))
mesh.compute_vertex_normals()
return mesh
def o3d_to_ply_bytes(geom: o3d.geometry.Geometry) -> bytes:
"""Serialize an Open3D geometry to .ply bytes (so Gradio can offer a download)."""
tmp = Path("tmp_out.ply")
if isinstance(geom, o3d.geometry.PointCloud):
o3d.io.write_point_cloud(str(tmp), geom)
else:
o3d.io.write_triangle_mesh(str(tmp), geom)
data = tmp.read_bytes()
tmp.unlink(missing_ok=True)
return data
def render_point_count(pcd: o3d.geometry.PointCloud) -> str:
return f"Points: {len(pcd.points):,}"
def render_face_count(mesh: o3d.geometry.TriangleMesh) -> str:
return f"Vertices: {len(mesh.vertices):,} | Triangles: {len(mesh.triangles):,}"
# ----------------------------
# Gradio pipeline
# ----------------------------
def pipeline(image: Image.Image):
logs = []
if image is None:
raise gr.Error("Please upload an image of a room.")
logs.append("Step 1 — Loaded image.")
image = image.convert("RGB")
w, h = image.size
# Depth
logs.append("Step 2 — Estimating depth with GLPN (vinvino02/glpn-nyu)…")
depth = estimate_depth_glpn(image)
depth_preview = depth_vis(depth)
# RGBD
logs.append("Step 3 — Creating RGBD image…")
rgbd = rgbd_from_rgb_depth(image, depth)
# Point cloud
logs.append("Step 4 — Back-projecting to point cloud…")
pcd = pointcloud_from_rgbd(rgbd, w, h)
logs.append("Step 5 — Filtering noise & estimating normals…")
pcd_f = filter_pointcloud(pcd)
# Mesh
logs.append("Step 6 — Poisson surface reconstruction…")
mesh = poisson_mesh(pcd_f, rotate_up=True)
# Prepare downloads
logs.append("Step 7 — Preparing downloads…")
pcd_bytes = o3d_to_ply_bytes(pcd_f)
mesh_bytes = o3d_to_ply_bytes(mesh)
# Small text stats
pcd_stats = render_point_count(pcd_f)
mesh_stats = render_face_count(mesh)
logs.append("Done.")
return (
image, # RGB preview
depth_preview, # Depth preview
pcd_stats, # point cloud stats
mesh_stats, # mesh stats
("point_cloud.ply", pcd_bytes),
("mesh.ply", mesh_bytes),
"\n".join(logs),
)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 2D → 3D (GLPN → RGBD → Point Cloud → Poisson Mesh)\nUpload a single image to reproduce your main.py workflow.")
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="Input Image")
run = gr.Button("Reconstruct 3D", variant="primary")
log_box = gr.Textbox(label="Log", lines=14, interactive=False)
with gr.Column():
rgb_out = gr.Image(label="RGB Preview", interactive=False)
depth_out = gr.Image(label="Depth Preview (8-bit normalized)", interactive=False)
pc_txt = gr.Markdown()
mesh_txt = gr.Markdown()
pc_file = gr.File(label="Download Point Cloud (.ply)")
mesh_file = gr.File(label="Download Mesh (.ply)")
run.click(
fn=pipeline,
inputs=[inp],
outputs=[rgb_out, depth_out, pc_txt, mesh_txt, pc_file, mesh_file, log_box],
api_name="reconstruct",
)
# IMPORTANT: older Spaces error came from using unsupported args like concurrency_count.
demo.queue() # default queue works across Gradio 4.x
demo.launch()