|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
import nibabel as nib |
|
|
import numpy as np |
|
|
from huggingface_hub import hf_hub_download |
|
|
from monai.transforms import ( |
|
|
Compose, |
|
|
LoadImage, |
|
|
EnsureChannelFirst, |
|
|
ScaleIntensity, |
|
|
Resize, |
|
|
AsDiscrete, |
|
|
) |
|
|
from monai.networks.nets import UNet |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
|
|
|
model = None |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def load_model(): |
|
|
global model |
|
|
if model is None: |
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id="MONAI/example_spleen_segmentation", |
|
|
filename="model.pt" |
|
|
) |
|
|
|
|
|
|
|
|
model = UNet( |
|
|
spatial_dims=3, |
|
|
in_channels=1, |
|
|
out_channels=2, |
|
|
channels=(16, 32, 64, 128, 256), |
|
|
strides=(2, 2, 2, 2), |
|
|
num_res_units=2, |
|
|
norm="batch", |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(checkpoint, strict=False) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
def segment_spleen(input_file): |
|
|
"""Segment spleen from CT NIfTI file""" |
|
|
try: |
|
|
|
|
|
net = load_model() |
|
|
|
|
|
|
|
|
img = nib.load(input_file) |
|
|
img_data = img.get_fdata() |
|
|
|
|
|
|
|
|
img_tensor = torch.from_numpy(img_data).float().unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min()) |
|
|
|
|
|
|
|
|
img_resized = torch.nn.functional.interpolate( |
|
|
img_tensor, |
|
|
size=(96, 96, 96), |
|
|
mode="trilinear", |
|
|
align_corners=True |
|
|
) |
|
|
|
|
|
|
|
|
img_resized = img_resized.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = net(img_resized) |
|
|
pred = torch.argmax(output, dim=1) |
|
|
|
|
|
|
|
|
pred_resized = torch.nn.functional.interpolate( |
|
|
pred.float().unsqueeze(0), |
|
|
size=img_data.shape, |
|
|
mode="nearest" |
|
|
) |
|
|
|
|
|
pred_np = pred_resized.squeeze().cpu().numpy().astype(np.uint8) |
|
|
|
|
|
|
|
|
seg_img = nib.Nifti1Image(pred_np, img.affine, img.header) |
|
|
output_path = tempfile.mktemp(suffix="_segmentation.nii.gz") |
|
|
nib.save(seg_img, output_path) |
|
|
|
|
|
|
|
|
mid_slice = img_data.shape[2] // 2 |
|
|
img_slice = img_data[:, :, mid_slice] |
|
|
seg_slice = pred_np[:, :, mid_slice] |
|
|
|
|
|
|
|
|
img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min()) * 255 |
|
|
|
|
|
|
|
|
overlay = np.stack([img_slice, img_slice, img_slice], axis=-1).astype(np.uint8) |
|
|
overlay[seg_slice == 1] = [255, 0, 0] |
|
|
|
|
|
return overlay, output_path, "Segmentation completed successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
return None, None, f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Spleen Segmentation") as demo: |
|
|
gr.Markdown("# 🏥 CT Spleen Segmentation") |
|
|
gr.Markdown( |
|
|
"""Upload a CT scan in NIfTI format (.nii or .nii.gz) to segment the spleen using the |
|
|
[MONAI/example_spleen_segmentation](https://huggingface.co/MONAI/example_spleen_segmentation) model. |
|
|
|
|
|
**Model Info:** |
|
|
- Architecture: UNet |
|
|
- Input: 3D CT image (96×96×96) |
|
|
- Output: Binary segmentation (spleen vs background) |
|
|
- Mean Dice Score: 0.96 |
|
|
|
|
|
**Instructions:** |
|
|
1. Upload a NIfTI file (.nii or .nii.gz) |
|
|
2. Click Submit |
|
|
3. View the segmentation overlay and download the result |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_file = gr.File( |
|
|
label="Upload CT Scan (NIfTI format)", |
|
|
file_types=[".nii", ".nii.gz"] |
|
|
) |
|
|
submit_btn = gr.Button("Segment Spleen", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_image = gr.Image(label="Segmentation Overlay (Middle Slice)", type="numpy") |
|
|
output_file = gr.File(label="Download Segmentation") |
|
|
status_text = gr.Textbox(label="Status") |
|
|
|
|
|
submit_btn.click( |
|
|
fn=segment_spleen, |
|
|
inputs=[input_file], |
|
|
outputs=[output_image, output_file, status_text] |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
"""### Requirements |
|
|
- MONAI |
|
|
- PyTorch |
|
|
- nibabel |
|
|
- numpy |
|
|
- huggingface_hub |
|
|
|
|
|
### Citation |
|
|
If you use this model, please cite: |
|
|
``` |
|
|
Xia, Yingda, et al. "3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training." |
|
|
arXiv preprint arXiv:1811.12506 (2018). |
|
|
``` |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |