| import gradio as gr | |
| import spaces | |
| import os | |
| import shutil | |
| os.environ['SPCONV_ALGO'] = 'native' | |
| from huggingface_hub import hf_hub_download | |
| from app_utils import ( | |
| generate_parts, | |
| prepare_models, | |
| process_image, | |
| apply_merge, | |
| DEFAULT_SIZE_TH, | |
| TMP_ROOT, | |
| ) | |
| EXAMPLES = [ | |
| ["assets/example_data/knight.png", 1800, "6,0,26,20,7;13,1,22,11,12,2,21,27,3,24,23;5,18;4,17;19,16,14,25,28", 42], | |
| ["assets/example_data/car.png", 2000, "12,10,2,11;1,7", 42], | |
| ["assets/example_data/warhammer.png", 1800, "7,1,0,8", 0], | |
| ["assets/example_data/snake.png", 3000, "2,3;0,1;4,5,6,7", 42], | |
| ["assets/example_data/Batman.png", 1800, "4,5", 42], | |
| ["assets/example_data/robot1.jpeg", 1600, "0,5;10,14,3;1,12,2;13,11,4;7,15", 42], | |
| ["assets/example_data/astronaut.png", 2000, "0,4,6;1,8,9,7;2,5", 42], | |
| ["assets/example_data/crossbow.jpg", 2000, "2,9;10,12,0,7,11,8,13;4,3", 42], | |
| ["assets/example_data/robot.jpg", 1600, "7,19;15,0;6,18", 42], | |
| ["assets/example_data/robot_dog.jpg", 1000, "21,9;2,12,10,15,17;11,7;1,0;13,19;4,16", 0], | |
| ["assets/example_data/crossbow.jpg", 1600, "9,2;10,15,13;7,14,8,11;0,12,16;5,3,1", 42], | |
| ["assets/example_data/robot.jpg", 1800, "1,2,3,5,4,16,17;11,7,19;10,14;18,6,0,15;13,9;12,8", 0], | |
| ["assets/example_data/robot_dog.jpg", 1000, "2,12,10,15,17,8,3,5,13,19,6,14;11,7;1,0,21,9,11;4,16", 0], | |
| ] | |
| HEADER = """ | |
| # OmniPart: Part-Aware 3D Generation with Semantic Decoupling and Structural Cohesion | |
| 🔮 Generate **part-aware 3D content** from a single 2D image with **2D mask control**. | |
| ## How to Use | |
| **🚀 Quick Start**: Select an example below and click **"▶️ Run Example"** | |
| **📋 Custom Image Processing**: | |
| 1. **Upload Image** - Select your image file | |
| 2. **Click "Segment Image"** - Get initial 2D segmentation | |
| 3. **Merge Segments** - Enter merge groups like `0,1;3,4` and click **"Apply Merge"** (Recommend keeping **2-15 parts**) | |
| 4. **Click "Generate 3D Model"** - Create the final 3D results | |
| """ | |
| def start_session(req: gr.Request): | |
| user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) | |
| os.makedirs(user_dir, exist_ok=True) | |
| def end_session(req: gr.Request): | |
| user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) | |
| shutil.rmtree(user_dir) | |
| with gr.Blocks(title="OmniPart") as demo: | |
| gr.Markdown(HEADER) | |
| state = gr.State({}) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("<div style='text-align: center'>\n\n## Input\n\n</div>") | |
| input_image = gr.Image(label="Upload Image", type="filepath", height=250, width=250) | |
| with gr.Row(): | |
| segment_btn = gr.Button("Segment Image", variant="primary", size="lg") | |
| run_example_btn = gr.Button("▶️ Run Example", variant="secondary", size="lg") | |
| size_threshold = gr.Slider( | |
| minimum=600, | |
| maximum=4000, | |
| value=DEFAULT_SIZE_TH, | |
| step=200, | |
| label="Minimum Segment Size (pixels)", | |
| info="Segments smaller than this will be ignored" | |
| ) | |
| gr.Markdown("### Merge Controls") | |
| merge_input = gr.Textbox( | |
| label="Merge Groups", | |
| placeholder="0,1;3,4", | |
| lines=2, | |
| info="Specify which segments to merge (e.g., '0,1;3,4' merges segments 0&1 together and 3&4 together)" | |
| ) | |
| merge_btn = gr.Button("Apply Merge", variant="primary", size="lg") | |
| gr.Markdown("### 3D Generation Controls") | |
| seed_slider = gr.Slider( | |
| minimum=0, | |
| maximum=10000, | |
| value=42, | |
| step=1, | |
| label="Generation Seed", | |
| info="Random seed for 3D model generation" | |
| ) | |
| cfg_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=15.0, | |
| value=7.5, | |
| step=0.5, | |
| label="CFG Strength", | |
| info="Classifier-Free Guidance strength" | |
| ) | |
| generate_mesh_btn = gr.Button("Generate 3D Model", variant="secondary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("<div style='text-align: center'>\n\n## Results Display\n\n</div>") | |
| with gr.Row(): | |
| initial_seg = gr.Image(label="Init Seg", height=220, width=220) | |
| pre_merge_vis = gr.Image(label="Pre-merge", height=220, width=220) | |
| merged_seg = gr.Image(label="Merged Seg", height=220, width=220) | |
| with gr.Row(): | |
| bbox_mesh = gr.Model3D(label="Bounding Boxes", height=350) | |
| whole_mesh = gr.Model3D(label="Combined Parts", height=350) | |
| exploded_mesh = gr.Model3D(label="Exploded Parts", height=350) | |
| with gr.Row(): | |
| combined_gs = gr.Model3D(label="Combined 3D Gaussians", clear_color=(0.0, 0.0, 0.0, 0.0), height=350) | |
| exploded_gs = gr.Model3D(label="Exploded 3D Gaussians", clear_color=(0.0, 0.0, 0.0, 0.0), height=350) | |
| with gr.Row(): | |
| examples = gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[input_image, size_threshold, merge_input, seed_slider], | |
| cache_examples=False, | |
| ) | |
| demo.load(start_session) | |
| demo.unload(end_session) | |
| segment_btn.click( | |
| process_image, | |
| inputs=[input_image, size_threshold], | |
| outputs=[initial_seg, pre_merge_vis, state] | |
| ) | |
| merge_btn.click( | |
| apply_merge, | |
| inputs=[merge_input, state], | |
| outputs=[merged_seg, state] | |
| ) | |
| generate_mesh_btn.click( | |
| generate_parts, | |
| inputs=[state, seed_slider, cfg_slider], | |
| outputs=[bbox_mesh, whole_mesh, exploded_mesh, combined_gs, exploded_gs] | |
| ) | |
| run_example_btn.click( | |
| fn=process_image, | |
| inputs=[input_image, size_threshold], | |
| outputs=[initial_seg, pre_merge_vis, state] | |
| ).then( | |
| fn=apply_merge, | |
| inputs=[merge_input, state], | |
| outputs=[merged_seg, state] | |
| ).then( | |
| fn=generate_parts, | |
| inputs=[state, seed_slider, cfg_slider], | |
| outputs=[bbox_mesh, whole_mesh, exploded_mesh, combined_gs, exploded_gs] | |
| ) | |
| if __name__ == "__main__": | |
| os.makedirs("ckpt", exist_ok=True) | |
| sam_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="sam_vit_h_4b8939.pth", local_dir="ckpt") | |
| partfield_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="partfield_encoder.ckpt", local_dir="ckpt") | |
| bbox_gen_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="bbox_gen.ckpt", local_dir="ckpt") | |
| prepare_models(sam_ckpt_path, partfield_ckpt_path, bbox_gen_ckpt_path) | |
| demo.launch() |