weijielyu commited on
Commit
fa06469
·
1 Parent(s): 59ae2c2

Update demo

Browse files
Files changed (1) hide show
  1. app.py +383 -1
app.py CHANGED
@@ -94,4 +94,386 @@ except ImportError:
94
  else:
95
  # Build stage may not see a GPU on HF Spaces: compile a cross-arch set
96
  env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX"
97
- except Excep
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  else:
95
  # Build stage may not see a GPU on HF Spaces: compile a cross-arch set
96
  env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX"
97
+ except Exception:
98
+ env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX"
99
+
100
+ # (Optional) side-step allocator+NVML quirks in restrictive containers
101
+ env.setdefault("PYTORCH_NO_CUDA_MEMORY_CACHING", "1")
102
+
103
+ subprocess.check_call(
104
+ [sys.executable, "-m", "pip", "install",
105
+ "git+https://github.com/graphdeco-inria/diff-gaussian-rasterization"],
106
+ env=env,
107
+ )
108
+ import diff_gaussian_rasterization # noqa: F401
109
+
110
+
111
+ from gslrm.model.gaussians_renderer import render_turntable, imageseq2video
112
+ from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
113
+ from utils_folder.face_utils import preprocess_image, preprocess_image_without_cropping
114
+
115
+ # HuggingFace repository configuration
116
+ HF_REPO_ID = "wlyu/OpenFaceLift"
117
+
118
+ def download_weights_from_hf() -> Path:
119
+ """Download model weights from HuggingFace if not already present.
120
+
121
+ Returns:
122
+ Path to the downloaded repository
123
+ """
124
+ workspace_dir = Path(__file__).parent
125
+
126
+ # Check if weights already exist locally
127
+ mvdiffusion_path = workspace_dir / "checkpoints/mvdiffusion/pipeckpts"
128
+ gslrm_path = workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt"
129
+
130
+ if mvdiffusion_path.exists() and gslrm_path.exists():
131
+ print("Using local model weights")
132
+ return workspace_dir
133
+
134
+ print(f"Downloading model weights from HuggingFace: {HF_REPO_ID}")
135
+ print("This may take a few minutes on first run...")
136
+
137
+ # Download to local directory
138
+ snapshot_download(
139
+ repo_id=HF_REPO_ID,
140
+ local_dir=str(workspace_dir / "checkpoints"),
141
+ local_dir_use_symlinks=False,
142
+ )
143
+
144
+ print("Model weights downloaded successfully!")
145
+ return workspace_dir
146
+
147
+ class FaceLiftPipeline:
148
+ """Pipeline for FaceLift 3D head generation from single images."""
149
+
150
+ def __init__(self):
151
+ # Download weights from HuggingFace if needed
152
+ workspace_dir = download_weights_from_hf()
153
+
154
+ # Setup paths
155
+ self.output_dir = workspace_dir / "outputs"
156
+ self.examples_dir = workspace_dir / "examples"
157
+ self.output_dir.mkdir(exist_ok=True)
158
+
159
+ # Parameters
160
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
161
+ self.image_size = 512
162
+ self.camera_indices = [2, 1, 0, 5, 4, 3]
163
+
164
+ # Load models (keep on CPU for ZeroGPU compatibility)
165
+ print("Loading models...")
166
+ try:
167
+ self.mvdiffusion_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
168
+ str(workspace_dir / "checkpoints/mvdiffusion/pipeckpts"),
169
+ torch_dtype=torch.float16,
170
+ )
171
+ # Don't move to device or enable xformers here - will be done in GPU-decorated function
172
+ self._models_on_gpu = False
173
+
174
+ with open(workspace_dir / "configs/gslrm.yaml", "r") as f:
175
+ config = edict(yaml.safe_load(f))
176
+
177
+ module_name, class_name = config.model.class_name.rsplit(".", 1)
178
+ module = __import__(module_name, fromlist=[class_name])
179
+ ModelClass = getattr(module, class_name)
180
+
181
+ self.gs_lrm_model = ModelClass(config)
182
+ checkpoint = torch.load(
183
+ workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt",
184
+ map_location="cpu"
185
+ )
186
+ # Filter out loss_calculator weights (training-only, not needed for inference)
187
+ state_dict = {k: v for k, v in checkpoint["model"].items()
188
+ if not k.startswith("loss_calculator.")}
189
+ self.gs_lrm_model.load_state_dict(state_dict)
190
+ # Keep on CPU initially - will move to GPU in decorated function
191
+
192
+ self.color_prompt_embedding = torch.load(
193
+ workspace_dir / "mvdiffusion/fixed_prompt_embeds_6view/clr_embeds.pt",
194
+ map_location="cpu"
195
+ )
196
+
197
+ with open(workspace_dir / "utils_folder/opencv_cameras.json", 'r') as f:
198
+ self.cameras_data = json.load(f)["frames"]
199
+
200
+ print("Models loaded successfully!")
201
+ except Exception as e:
202
+ print(f"Error loading models: {e}")
203
+ import traceback
204
+ traceback.print_exc()
205
+ raise
206
+
207
+ def _move_models_to_gpu(self):
208
+ """Move models to GPU and enable optimizations. Called within @spaces.GPU context."""
209
+ if not self._models_on_gpu and torch.cuda.is_available():
210
+ print("Moving models to GPU...")
211
+ self.device = torch.device("cuda:0")
212
+ self.mvdiffusion_pipeline.to(self.device)
213
+ self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention()
214
+ self.gs_lrm_model.to(self.device)
215
+ self.gs_lrm_model.eval() # Set to eval mode
216
+ self.color_prompt_embedding = self.color_prompt_embedding.to(self.device)
217
+ self._models_on_gpu = True
218
+ torch.cuda.empty_cache() # Clear cache after moving models
219
+ print("Models on GPU, xformers enabled!")
220
+
221
+ @spaces.GPU(duration=120)
222
+ def generate_3d_head(self, image_path, auto_crop=True, guidance_scale=3.0,
223
+ random_seed=4, num_steps=50):
224
+ """Generate 3D head from single image."""
225
+ try:
226
+ # Move models to GPU now that we're in the GPU context
227
+ self._move_models_to_gpu()
228
+ # Setup output directory
229
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
230
+ output_dir = self.output_dir / timestamp
231
+ output_dir.mkdir(exist_ok=True)
232
+
233
+ # Preprocess input
234
+ original_img = np.array(Image.open(image_path))
235
+ input_image = preprocess_image(original_img) if auto_crop else \
236
+ preprocess_image_without_cropping(original_img)
237
+
238
+ if input_image.size != (self.image_size, self.image_size):
239
+ input_image = input_image.resize((self.image_size, self.image_size))
240
+
241
+ input_path = output_dir / "input.png"
242
+ input_image.save(input_path)
243
+
244
+ # Generate multi-view images
245
+ generator = torch.Generator(device=self.mvdiffusion_pipeline.unet.device)
246
+ generator.manual_seed(random_seed)
247
+
248
+ result = self.mvdiffusion_pipeline(
249
+ input_image, None,
250
+ prompt_embeds=self.color_prompt_embedding,
251
+ height=self.image_size,
252
+ width=self.image_size,
253
+ guidance_scale=guidance_scale,
254
+ num_images_per_prompt=1,
255
+ num_inference_steps=num_steps,
256
+ generator=generator,
257
+ eta=1.0,
258
+ )
259
+
260
+ selected_views = result.images[:6]
261
+
262
+ # Save multi-view composite
263
+ multiview_image = Image.new("RGB", (self.image_size * 6, self.image_size))
264
+ for i, view in enumerate(selected_views):
265
+ multiview_image.paste(view, (self.image_size * i, 0))
266
+
267
+ multiview_path = output_dir / "multiview.png"
268
+ multiview_image.save(multiview_path)
269
+
270
+ # Move diffusion model to CPU to free GPU memory for GS-LRM
271
+ print("Moving diffusion model to CPU to free memory...")
272
+ self.mvdiffusion_pipeline.to("cpu")
273
+
274
+ # Delete intermediate variables to free memory
275
+ del result, generator
276
+ torch.cuda.empty_cache()
277
+ torch.cuda.synchronize()
278
+
279
+ # Prepare 3D reconstruction input
280
+ view_arrays = [np.array(view) for view in selected_views]
281
+ lrm_input = torch.from_numpy(np.stack(view_arrays, axis=0)).float()
282
+ lrm_input = lrm_input[None].to(self.device) / 255.0
283
+ lrm_input = rearrange(lrm_input, "b v h w c -> b v c h w")
284
+
285
+ # Prepare camera parameters
286
+ selected_cameras = [self.cameras_data[i] for i in self.camera_indices]
287
+ fxfycxcy_list = [[c["fx"], c["fy"], c["cx"], c["cy"]] for c in selected_cameras]
288
+ c2w_list = [np.linalg.inv(np.array(c["w2c"])) for c in selected_cameras]
289
+
290
+ fxfycxcy = torch.from_numpy(np.stack(fxfycxcy_list, axis=0).astype(np.float32))
291
+ c2w = torch.from_numpy(np.stack(c2w_list, axis=0).astype(np.float32))
292
+ fxfycxcy = fxfycxcy[None].to(self.device)
293
+ c2w = c2w[None].to(self.device)
294
+
295
+ batch_indices = torch.stack([
296
+ torch.zeros(lrm_input.size(1)).long(),
297
+ torch.arange(lrm_input.size(1)).long(),
298
+ ], dim=-1)[None].to(self.device)
299
+
300
+ batch = edict({
301
+ "image": lrm_input,
302
+ "c2w": c2w,
303
+ "fxfycxcy": fxfycxcy,
304
+ "index": batch_indices,
305
+ })
306
+
307
+ # Ensure GS-LRM model is on GPU
308
+ if next(self.gs_lrm_model.parameters()).device.type == "cpu":
309
+ print("Moving GS-LRM model to GPU...")
310
+ self.gs_lrm_model.to(self.device)
311
+ torch.cuda.empty_cache()
312
+
313
+ # Final memory cleanup before reconstruction
314
+ torch.cuda.empty_cache()
315
+
316
+ # Run 3D reconstruction
317
+ with torch.no_grad(), torch.autocast(enabled=True, device_type="cuda", dtype=torch.float16):
318
+ result = self.gs_lrm_model.forward(batch, create_visual=False, split_data=True)
319
+
320
+ comp_image = result.render[0].unsqueeze(0).detach()
321
+ gaussians = result.gaussians[0]
322
+
323
+ # Clear CUDA cache after reconstruction
324
+ torch.cuda.empty_cache()
325
+
326
+ # Save filtered gaussians
327
+ filtered_gaussians = gaussians.apply_all_filters(
328
+ cam_origins=None,
329
+ opacity_thres=0.04,
330
+ scaling_thres=0.2,
331
+ floater_thres=0.75,
332
+ crop_bbx=[-0.91, 0.91, -0.91, 0.91, -1.0, 1.0],
333
+ nearfar_percent=(0.0001, 1.0),
334
+ )
335
+
336
+ ply_path = output_dir / "gaussians.ply"
337
+ filtered_gaussians.save_ply(str(ply_path))
338
+
339
+ # Save output image
340
+ comp_image = rearrange(comp_image, "x v c h w -> (x h) (v w) c")
341
+ comp_image = (comp_image.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
342
+ output_path = output_dir / "output.png"
343
+ Image.fromarray(comp_image).save(output_path)
344
+
345
+ # Generate turntable video
346
+ turntable_resolution = 512
347
+ num_turntable_views = 180
348
+ turntable_frames = render_turntable(gaussians, rendering_resolution=turntable_resolution,
349
+ num_views=num_turntable_views)
350
+ turntable_frames = rearrange(turntable_frames, "h (v w) c -> v h w c", v=num_turntable_views)
351
+ turntable_frames = np.ascontiguousarray(turntable_frames)
352
+
353
+ turntable_path = output_dir / "turntable.mp4"
354
+ imageseq2video(turntable_frames, str(turntable_path), fps=30)
355
+
356
+ # Final CUDA cache clear
357
+ torch.cuda.empty_cache()
358
+
359
+ return str(input_path), str(multiview_path), str(output_path), \
360
+ str(turntable_path), str(ply_path)
361
+
362
+ except Exception as e:
363
+ import traceback
364
+ error_details = traceback.format_exc()
365
+ print(f"Error details:\n{error_details}")
366
+ raise gr.Error(f"Generation failed: {str(e)}")
367
+
368
+ # -----------------------------
369
+ # gsplat.js viewer (Option A)
370
+ # -----------------------------
371
+ GSPLAT_HEAD = """
372
+ <script type="module">
373
+ import * as SPLAT from "https://cdn.jsdelivr.net/npm/gsplat@1.2.9/+esm";
374
+ let renderer, scene, camera, controls;
375
+
376
+ function ensureViewer() {
377
+ if (renderer) return;
378
+ const container = document.getElementById("splat-container");
379
+ renderer = new SPLAT.WebGLRenderer();
380
+ container.appendChild(renderer.canvas);
381
+ scene = new SPLAT.Scene();
382
+ camera = new SPLAT.Camera();
383
+ controls = new SPLAT.OrbitControls(camera, renderer.canvas);
384
+ const loop = () => { controls.update(); renderer.render(scene, camera); requestAnimationFrame(loop); };
385
+ requestAnimationFrame(loop);
386
+ }
387
+
388
+ async function loadSplat(url) {
389
+ ensureViewer();
390
+ // clear previous
391
+ scene.children.length = 0;
392
+ await SPLAT.Loader.LoadAsync(url, scene, ()=>{});
393
+ }
394
+
395
+ // Expose callable function for Gradio
396
+ window.__load_splat__ = loadSplat;
397
+ </script>
398
+ """
399
+
400
+ def main():
401
+ """Run the FaceLift application with an embedded gsplat.js viewer and per-session files."""
402
+ pipeline = FaceLiftPipeline()
403
+
404
+ # Prepare examples (same as before)
405
+ examples = []
406
+ if pipeline.examples_dir.exists():
407
+ examples = [[str(f), True, 3.0, 4, 50] for f in sorted(pipeline.examples_dir.iterdir())
408
+ if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}]
409
+
410
+ with gr.Blocks(head=GSPLAT_HEAD, title="FaceLift: Single Image 3D Face Reconstruction") as demo:
411
+ session = gr.State()
412
+
413
+ # Light GC + session init
414
+ def _init_session():
415
+ cleanup_old_sessions()
416
+ return new_session_id()
417
+
418
+ # After generation: copy ply into per-session folder and return viewer URL
419
+ def _prep_viewer_url(ply_path: str, session_id: str) -> str:
420
+ if not ply_path or not os.path.exists(ply_path):
421
+ return ""
422
+ return copy_to_session_and_get_url(ply_path, session_id)
423
+
424
+ gr.Markdown("## FaceLift: Single Image 3D Face Reconstruction\nTurn a single portrait image into a 3D head model and preview it interactively.")
425
+ with gr.Row():
426
+ with gr.Column(scale=1):
427
+ in_image = gr.Image(type="filepath", label="Input Portrait Image")
428
+ auto_crop = gr.Checkbox(value=True, label="Auto Cropping")
429
+ guidance = gr.Slider(1.0, 10.0, 3.0, step=0.1, label="Guidance Scale")
430
+ seed = gr.Number(value=4, label="Random Seed")
431
+ steps = gr.Slider(10, 100, 50, step=5, label="Generation Steps")
432
+ run_btn = gr.Button("Generate 3D Head", variant="primary")
433
+
434
+ # Examples (match input signature)
435
+ if examples:
436
+ gr.Examples(
437
+ examples=examples,
438
+ inputs=[in_image, auto_crop, guidance, seed, steps],
439
+ examples_per_page=8,
440
+ )
441
+
442
+ with gr.Column(scale=1):
443
+ out_proc = gr.Image(label="Processed Input")
444
+ out_multi = gr.Image(label="Multi-view Generation")
445
+ out_recon = gr.Image(label="3D Reconstruction")
446
+ out_video = gr.PlayableVideo(label="Turntable Animation")
447
+ out_ply = gr.File(label="3D Model (.ply)")
448
+
449
+ gr.Markdown("### Interactive Gaussian Splat Viewer")
450
+ with gr.Row():
451
+ url_box = gr.Textbox(label="Scene URL (auto-filled)", interactive=False)
452
+ viewer = gr.HTML("<div id='splat-container' style='width:100%;height:640px'></div>")
453
+ reload_btn = gr.Button("Reload Viewer")
454
+
455
+ # Initialize per-browser session
456
+ demo.load(fn=_init_session, inputs=None, outputs=session)
457
+
458
+ # Chain: run → show outputs → prepare viewer URL → load viewer (JS)
459
+ run_btn.click(
460
+ fn=pipeline.generate_3d_head,
461
+ inputs=[in_image, auto_crop, guidance, seed, steps],
462
+ outputs=[out_proc, out_multi, out_recon, out_video, out_ply],
463
+ ).then(
464
+ fn=_prep_viewer_url,
465
+ inputs=[out_ply, session],
466
+ outputs=url_box,
467
+ ).then(
468
+ fn=None, inputs=url_box, outputs=None,
469
+ js="(url)=>window.__load_splat__(url)"
470
+ )
471
+
472
+ # Manual reload if needed
473
+ reload_btn.click(fn=None, inputs=url_box, outputs=None, js="(url)=>window.__load_splat__(url)")
474
+
475
+ demo.queue(max_size=10)
476
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)
477
+
478
+ if __name__ == "__main__":
479
+ main()