Update app.py
Browse files
app.py
CHANGED
|
@@ -212,7 +212,6 @@ def extract_frames(input_video, session_id):
|
|
| 212 |
|
| 213 |
|
| 214 |
def update_gallery_on_video_upload(input_video, session_id):
|
| 215 |
-
|
| 216 |
if not input_video:
|
| 217 |
return None, None, None
|
| 218 |
|
|
@@ -229,6 +228,17 @@ def update_gallery_on_images_upload(input_images, session_id):
|
|
| 229 |
|
| 230 |
@spaces.GPU()
|
| 231 |
def generate_splats_from_video(video_path, session_id=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
if session_id is None:
|
| 234 |
session_id = uuid.uuid4().hex
|
|
@@ -240,7 +250,16 @@ def generate_splats_from_video(video_path, session_id=None):
|
|
| 240 |
|
| 241 |
@spaces.GPU()
|
| 242 |
def generate_splats_from_images(image_paths, session_id=None):
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
processed_image_paths = []
|
| 245 |
|
| 246 |
for file_data in image_paths:
|
|
@@ -267,12 +286,12 @@ def generate_splats_from_images(image_paths, session_id=None):
|
|
| 267 |
|
| 268 |
print("Running run_model...")
|
| 269 |
with torch.no_grad():
|
| 270 |
-
plyfile,
|
| 271 |
|
| 272 |
end_time = time.time()
|
| 273 |
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
|
| 274 |
|
| 275 |
-
return plyfile,
|
| 276 |
|
| 277 |
def cleanup(request: gr.Request):
|
| 278 |
|
|
@@ -422,14 +441,16 @@ if __name__ == "__main__":
|
|
| 422 |
fn=update_gallery_on_video_upload,
|
| 423 |
inputs=[input_video, session_state],
|
| 424 |
outputs=[reconstruction_output, target_dir_output, image_gallery],
|
|
|
|
| 425 |
)
|
| 426 |
|
| 427 |
input_images.upload(
|
| 428 |
fn=update_gallery_on_images_upload,
|
| 429 |
inputs=[input_images, session_state],
|
| 430 |
outputs=[reconstruction_output, target_dir_output, image_gallery],
|
|
|
|
| 431 |
)
|
| 432 |
|
| 433 |
demo.unload(cleanup)
|
| 434 |
demo.queue()
|
| 435 |
-
demo.launch(show_error=True, share=True)
|
|
|
|
| 212 |
|
| 213 |
|
| 214 |
def update_gallery_on_video_upload(input_video, session_id):
|
|
|
|
| 215 |
if not input_video:
|
| 216 |
return None, None, None
|
| 217 |
|
|
|
|
| 228 |
|
| 229 |
@spaces.GPU()
|
| 230 |
def generate_splats_from_video(video_path, session_id=None):
|
| 231 |
+
"""
|
| 232 |
+
Perform Gaussian Splatting from Unconstrained Views a Given Video, using a Feed-forward model.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
video_path (str): Path to the input video file on disk.
|
| 236 |
+
Returns:
|
| 237 |
+
plyfile: Path to the reconstructed 3D object from the given video.
|
| 238 |
+
rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames.
|
| 239 |
+
depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames.
|
| 240 |
+
image_paths: A list of paths from extracted frame from the video that is used for training Gaussian Splatting.
|
| 241 |
+
"""
|
| 242 |
|
| 243 |
if session_id is None:
|
| 244 |
session_id = uuid.uuid4().hex
|
|
|
|
| 250 |
|
| 251 |
@spaces.GPU()
|
| 252 |
def generate_splats_from_images(image_paths, session_id=None):
|
| 253 |
+
"""
|
| 254 |
+
Perform Gaussian Splatting from Unconstrained Views a Given Images , using a Feed-forward model.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
image_paths (str): Path to the input image files on disk.
|
| 258 |
+
Returns:
|
| 259 |
+
plyfile: Path to the reconstructed 3D object from the given image files.
|
| 260 |
+
rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames.
|
| 261 |
+
depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames.
|
| 262 |
+
"""
|
| 263 |
processed_image_paths = []
|
| 264 |
|
| 265 |
for file_data in image_paths:
|
|
|
|
| 286 |
|
| 287 |
print("Running run_model...")
|
| 288 |
with torch.no_grad():
|
| 289 |
+
plyfile, rgb_vid, depth_vid = get_reconstructed_scene(base_dir, image_paths, model, device)
|
| 290 |
|
| 291 |
end_time = time.time()
|
| 292 |
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
|
| 293 |
|
| 294 |
+
return plyfile, rgb_vid, depth_vid
|
| 295 |
|
| 296 |
def cleanup(request: gr.Request):
|
| 297 |
|
|
|
|
| 441 |
fn=update_gallery_on_video_upload,
|
| 442 |
inputs=[input_video, session_state],
|
| 443 |
outputs=[reconstruction_output, target_dir_output, image_gallery],
|
| 444 |
+
show_api=False
|
| 445 |
)
|
| 446 |
|
| 447 |
input_images.upload(
|
| 448 |
fn=update_gallery_on_images_upload,
|
| 449 |
inputs=[input_images, session_state],
|
| 450 |
outputs=[reconstruction_output, target_dir_output, image_gallery],
|
| 451 |
+
show_api=False
|
| 452 |
)
|
| 453 |
|
| 454 |
demo.unload(cleanup)
|
| 455 |
demo.queue()
|
| 456 |
+
demo.launch(show_error=True, share=True, mcp_server=True)
|