Tohru127 commited on
Commit
05ad083
·
verified ·
1 Parent(s): ba2bdb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -236
app.py CHANGED
@@ -1,249 +1,232 @@
1
- # Run-your-script (dynamic) — HF Spaces wrapper for main.py with user inputs
2
- import os, sys, io, time, glob, json, shlex, subprocess
 
 
3
  from pathlib import Path
4
- from typing import Optional, Tuple
 
 
 
 
 
5
 
 
6
  import gradio as gr
7
- from PIL import Image
8
 
9
- # Keep CPU runtimes stable
10
- os.environ.pop("OMP_NUM_THREADS", None)
11
- os.environ.setdefault("OMP_NUM_THREADS", "1")
12
- os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
13
- os.environ.setdefault("MKL_NUM_THREADS", "1")
14
- os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
15
- os.environ.setdefault("MPLBACKEND", "Agg")
16
-
17
- REPO_ROOT = Path(".").resolve()
18
-
19
- # ---- Defaults: adjust to match your script, or override in the UI ----
20
- DEFAULT_SCRIPT = str(REPO_ROOT / "main.py")
21
- DEFAULT_INPUT_PATH = str(REPO_ROOT / "ROOM.jpg") # where we'll save the uploaded image
22
- DEFAULT_WORKDIR = str(REPO_ROOT)
23
- DEFAULT_OUTPUT_DIR = str(REPO_ROOT / "outputs") # where your script writes results
24
-
25
- # ---------- helpers ----------
26
- def _save_image(img: Image.Image, path: str) -> str:
27
- p = Path(path); p.parent.mkdir(parents=True, exist_ok=True)
28
- img.convert("RGB").save(p, format="JPEG", quality=95)
29
- return str(p)
30
-
31
- def _pick_latest(patterns):
32
- newest = None; mt = -1
33
- for pat in patterns:
34
- for fp in glob.glob(pat):
35
- try:
36
- sz = os.path.getsize(fp)
37
- if sz <= 0: continue
38
- m = os.path.getmtime(fp)
39
- if m > mt:
40
- newest, mt = fp, m
41
- except Exception:
42
- pass
43
- return newest
44
-
45
- def _scan_outputs(output_dir: str):
46
- od = Path(output_dir)
47
- depth = _pick_latest([
48
- str(od / "depth_preview.*"),
49
- str(od / "*depth*.png"),
50
- str(od / "*depth*.jpg"),
51
- str(REPO_ROOT / "depth_preview.*"),
52
- ])
53
- pcd = _pick_latest([
54
- str(od / "point_cloud.ply"),
55
- str(od / "*.ply"),
56
- ])
57
- mesh = _pick_latest([
58
- str(od / "mesh.obj"),
59
- str(od / "*.obj"),
60
- str(od / "mesh.ply"),
61
- str(od / "*mesh*.ply"),
62
- str(od / "*.glb"),
63
- str(od / "*.gltf"),
64
- ])
65
- return depth, pcd, mesh
66
-
67
- def _compose_cli(script_path: str, base_args: str, kv_pairs: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  """
69
- base_args: free-form CLI string (e.g., "--poisson_depth 10 --out outputs")
70
- kv_pairs: JSON or 'key=value key2=value2' → becomes '--key value --key2 value2'
71
  """
72
- args = [sys.executable, script_path]
73
-
74
- # Add free-form args (if provided)
75
- if base_args and base_args.strip():
76
- args.extend(shlex.split(base_args.strip()))
77
-
78
- # Add key=value pairs
79
- if kv_pairs and kv_pairs.strip():
80
- # try JSON first
81
- as_json = None
82
- try:
83
- as_json = json.loads(kv_pairs)
84
- except Exception:
85
- pass
86
- if isinstance(as_json, dict):
87
- for k, v in as_json.items():
88
- if k.startswith("--"): args.append(k)
89
- else: args.append(f"--{k}")
90
- if v is not True and v is not None:
91
- args.append(str(v))
92
- else:
93
- # fallback: split by spaces, accept k=v tokens
94
- for token in shlex.split(kv_pairs.strip()):
95
- if "=" in token:
96
- k, v = token.split("=", 1)
97
- if k.startswith("--"): args.append(k)
98
- else: args.append(f"--{k}")
99
- args.append(v)
100
- else:
101
- # allow plain flags like --use_poisson
102
- args.append(token)
103
-
104
- return args
105
-
106
- # ---------- streaming runner ----------
107
- def _run_streaming(
108
- image,
109
- script_path,
110
- input_path,
111
- workdir,
112
- output_dir,
113
- freeform_args, # raw CLI string
114
- kv_args, # k=v pairs or JSON
115
- extra_env_json # ENV as JSON (optional)
116
- ):
117
- depth_path = None; pcd_path = None; mesh_path = None
118
- viewer_path = None
119
- log_buf = []
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if image is None:
122
- yield None, None, None, None, "Please upload an image."
123
- return
124
-
125
- # Save input where the script expects it
126
- try:
127
- saved = _save_image(image, input_path)
128
- log_buf.append(f"[app] Saved input → {saved}")
129
- except Exception as e:
130
- yield None, None, None, None, f"[Save error] {e}"
131
- return
132
-
133
- # Compose CLI
134
- try:
135
- args = _compose_cli(script_path, freeform_args, kv_args)
136
- # If the script expects a positional image path, add it here (uncomment if needed):
137
- # args.extend([saved])
138
- log_buf.append(f"[app] Running: {' '.join(shlex.quote(a) for a in args)}")
139
- except Exception as e:
140
- yield None, None, None, None, f"[Args error] {e}"
141
- return
142
-
143
- # Build environment
144
- env = os.environ.copy()
145
- if extra_env_json and extra_env_json.strip():
146
- try:
147
- env.update(json.loads(extra_env_json))
148
- except Exception as e:
149
- yield None, None, None, None, f"[ENV JSON parse error] {e}"
150
- return
151
-
152
- # Launch process, stream logs
153
- try:
154
- proc = subprocess.Popen(
155
- args, cwd=workdir, env=env,
156
- stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
157
- text=True, bufsize=1
158
- )
159
- except Exception as e:
160
- yield None, None, None, None, f"[Run error] {e}"
161
- return
162
-
163
- last_yield = time.time()
164
- for line in iter(proc.stdout.readline, ""):
165
- log_buf.append(line.rstrip("\n"))
166
- if time.time() - last_yield > 1.0:
167
- d, p, m = _scan_outputs(output_dir)
168
- depth_path = depth_path or d
169
- pcd_path = pcd_path or p
170
- mesh_path = mesh_path or m
171
- viewer_path = mesh_path or pcd_path
172
- yield depth_path, viewer_path, pcd_path, mesh_path, "\n".join(log_buf[-800:])
173
- last_yield = time.time()
174
-
175
- proc.wait()
176
-
177
- # Final scan
178
- d, p, m = _scan_outputs(output_dir)
179
- depth_path = depth_path or d
180
- pcd_path = pcd_path or p
181
- mesh_path = mesh_path or m
182
- viewer_path = mesh_path or pcd_path
183
- log_buf.append(f"[app] Script finished with return code {proc.returncode}")
184
-
185
- yield depth_path, viewer_path, pcd_path, mesh_path, "\n".join(log_buf[-2000:])
186
-
187
- # ---------- UI ----------
188
- with gr.Blocks(title="Run main.py — Dynamic Inputs") as demo:
189
- gr.Markdown(
190
- "## Run your `main.py` with dynamic user inputs\n"
191
- "- Upload an image (we’ll save it to the path your script expects)\n"
192
- "- Enter **CLI arguments** and/or **key=value** pairs (auto-converted to `--key value`)\n"
193
- "- We stream stdout/stderr live and show any depth/PCD/mesh files your script writes\n"
194
  )
195
 
 
 
 
 
196
  with gr.Row():
197
- with gr.Column(scale=1):
198
- img = gr.Image(type="pil", label="Upload image", interactive=True)
199
-
200
- with gr.Accordion("Script paths", open=False):
201
- script_path = gr.Textbox(value=DEFAULT_SCRIPT, label="Script path (e.g., main.py)")
202
- input_path = gr.Textbox(value=DEFAULT_INPUT_PATH, label="Save uploaded image to (path your script reads)")
203
- workdir = gr.Textbox(value=DEFAULT_WORKDIR, label="Working directory")
204
- output_dir = gr.Textbox(value=DEFAULT_OUTPUT_DIR, label="Output directory to scan")
205
-
206
- with gr.Accordion("Arguments", open=True):
207
- freeform_args = gr.Textbox(
208
- value="",
209
- placeholder="e.g., --poisson_depth 10 --out outputs",
210
- label="CLI arguments (free-form)"
211
- )
212
- kv_args = gr.Textbox(
213
- value="",
214
- placeholder='JSON or k=v (space-separated). e.g., {"poisson_depth":10, "out":"outputs"} or poisson_depth=10 out=outputs',
215
- label="Key=Value (auto --key value)"
216
- )
217
-
218
- with gr.Accordion("Environment (optional)", open=False):
219
- extra_env = gr.Textbox(
220
- value="{}",
221
- label="ENV as JSON",
222
- placeholder='e.g., {"OMP_NUM_THREADS":"1"}'
223
- )
224
-
225
- run_btn = gr.Button("Run script", variant="primary")
226
-
227
- with gr.Column(scale=2):
228
- with gr.Tabs():
229
- with gr.Tab("Depth"):
230
- depth_img = gr.Image(type="filepath", label="Depth preview (detected)")
231
- with gr.Tab("3D Reconstruction"):
232
- model3d = gr.Model3D(label="Mesh / Point Cloud (OBJ/PLY/GLB/GLTF)")
233
- with gr.Tab("Downloads"):
234
- pcd_file = gr.File(label="Point cloud (PLY)")
235
- mesh_file = gr.File(label="Mesh (OBJ/PLY/GLB/GLTF)")
236
- with gr.Tab("Logs"):
237
- logs = gr.Textbox(label="Live logs", lines=20)
238
-
239
- run_btn.click(
240
- _run_streaming,
241
- inputs=[img, script_path, input_path, workdir, output_dir, freeform_args, kv_args, extra_env],
242
- outputs=[depth_img, model3d, pcd_file, mesh_file, logs]
243
  )
244
 
245
- # Keep long jobs alive & serialized
246
- demo.queue(concurrency_count=1, max_size=8, status_update_rate=1.0)
247
-
248
- if __name__ == "__main__":
249
- demo.launch(show_error=True, server_keepalive_timeout=180)
 
1
+ import os
2
+ os.environ.setdefault("OMP_NUM_THREADS", "1") # silence libgomp spam on HF
3
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
4
+
5
  from pathlib import Path
6
+ import io
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ import torch
11
+ from transformers import GLPNForDepthEstimation, GLPNImageProcessor
12
 
13
+ import open3d as o3d
14
  import gradio as gr
 
15
 
16
+
17
+ # ----------------------------
18
+ # Device & model (load once)
19
+ # ----------------------------
20
+ DEVICE = torch.device(
21
+ "cuda" if torch.cuda.is_available()
22
+ else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
23
+ )
24
+ PROCESSOR = GLPNImageProcessor.from_pretrained("vinvino02/glpn-nyu")
25
+ MODEL = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-nyu").to(DEVICE).eval()
26
+
27
+
28
+ # ----------------------------
29
+ # Helpers (faithful to main.py logic)
30
+ # ----------------------------
31
+ def _resize_like_main(pil_img: Image.Image, cap_h: int = 480):
32
+ """Mirror your main.py: cap height at 480, then round down to multiple of 32, preserve aspect."""
33
+ new_h = min(pil_img.height, cap_h)
34
+ new_h -= (new_h % 32)
35
+ if new_h < 32:
36
+ new_h = 32
37
+ new_w = int(new_h * pil_img.width / pil_img.height)
38
+ return pil_img.resize((new_w, new_h), Image.BILINEAR), (pil_img.width, pil_img.height)
39
+
40
+
41
+ @torch.inference_mode()
42
+ def estimate_depth_glpn(pil_img: Image.Image) -> np.ndarray:
43
+ """
44
+ GLPN forward that DOES NOT rely on .post_process_depth()
45
+ (fix for your AttributeError). We upsample back to the original size manually.
46
+ Returns depth in float32 (larger = farther).
47
+ """
48
+ resized, (orig_w, orig_h) = _resize_like_main(pil_img)
49
+ inputs = PROCESSOR(images=resized, return_tensors="pt")
50
+ for k in inputs:
51
+ inputs[k] = inputs[k].to(DEVICE)
52
+
53
+ outputs = MODEL(**inputs)
54
+ pred = outputs.predicted_depth # [B, 1, h, w]
55
+ depth = pred[0, 0].float().detach().cpu().numpy() # resized size
56
+
57
+ # Resize depth back to original image size for downstream Open3D steps
58
+ depth_img = Image.fromarray(depth)
59
+ depth_full = depth_img.resize((orig_w, orig_h), Image.BILINEAR)
60
+ depth_full = np.array(depth_full).astype(np.float32)
61
+
62
+ return depth_full
63
+
64
+
65
+ def depth_vis(depth: np.ndarray) -> Image.Image:
66
+ """Normalize depth to 0..255 for a PNG preview (like your matplotlib preview)."""
67
+ d = depth.copy()
68
+ d = d - np.nanmin(d)
69
+ maxv = np.nanmax(d)
70
+ if maxv <= 0:
71
+ maxv = 1.0
72
+ d = (255.0 * d / maxv).astype(np.uint8)
73
+ return Image.fromarray(d)
74
+
75
+
76
+ def rgbd_from_rgb_depth(rgb: Image.Image, depth_f32: np.ndarray) -> o3d.geometry.RGBDImage:
77
+ """
78
+ Create Open3D RGBD using an 8-bit depth *preview* for visualization consistency
79
+ (same as your main.py normalization step).
80
+ """
81
+ rgb_np = np.array(rgb)
82
+ # match your main.py: depth to 0..255 uint8 before feeding create_from_color_and_depth
83
+ d8 = (depth_f32 * 255.0 / (depth_f32.max() + 1e-8)).astype(np.uint8)
84
+ depth_o3d = o3d.geometry.Image(d8)
85
+ color_o3d = o3d.geometry.Image(rgb_np)
86
+ rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
87
+ color_o3d, depth_o3d, convert_rgb_to_intensity=False
88
+ )
89
+ return rgbd
90
+
91
+
92
+ def pointcloud_from_rgbd(rgbd: o3d.geometry.RGBDImage, w: int, h: int) -> o3d.geometry.PointCloud:
93
  """
94
+ Reproduce your simple pinhole intrinsics (fx=fy=500, cx=w/2, cy=h/2) and back-project.
 
95
  """
