Fix UNet architecture: add norm='batch' parameter and use strict=False for state_dict loading
1338ad7
verified
| import gradio as gr | |
| import spaces | |
| import torch | |
| import nibabel as nib | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from monai.transforms import ( | |
| Compose, | |
| LoadImage, | |
| EnsureChannelFirst, | |
| ScaleIntensity, | |
| Resize, | |
| AsDiscrete, | |
| ) | |
| from monai.networks.nets import UNet | |
| import tempfile | |
| import os | |
| # Load the model | |
| model = None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(): | |
| global model | |
| if model is None: | |
| # Download model from HuggingFace | |
| model_path = hf_hub_download( | |
| repo_id="MONAI/example_spleen_segmentation", | |
| filename="model.pt" | |
| ) | |
| # Initialize UNet architecture with exact parameters from inference.json | |
| model = UNet( | |
| spatial_dims=3, | |
| in_channels=1, | |
| out_channels=2, | |
| channels=(16, 32, 64, 128, 256), | |
| strides=(2, 2, 2, 2), | |
| num_res_units=2, | |
| norm="batch", # Added: batch normalization as specified in inference.json | |
| ) | |
| # Load weights with strict=False to handle minor key mismatches | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model.load_state_dict(checkpoint, strict=False) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def segment_spleen(input_file): | |
| """Segment spleen from CT NIfTI file""" | |
| try: | |
| # Load model | |
| net = load_model() | |
| # Load NIfTI file | |
| img = nib.load(input_file) | |
| img_data = img.get_fdata() | |
| # Preprocessing | |
| img_tensor = torch.from_numpy(img_data).float().unsqueeze(0).unsqueeze(0) | |
| # Normalize | |
| img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min()) | |
| # Resize to model input size (96x96x96) | |
| img_resized = torch.nn.functional.interpolate( | |
| img_tensor, | |
| size=(96, 96, 96), | |
| mode="trilinear", | |
| align_corners=True | |
| ) | |
| # Move to device and run inference | |
| img_resized = img_resized.to(device) | |
| with torch.no_grad(): | |
| output = net(img_resized) | |
| pred = torch.argmax(output, dim=1) | |
| # Resize back to original size | |
| pred_resized = torch.nn.functional.interpolate( | |
| pred.float().unsqueeze(0), | |
| size=img_data.shape, | |
| mode="nearest" | |
| ) | |
| pred_np = pred_resized.squeeze().cpu().numpy().astype(np.uint8) | |
| # Save segmentation as NIfTI | |
| seg_img = nib.Nifti1Image(pred_np, img.affine, img.header) | |
| output_path = tempfile.mktemp(suffix="_segmentation.nii.gz") | |
| nib.save(seg_img, output_path) | |
| # Create visualization (middle slice) | |
| mid_slice = img_data.shape[2] // 2 | |
| img_slice = img_data[:, :, mid_slice] | |
| seg_slice = pred_np[:, :, mid_slice] | |
| # Normalize image for display | |
| img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min()) * 255 | |
| # Create overlay | |
| overlay = np.stack([img_slice, img_slice, img_slice], axis=-1).astype(np.uint8) | |
| overlay[seg_slice == 1] = [255, 0, 0] # Red for spleen | |
| return overlay, output_path, "Segmentation completed successfully!" | |
| except Exception as e: | |
| return None, None, f"Error: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Spleen Segmentation") as demo: | |
| gr.Markdown("# 🏥 CT Spleen Segmentation") | |
| gr.Markdown( | |
| """Upload a CT scan in NIfTI format (.nii or .nii.gz) to segment the spleen using the | |
| [MONAI/example_spleen_segmentation](https://huggingface.co/MONAI/example_spleen_segmentation) model. | |
| **Model Info:** | |
| - Architecture: UNet | |
| - Input: 3D CT image (96×96×96) | |
| - Output: Binary segmentation (spleen vs background) | |
| - Mean Dice Score: 0.96 | |
| **Instructions:** | |
| 1. Upload a NIfTI file (.nii or .nii.gz) | |
| 2. Click Submit | |
| 3. View the segmentation overlay and download the result | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_file = gr.File( | |
| label="Upload CT Scan (NIfTI format)", | |
| file_types=[".nii", ".nii.gz"] | |
| ) | |
| submit_btn = gr.Button("Segment Spleen", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Segmentation Overlay (Middle Slice)", type="numpy") | |
| output_file = gr.File(label="Download Segmentation") | |
| status_text = gr.Textbox(label="Status") | |
| submit_btn.click( | |
| fn=segment_spleen, | |
| inputs=[input_file], | |
| outputs=[output_image, output_file, status_text] | |
| ) | |
| gr.Markdown( | |
| """### Requirements | |
| - MONAI | |
| - PyTorch | |
| - nibabel | |
| - numpy | |
| - huggingface_hub | |
| ### Citation | |
| If you use this model, please cite: | |
| ``` | |
| Xia, Yingda, et al. "3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training." | |
| arXiv preprint arXiv:1811.12506 (2018). | |
| ``` | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |