import gradio as gr import spaces import torch import nibabel as nib import numpy as np from huggingface_hub import hf_hub_download from monai.transforms import ( Compose, LoadImage, EnsureChannelFirst, ScaleIntensity, Resize, AsDiscrete, ) from monai.networks.nets import UNet import tempfile import os # Load the model model = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(): global model if model is None: # Download model from HuggingFace model_path = hf_hub_download( repo_id="MONAI/example_spleen_segmentation", filename="model.pt" ) # Initialize UNet architecture with exact parameters from inference.json model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm="batch", # Added: batch normalization as specified in inference.json ) # Load weights with strict=False to handle minor key mismatches checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint, strict=False) model.to(device) model.eval() return model def segment_spleen(input_file): """Segment spleen from CT NIfTI file""" try: # Load model net = load_model() # Load NIfTI file img = nib.load(input_file) img_data = img.get_fdata() # Preprocessing img_tensor = torch.from_numpy(img_data).float().unsqueeze(0).unsqueeze(0) # Normalize img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min()) # Resize to model input size (96x96x96) img_resized = torch.nn.functional.interpolate( img_tensor, size=(96, 96, 96), mode="trilinear", align_corners=True ) # Move to device and run inference img_resized = img_resized.to(device) with torch.no_grad(): output = net(img_resized) pred = torch.argmax(output, dim=1) # Resize back to original size pred_resized = torch.nn.functional.interpolate( pred.float().unsqueeze(0), size=img_data.shape, mode="nearest" ) pred_np = pred_resized.squeeze().cpu().numpy().astype(np.uint8) # Save segmentation as NIfTI seg_img = nib.Nifti1Image(pred_np, img.affine, img.header) output_path = tempfile.mktemp(suffix="_segmentation.nii.gz") nib.save(seg_img, output_path) # Create visualization (middle slice) mid_slice = img_data.shape[2] // 2 img_slice = img_data[:, :, mid_slice] seg_slice = pred_np[:, :, mid_slice] # Normalize image for display img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min()) * 255 # Create overlay overlay = np.stack([img_slice, img_slice, img_slice], axis=-1).astype(np.uint8) overlay[seg_slice == 1] = [255, 0, 0] # Red for spleen return overlay, output_path, "Segmentation completed successfully!" except Exception as e: return None, None, f"Error: {str(e)}" # Create Gradio interface with gr.Blocks(title="Spleen Segmentation") as demo: gr.Markdown("# 🏥 CT Spleen Segmentation") gr.Markdown( """Upload a CT scan in NIfTI format (.nii or .nii.gz) to segment the spleen using the [MONAI/example_spleen_segmentation](https://huggingface.co/MONAI/example_spleen_segmentation) model. **Model Info:** - Architecture: UNet - Input: 3D CT image (96×96×96) - Output: Binary segmentation (spleen vs background) - Mean Dice Score: 0.96 **Instructions:** 1. Upload a NIfTI file (.nii or .nii.gz) 2. Click Submit 3. View the segmentation overlay and download the result """ ) with gr.Row(): with gr.Column(): input_file = gr.File( label="Upload CT Scan (NIfTI format)", file_types=[".nii", ".nii.gz"] ) submit_btn = gr.Button("Segment Spleen", variant="primary") with gr.Column(): output_image = gr.Image(label="Segmentation Overlay (Middle Slice)", type="numpy") output_file = gr.File(label="Download Segmentation") status_text = gr.Textbox(label="Status") submit_btn.click( fn=segment_spleen, inputs=[input_file], outputs=[output_image, output_file, status_text] ) gr.Markdown( """### Requirements - MONAI - PyTorch - nibabel - numpy - huggingface_hub ### Citation If you use this model, please cite: ``` Xia, Yingda, et al. "3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training." arXiv preprint arXiv:1811.12506 (2018). ``` """ ) if __name__ == "__main__": demo.launch()