96
+ K = o3d.camera.PinholeCameraIntrinsic()
97
+ K.set_intrinsics(w, h, 500.0, 500.0, w / 2.0, h / 2.0)
98
+ pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, K)
99
+ return pcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+
102
+ def filter_pointcloud(pcd: o3d.geometry.PointCloud):
103
+ """
104
+ Statistical outlier removal ~ your 'noise removal' step. Tuned conservatively.
105
+ """
106
+ if len(pcd.points) == 0:
107
+ return pcd
108
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
109
+ pcd_f = pcd.select_by_index(ind)
110
+ pcd_f.estimate_normals(
111
+ search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.05, max_nn=30)
112
+ )
113
+ return pcd_f
114
+
115
+
116
+ def poisson_mesh(pcd: o3d.geometry.PointCloud, rotate_up=True) -> o3d.geometry.TriangleMesh:
117
+ if len(pcd.points) == 0:
118
+ return o3d.geometry.TriangleMesh()
119
+ mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
120
+ pcd, depth=10, n_threads=1
121
+ )
122
+ # Flip like your main.py (rotate X by pi)
123
+ if rotate_up:
124
+ R = mesh.get_rotation_matrix_from_xyz((np.pi, 0.0, 0.0))
125
+ mesh.rotate(R, center=(0, 0, 0))
126
+ mesh.compute_vertex_normals()
127
+ return mesh
128
+
129
+
130
+ def o3d_to_ply_bytes(geom: o3d.geometry.Geometry) -> bytes:
131
+ """Serialize an Open3D geometry to .ply bytes (so Gradio can offer a download)."""
132
+ tmp = Path("tmp_out.ply")
133
+ if isinstance(geom, o3d.geometry.PointCloud):
134
+ o3d.io.write_point_cloud(str(tmp), geom)
135
+ else:
136
+ o3d.io.write_triangle_mesh(str(tmp), geom)
137
+ data = tmp.read_bytes()
138
+ tmp.unlink(missing_ok=True)
139
+ return data
140
+
141
+
142
+ def render_point_count(pcd: o3d.geometry.PointCloud) -> str:
143
+ return f"Points: {len(pcd.points):,}"
144
+
145
+
146
+ def render_face_count(mesh: o3d.geometry.TriangleMesh) -> str:
147
+ return f"Vertices: {len(mesh.vertices):,} | Triangles: {len(mesh.triangles):,}"
148
+
149
+
150
+ # ----------------------------
151
+ # Gradio pipeline
152
+ # ----------------------------
153
+ def pipeline(image: Image.Image):
154
+ logs = []
155
  if image is None:
