File size: 5,304 Bytes
1e89df9 3cc5be7 1e89df9 1338ad7 1e89df9 1338ad7 1e89df9 1338ad7 1e89df9 1338ad7 1e89df9 ac616fc 1e89df9 1338ad7 1e89df9 3cc5be7 1e89df9 3cc5be7 1e89df9 3cc5be7 1e89df9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import gradio as gr
import spaces
import torch
import nibabel as nib
import numpy as np
from huggingface_hub import hf_hub_download
from monai.transforms import (
Compose,
LoadImage,
EnsureChannelFirst,
ScaleIntensity,
Resize,
AsDiscrete,
)
from monai.networks.nets import UNet
import tempfile
import os
# Load the model
model = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
global model
if model is None:
# Download model from HuggingFace
model_path = hf_hub_download(
repo_id="MONAI/example_spleen_segmentation",
filename="model.pt"
)
# Initialize UNet architecture with exact parameters from inference.json
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm="batch", # Added: batch normalization as specified in inference.json
)
# Load weights with strict=False to handle minor key mismatches
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint, strict=False)
model.to(device)
model.eval()
return model
def segment_spleen(input_file):
"""Segment spleen from CT NIfTI file"""
try:
# Load model
net = load_model()
# Load NIfTI file
img = nib.load(input_file)
img_data = img.get_fdata()
# Preprocessing
img_tensor = torch.from_numpy(img_data).float().unsqueeze(0).unsqueeze(0)
# Normalize
img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min())
# Resize to model input size (96x96x96)
img_resized = torch.nn.functional.interpolate(
img_tensor,
size=(96, 96, 96),
mode="trilinear",
align_corners=True
)
# Move to device and run inference
img_resized = img_resized.to(device)
with torch.no_grad():
output = net(img_resized)
pred = torch.argmax(output, dim=1)
# Resize back to original size
pred_resized = torch.nn.functional.interpolate(
pred.float().unsqueeze(0),
size=img_data.shape,
mode="nearest"
)
pred_np = pred_resized.squeeze().cpu().numpy().astype(np.uint8)
# Save segmentation as NIfTI
seg_img = nib.Nifti1Image(pred_np, img.affine, img.header)
output_path = tempfile.mktemp(suffix="_segmentation.nii.gz")
nib.save(seg_img, output_path)
# Create visualization (middle slice)
mid_slice = img_data.shape[2] // 2
img_slice = img_data[:, :, mid_slice]
seg_slice = pred_np[:, :, mid_slice]
# Normalize image for display
img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min()) * 255
# Create overlay
overlay = np.stack([img_slice, img_slice, img_slice], axis=-1).astype(np.uint8)
overlay[seg_slice == 1] = [255, 0, 0] # Red for spleen
return overlay, output_path, "Segmentation completed successfully!"
except Exception as e:
return None, None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Spleen Segmentation") as demo:
gr.Markdown("# 🏥 CT Spleen Segmentation")
gr.Markdown(
"""Upload a CT scan in NIfTI format (.nii or .nii.gz) to segment the spleen using the
[MONAI/example_spleen_segmentation](https://huggingface.co/MONAI/example_spleen_segmentation) model.
**Model Info:**
- Architecture: UNet
- Input: 3D CT image (96×96×96)
- Output: Binary segmentation (spleen vs background)
- Mean Dice Score: 0.96
**Instructions:**
1. Upload a NIfTI file (.nii or .nii.gz)
2. Click Submit
3. View the segmentation overlay and download the result
"""
)
with gr.Row():
with gr.Column():
input_file = gr.File(
label="Upload CT Scan (NIfTI format)",
file_types=[".nii", ".nii.gz"]
)
submit_btn = gr.Button("Segment Spleen", variant="primary")
with gr.Column():
output_image = gr.Image(label="Segmentation Overlay (Middle Slice)", type="numpy")
output_file = gr.File(label="Download Segmentation")
status_text = gr.Textbox(label="Status")
submit_btn.click(
fn=segment_spleen,
inputs=[input_file],
outputs=[output_image, output_file, status_text]
)
gr.Markdown(
"""### Requirements
- MONAI
- PyTorch
- nibabel
- numpy
- huggingface_hub
### Citation
If you use this model, please cite:
```
Xia, Yingda, et al. "3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training."
arXiv preprint arXiv:1811.12506 (2018).
```
"""
)
if __name__ == "__main__":
demo.launch() |