IFMedTechdemo's picture
Update app.py
ac616fc verified
import gradio as gr
import spaces
import torch
import nibabel as nib
import numpy as np
from huggingface_hub import hf_hub_download
from monai.transforms import (
Compose,
LoadImage,
EnsureChannelFirst,
ScaleIntensity,
Resize,
AsDiscrete,
)
from monai.networks.nets import UNet
import tempfile
import os
# Load the model
model = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
global model
if model is None:
# Download model from HuggingFace
model_path = hf_hub_download(
repo_id="MONAI/example_spleen_segmentation",
filename="model.pt"
)
# Initialize UNet architecture with exact parameters from inference.json
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm="batch", # Added: batch normalization as specified in inference.json
)
# Load weights with strict=False to handle minor key mismatches
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint, strict=False)
model.to(device)
model.eval()
return model
def segment_spleen(input_file):
"""Segment spleen from CT NIfTI file"""
try:
# Load model
net = load_model()
# Load NIfTI file
img = nib.load(input_file)
img_data = img.get_fdata()
# Preprocessing
img_tensor = torch.from_numpy(img_data).float().unsqueeze(0).unsqueeze(0)
# Normalize
img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min())
# Resize to model input size (96x96x96)
img_resized = torch.nn.functional.interpolate(
img_tensor,
size=(96, 96, 96),
mode="trilinear",
align_corners=True
)
# Move to device and run inference
img_resized = img_resized.to(device)
with torch.no_grad():
output = net(img_resized)
pred = torch.argmax(output, dim=1)
# Resize back to original size
pred_resized = torch.nn.functional.interpolate(
pred.float().unsqueeze(0),
size=img_data.shape,
mode="nearest"
)
pred_np = pred_resized.squeeze().cpu().numpy().astype(np.uint8)
# Save segmentation as NIfTI
seg_img = nib.Nifti1Image(pred_np, img.affine, img.header)
output_path = tempfile.mktemp(suffix="_segmentation.nii.gz")
nib.save(seg_img, output_path)
# Create visualization (middle slice)
mid_slice = img_data.shape[2] // 2
img_slice = img_data[:, :, mid_slice]
seg_slice = pred_np[:, :, mid_slice]
# Normalize image for display
img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min()) * 255
# Create overlay
overlay = np.stack([img_slice, img_slice, img_slice], axis=-1).astype(np.uint8)
overlay[seg_slice == 1] = [255, 0, 0] # Red for spleen
return overlay, output_path, "Segmentation completed successfully!"
except Exception as e:
return None, None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Spleen Segmentation") as demo:
gr.Markdown("# 🏥 CT Spleen Segmentation")
gr.Markdown(
"""Upload a CT scan in NIfTI format (.nii or .nii.gz) to segment the spleen using the
[MONAI/example_spleen_segmentation](https://huggingface.co/MONAI/example_spleen_segmentation) model.
**Model Info:**
- Architecture: UNet
- Input: 3D CT image (96×96×96)
- Output: Binary segmentation (spleen vs background)
- Mean Dice Score: 0.96
**Instructions:**
1. Upload a NIfTI file (.nii or .nii.gz)
2. Click Submit
3. View the segmentation overlay and download the result
"""
)
with gr.Row():
with gr.Column():
input_file = gr.File(
label="Upload CT Scan (NIfTI format)",
file_types=[".nii", ".nii.gz"]
)
submit_btn = gr.Button("Segment Spleen", variant="primary")
with gr.Column():
output_image = gr.Image(label="Segmentation Overlay (Middle Slice)", type="numpy")
output_file = gr.File(label="Download Segmentation")
status_text = gr.Textbox(label="Status")
submit_btn.click(
fn=segment_spleen,
inputs=[input_file],
outputs=[output_image, output_file, status_text]
)
gr.Markdown(
"""### Requirements
- MONAI
- PyTorch
- nibabel
- numpy
- huggingface_hub
### Citation
If you use this model, please cite:
```
Xia, Yingda, et al. "3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training."
arXiv preprint arXiv:1811.12506 (2018).
```
"""
)
if __name__ == "__main__":
demo.launch()