156
+ raise gr.Error("Please upload an image of a room.")
157
+
158
+ logs.append("Step 1 — Loaded image.")
159
+ image = image.convert("RGB")
160
+ w, h = image.size
161
+
162
+ # Depth
163
+ logs.append("Step 2 Estimating depth with GLPN (vinvino02/glpn-nyu)…")
164
+ depth = estimate_depth_glpn(image)
165
+ depth_preview = depth_vis(depth)
166
+
167
+ # RGBD
168
+ logs.append("Step 3 — Creating RGBD image…")
169
+ rgbd = rgbd_from_rgb_depth(image, depth)
170
+
171
+ # Point cloud
172
+ logs.append("Step 4 Back-projecting to point cloud…")
173
+ pcd = pointcloud_from_rgbd(rgbd, w, h)
174
+
175
+ logs.append("Step 5 — Filtering noise & estimating normals…")
176
+ pcd_f = filter_pointcloud(pcd)
177
+
178
+ # Mesh
179
+ logs.append("Step 6 Poisson surface reconstruction…")
180
+ mesh = poisson_mesh(pcd_f, rotate_up=True)
181
+
182
+ # Prepare downloads
183
+ logs.append("Step 7 Preparing downloads…")
184
+ pcd_bytes = o3d_to_ply_bytes(pcd_f)
185
+ mesh_bytes = o3d_to_ply_bytes(mesh)
186
+
187
+ # Small text stats
188
+ pcd_stats = render_point_count(pcd_f)
189
+ mesh_stats = render_face_count(mesh)
190
+
191
+ logs.append("Done.")
192
+
193
+ return (
194
+ image, # RGB preview
195
+ depth_preview, # Depth preview
196
+ pcd_stats, # point cloud stats
197
+ mesh_stats, # mesh stats
198
+ ("point_cloud.ply", pcd_bytes),
199
+ ("mesh.ply", mesh_bytes),
200
+ "\n".join(logs),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  )
202
 
