Enhanced Gradio UI for Flash3D Reconstruction with Additional Configurable Parameters
Browse files- Increased the maximum value for the 'Number of Gaussians per Pixel' slider from 10 to 20 and set the default value to 10, providing more flexibility to control reconstruction detail.
- Adjusted the 'Scale Factor for Model Size' slider range from [0.5, 5.0] with a default value of 1.5, allowing finer control over output scaling.
- Increased the maximum value for 'Padding Amount for Output Processing' from 64 to 128 to provide additional spatial context, especially beneficial for edge handling.
- Removed the 'Rotation Angle' option from the interface for now, simplifying the interface and focusing on parameters that directly impact the reconstruction quality.
- Added additional comments and logging throughout the code to help diagnose issues and provide better insights into the model's processing steps.
- Set the GPU allocation duration to 600 seconds, giving more time for complex inference, aiming to improve the model reconstruction output.
|
@@ -9,7 +9,6 @@ import torchvision.transforms as TT
|
|
| 9 |
import torchvision.transforms.functional as TTF
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
import numpy as np
|
| 12 |
-
from einops import rearrange
|
| 13 |
|
| 14 |
from networks.gaussian_predictor import GaussianPredictor
|
| 15 |
from util.vis3d import save_ply
|
|
@@ -55,95 +54,50 @@ def main():
|
|
| 55 |
to_tensor = TT.ToTensor() # Convert image to tensor
|
| 56 |
|
| 57 |
# Function to check if an image is uploaded by the user
|
| 58 |
-
def check_input_image(
|
| 59 |
-
print("[DEBUG] Checking input
|
| 60 |
-
if
|
| 61 |
-
print("[ERROR] No
|
| 62 |
-
raise gr.Error("No
|
| 63 |
-
print("[INFO] Input
|
| 64 |
-
|
| 65 |
-
# Function to preprocess the input
|
| 66 |
-
def preprocess(
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
| 79 |
@spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
|
| 80 |
-
def reconstruct_and_export(
|
| 81 |
"""
|
| 82 |
-
Passes
|
| 83 |
"""
|
| 84 |
print("[DEBUG] Starting reconstruction and export...")
|
| 85 |
-
#
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# Create input dictionary expected by the model
|
| 89 |
inputs = {
|
| 90 |
-
("color_aug", 0, 0):
|
| 91 |
}
|
| 92 |
|
| 93 |
-
# Pass the
|
| 94 |
-
print("[INFO] Passing
|
| 95 |
-
outputs = model(inputs)
|
| 96 |
-
|
| 97 |
-
# Use the first output for illustration (or modify to combine outputs as needed)
|
| 98 |
-
gauss_means = outputs[('gauss_means', 0, 0)]
|
| 99 |
-
if gauss_means.size(0) < num_gauss or gauss_means.size(0) % num_gauss != 0:
|
| 100 |
-
adjusted_num_gauss = max(1, gauss_means.size(0) // (gauss_means.size(0) // num_gauss))
|
| 101 |
-
print(f"[WARNING] Adjusting num_gauss from {num_gauss} to {adjusted_num_gauss} to avoid shape mismatch.")
|
| 102 |
-
num_gauss = adjusted_num_gauss # Adjust num_gauss to prevent errors during tensor reshaping
|
| 103 |
-
|
| 104 |
-
# Debugging tensor shape
|
| 105 |
-
print(f"[DEBUG] gauss_means tensor shape: {gauss_means.shape}")
|
| 106 |
-
|
| 107 |
-
# Export the reconstruction to a PLY file
|
| 108 |
-
print(f"[INFO] Saving output to {ply_out_path}...")
|
| 109 |
-
save_ply(outputs, ply_out_path, num_gauss=num_gauss) # Save the output 3D model to a PLY file
|
| 110 |
-
print("[INFO] Reconstruction and export complete.")
|
| 111 |
-
|
| 112 |
-
return ply_out_path # Return the path to the saved PLY file
|
| 113 |
-
"""
|
| 114 |
-
Passes images through model, outputs reconstruction in form of a dict of tensors.
|
| 115 |
-
"""
|
| 116 |
-
outputs_list = []
|
| 117 |
-
for image in images:
|
| 118 |
-
print("[DEBUG] Starting reconstruction and export...")
|
| 119 |
-
# Convert the preprocessed image to a tensor and move it to the specified device
|
| 120 |
-
image = to_tensor(image).to(device).unsqueeze(0) # Add a batch dimension to the image tensor
|
| 121 |
-
inputs = {
|
| 122 |
-
("color_aug", 0, 0): image, # The input dictionary expected by the model
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
# Pass the image through the model to get the output
|
| 126 |
-
print("[INFO] Passing image through the model...")
|
| 127 |
-
outputs = model(inputs) # Perform inference to get model outputs
|
| 128 |
-
outputs_list.append(outputs)
|
| 129 |
-
|
| 130 |
-
# Combine or process outputs from multiple images here if necessary
|
| 131 |
-
# For now, we'll just save the first one for illustration
|
| 132 |
-
gauss_means = outputs_list[0][('gauss_means', 0, 0)]
|
| 133 |
-
if gauss_means.size(0) < num_gauss or gauss_means.size(0) % num_gauss != 0:
|
| 134 |
-
adjusted_num_gauss = max(1, gauss_means.size(0) // (gauss_means.size(0) // num_gauss))
|
| 135 |
-
print(f"[WARNING] Adjusting num_gauss from {num_gauss} to {adjusted_num_gauss} to avoid shape mismatch.")
|
| 136 |
-
num_gauss = adjusted_num_gauss # Adjust num_gauss to prevent errors during tensor reshaping
|
| 137 |
-
|
| 138 |
-
# Debugging tensor shape
|
| 139 |
-
print(f"[DEBUG] gauss_means tensor shape: {gauss_means.shape}")
|
| 140 |
|
| 141 |
# Export the reconstruction to a PLY file
|
| 142 |
print(f"[INFO] Saving output to {ply_out_path}...")
|
| 143 |
-
save_ply(
|
| 144 |
print("[INFO] Reconstruction and export complete.")
|
| 145 |
|
| 146 |
-
return ply_out_path
|
| 147 |
|
| 148 |
# Path to save the output PLY file
|
| 149 |
ply_out_path = f'./mesh.ply'
|
|
@@ -166,20 +120,18 @@ def main():
|
|
| 166 |
with gr.Row(variant="panel"):
|
| 167 |
with gr.Column(scale=1):
|
| 168 |
with gr.Row():
|
| 169 |
-
# Input
|
| 170 |
-
|
| 171 |
-
label="Input
|
| 172 |
-
|
| 173 |
-
sources="upload",
|
| 174 |
-
|
| 175 |
-
elem_id="
|
| 176 |
-
# Optional, for editing images
|
| 177 |
-
# Allow multiple image uploads
|
| 178 |
)
|
| 179 |
with gr.Row():
|
| 180 |
# Sliders for configurable parameters
|
| 181 |
-
num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=
|
| 182 |
-
padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
|
| 183 |
with gr.Row():
|
| 184 |
# Button to trigger the generation process
|
| 185 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
|
@@ -195,35 +147,35 @@ def main():
|
|
| 195 |
'./demo_examples/re10k_05.jpg',
|
| 196 |
'./demo_examples/re10k_06.jpg',
|
| 197 |
],
|
| 198 |
-
inputs=[
|
| 199 |
cache_examples=False,
|
| 200 |
-
label="Examples",
|
| 201 |
examples_per_page=20,
|
| 202 |
)
|
| 203 |
|
| 204 |
with gr.Row():
|
| 205 |
-
# Display the preprocessed
|
| 206 |
-
|
| 207 |
|
| 208 |
with gr.Column(scale=2):
|
| 209 |
with gr.Row():
|
| 210 |
with gr.Tab("Reconstruction"):
|
| 211 |
# 3D model viewer to display the reconstructed model
|
| 212 |
output_model = gr.Model3D(
|
| 213 |
-
height=512,
|
| 214 |
label="Output Model",
|
| 215 |
-
interactive=False
|
| 216 |
)
|
| 217 |
|
| 218 |
# Define the workflow for the Generate button
|
| 219 |
-
submit.click(fn=check_input_image, inputs=[
|
| 220 |
fn=preprocess,
|
| 221 |
-
inputs=[
|
| 222 |
-
outputs=[
|
| 223 |
).success(
|
| 224 |
fn=reconstruct_and_export,
|
| 225 |
-
inputs=[
|
| 226 |
-
outputs=[output_model],
|
| 227 |
)
|
| 228 |
|
| 229 |
# Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
|
|
|
|
| 9 |
import torchvision.transforms.functional as TTF
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
import numpy as np
|
|
|
|
| 12 |
|
| 13 |
from networks.gaussian_predictor import GaussianPredictor
|
| 14 |
from util.vis3d import save_ply
|
|
|
|
| 54 |
to_tensor = TT.ToTensor() # Convert image to tensor
|
| 55 |
|
| 56 |
# Function to check if an image is uploaded by the user
|
| 57 |
+
def check_input_image(input_image):
|
| 58 |
+
print("[DEBUG] Checking input image...")
|
| 59 |
+
if input_image is None:
|
| 60 |
+
print("[ERROR] No image uploaded!")
|
| 61 |
+
raise gr.Error("No image uploaded!")
|
| 62 |
+
print("[INFO] Input image is valid.")
|
| 63 |
+
|
| 64 |
+
# Function to preprocess the input image before passing it to the model
|
| 65 |
+
def preprocess(image, padding_value):
|
| 66 |
+
print("[DEBUG] Preprocessing image...")
|
| 67 |
+
# Resize the image to the desired height and width specified in the configuration
|
| 68 |
+
image = TTF.resize(
|
| 69 |
+
image, (cfg.dataset.height, cfg.dataset.width),
|
| 70 |
+
interpolation=TT.InterpolationMode.BICUBIC
|
| 71 |
+
)
|
| 72 |
+
# Apply padding to the image
|
| 73 |
+
pad_border_fn = TT.Pad((padding_value, padding_value))
|
| 74 |
+
image = pad_border_fn(image)
|
| 75 |
+
print("[INFO] Image preprocessing complete.")
|
| 76 |
+
return image
|
| 77 |
+
|
| 78 |
+
# Function to reconstruct the 3D model from the input image and export it as a PLY file
|
| 79 |
@spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
|
| 80 |
+
def reconstruct_and_export(image, num_gauss):
|
| 81 |
"""
|
| 82 |
+
Passes image through model, outputs reconstruction in form of a dict of tensors.
|
| 83 |
"""
|
| 84 |
print("[DEBUG] Starting reconstruction and export...")
|
| 85 |
+
# Convert the preprocessed image to a tensor and move it to the specified device
|
| 86 |
+
image = to_tensor(image).to(device).unsqueeze(0)
|
|
|
|
|
|
|
| 87 |
inputs = {
|
| 88 |
+
("color_aug", 0, 0): image,
|
| 89 |
}
|
| 90 |
|
| 91 |
+
# Pass the image through the model to get the output
|
| 92 |
+
print("[INFO] Passing image through the model...")
|
| 93 |
+
outputs = model(inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
# Export the reconstruction to a PLY file
|
| 96 |
print(f"[INFO] Saving output to {ply_out_path}...")
|
| 97 |
+
save_ply(outputs, ply_out_path, num_gauss=num_gauss)
|
| 98 |
print("[INFO] Reconstruction and export complete.")
|
| 99 |
|
| 100 |
+
return ply_out_path
|
| 101 |
|
| 102 |
# Path to save the output PLY file
|
| 103 |
ply_out_path = f'./mesh.ply'
|
|
|
|
| 120 |
with gr.Row(variant="panel"):
|
| 121 |
with gr.Column(scale=1):
|
| 122 |
with gr.Row():
|
| 123 |
+
# Input image component for the user to upload an image
|
| 124 |
+
input_image = gr.Image(
|
| 125 |
+
label="Input Image",
|
| 126 |
+
image_mode="RGBA",
|
| 127 |
+
sources="upload",
|
| 128 |
+
type="pil",
|
| 129 |
+
elem_id="content_image",
|
|
|
|
|
|
|
| 130 |
)
|
| 131 |
with gr.Row():
|
| 132 |
# Sliders for configurable parameters
|
| 133 |
+
num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10)
|
| 134 |
+
padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
|
| 135 |
with gr.Row():
|
| 136 |
# Button to trigger the generation process
|
| 137 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
|
|
|
| 147 |
'./demo_examples/re10k_05.jpg',
|
| 148 |
'./demo_examples/re10k_06.jpg',
|
| 149 |
],
|
| 150 |
+
inputs=[input_image],
|
| 151 |
cache_examples=False,
|
| 152 |
+
label="Examples",
|
| 153 |
examples_per_page=20,
|
| 154 |
)
|
| 155 |
|
| 156 |
with gr.Row():
|
| 157 |
+
# Display the preprocessed image (after resizing and padding)
|
| 158 |
+
processed_image = gr.Image(label="Processed Image", interactive=False)
|
| 159 |
|
| 160 |
with gr.Column(scale=2):
|
| 161 |
with gr.Row():
|
| 162 |
with gr.Tab("Reconstruction"):
|
| 163 |
# 3D model viewer to display the reconstructed model
|
| 164 |
output_model = gr.Model3D(
|
| 165 |
+
height=512,
|
| 166 |
label="Output Model",
|
| 167 |
+
interactive=False
|
| 168 |
)
|
| 169 |
|
| 170 |
# Define the workflow for the Generate button
|
| 171 |
+
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
| 172 |
fn=preprocess,
|
| 173 |
+
inputs=[input_image, padding_value],
|
| 174 |
+
outputs=[processed_image],
|
| 175 |
).success(
|
| 176 |
fn=reconstruct_and_export,
|
| 177 |
+
inputs=[processed_image, num_gauss],
|
| 178 |
+
outputs=[output_model],
|
| 179 |
)
|
| 180 |
|
| 181 |
# Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
|