203
+
204
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
205
+ gr.Markdown("# 2D → 3D (GLPN → RGBD → Point Cloud → Poisson Mesh)\nUpload a single image to reproduce your main.py workflow.")
206
+
207
  with gr.Row():
208
+ with gr.Column():
209
+ inp = gr.Image(type="pil", label="Input Image")
210
+ run = gr.Button("Reconstruct 3D", variant="primary")
211
+ log_box = gr.Textbox(label="Log", lines=14, interactive=False)
212
+
213
+ with gr.Column():
214
+ rgb_out = gr.Image(label="RGB Preview", interactive=False)
215
+ depth_out = gr.Image(label="Depth Preview (8-bit normalized)", interactive=False)
216
+
217
+ pc_txt = gr.Markdown()
218
+ mesh_txt = gr.Markdown()
219
+
220
+ pc_file = gr.File(label="Download Point Cloud (.ply)")
221
+ mesh_file = gr.File(label="Download Mesh (.ply)")
222
+
223
+ run.click(
224
+ fn=pipeline,
225
+ inputs=[inp],
226
+ outputs=[rgb_out, depth_out, pc_txt, mesh_txt, pc_file, mesh_file, log_box],
227
+ api_name="reconstruct",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  )
229
 
230
+ # IMPORTANT: older Spaces error came from using unsupported args like concurrency_count.
231
+ demo.queue() # default queue works across Gradio 4.x
232
+ demo.launch()