IFMedTechdemo commited on
Commit
1e89df9
·
verified ·
1 Parent(s): 94264be

Add spleen segmentation app with MONAI model

Browse files
Files changed (1) hide show
  1. app.py +169 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import nibabel as nib
5
+ import numpy as np
6
+ from huggingface_hub import hf_hub_download
7
+ from monai.transforms import (
8
+ Compose,
9
+ LoadImage,
10
+ EnsureChannelFirst,
11
+ ScaleIntensity,
12
+ Resize,
13
+ AsDiscrete,
14
+ )
15
+ from monai.networks.nets import UNet
16
+ import tempfile
17
+ import os
18
+
19
+ # Load the model
20
+ model = None
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ def load_model():
24
+ global model
25
+ if model is None:
26
+ # Download model from HuggingFace
27
+ model_path = hf_hub_download(
28
+ repo_id="MONAI/example_spleen_segmentation",
29
+ filename="models/model.pt"
30
+ )
31
+
32
+ # Initialize UNet architecture
33
+ model = UNet(
34
+ spatial_dims=3,
35
+ in_channels=1,
36
+ out_channels=2,
37
+ channels=(16, 32, 64, 128, 256),
38
+ strides=(2, 2, 2, 2),
39
+ num_res_units=2,
40
+ )
41
+
42
+ # Load weights
43
+ checkpoint = torch.load(model_path, map_location=device)
44
+ model.load_state_dict(checkpoint)
45
+ model.to(device)
46
+ model.eval()
47
+ return model
48
+
49
+ @spaces.GPU
50
+ def segment_spleen(input_file):
51
+ """Segment spleen from CT NIfTI file"""
52
+ try:
53
+ # Load model
54
+ net = load_model()
55
+
56
+ # Load NIfTI file
57
+ img = nib.load(input_file)
58
+ img_data = img.get_fdata()
59
+
60
+ # Preprocessing
61
+ img_tensor = torch.from_numpy(img_data).float().unsqueeze(0).unsqueeze(0)
62
+
63
+ # Normalize
64
+ img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min())
65
+
66
+ # Resize to model input size (96x96x96)
67
+ img_resized = torch.nn.functional.interpolate(
68
+ img_tensor,
69
+ size=(96, 96, 96),
70
+ mode="trilinear",
71
+ align_corners=True
72
+ )
73
+
74
+ # Move to device and run inference
75
+ img_resized = img_resized.to(device)
76
+
77
+ with torch.no_grad():
78
+ output = net(img_resized)
79
+ pred = torch.argmax(output, dim=1)
80
+
81
+ # Resize back to original size
82
+ pred_resized = torch.nn.functional.interpolate(
83
+ pred.float().unsqueeze(0),
84
+ size=img_data.shape,
85
+ mode="nearest"
86
+ )
87
+
88
+ pred_np = pred_resized.squeeze().cpu().numpy().astype(np.uint8)
89
+
90
+ # Save segmentation as NIfTI
91
+ seg_img = nib.Nifti1Image(pred_np, img.affine, img.header)
92
+ output_path = tempfile.mktemp(suffix="_segmentation.nii.gz")
93
+ nib.save(seg_img, output_path)
94
+
95
+ # Create visualization (middle slice)
96
+ mid_slice = img_data.shape[2] // 2
97
+ img_slice = img_data[:, :, mid_slice]
98
+ seg_slice = pred_np[:, :, mid_slice]
99
+
100
+ # Normalize image for display
101
+ img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min()) * 255
102
+
103
+ # Create overlay
104
+ overlay = np.stack([img_slice, img_slice, img_slice], axis=-1).astype(np.uint8)
105
+ overlay[seg_slice == 1] = [255, 0, 0] # Red for spleen
106
+
107
+ return overlay, output_path, "Segmentation completed successfully!"
108
+
109
+ except Exception as e:
110
+ return None, None, f"Error: {str(e)}"
111
+
112
+ # Create Gradio interface
113
+ with gr.Blocks(title="Spleen Segmentation") as demo:
114
+ gr.Markdown("# 🏥 CT Spleen Segmentation")
115
+ gr.Markdown(
116
+ """Upload a CT scan in NIfTI format (.nii or .nii.gz) to segment the spleen using the
117
+ [MONAI/example_spleen_segmentation](https://huggingface.co/MONAI/example_spleen_segmentation) model.
118
+
119
+ **Model Info:**
120
+ - Architecture: UNet
121
+ - Input: 3D CT image (96×96×96)
122
+ - Output: Binary segmentation (spleen vs background)
123
+ - Mean Dice Score: 0.96
124
+
125
+ **Instructions:**
126
+ 1. Upload a NIfTI file (.nii or .nii.gz)
127
+ 2. Click Submit
128
+ 3. View the segmentation overlay and download the result
129
+ """
130
+ )
131
+
132
+ with gr.Row():
133
+ with gr.Column():
134
+ input_file = gr.File(
135
+ label="Upload CT Scan (NIfTI format)",
136
+ file_types=[".nii", ".nii.gz"]
137
+ )
138
+ submit_btn = gr.Button("Segment Spleen", variant="primary")
139
+
140
+ with gr.Column():
141
+ output_image = gr.Image(label="Segmentation Overlay (Middle Slice)", type="numpy")
142
+ output_file = gr.File(label="Download Segmentation")
143
+ status_text = gr.Textbox(label="Status")
144
+
145
+ submit_btn.click(
146
+ fn=segment_spleen,
147
+ inputs=[input_file],
148
+ outputs=[output_image, output_file, status_text]
149
+ )
150
+
151
+ gr.Markdown(
152
+ """### Requirements
153
+ - MONAI
154
+ - PyTorch
155
+ - nibabel
156
+ - numpy
157
+ - huggingface_hub
158
+
159
+ ### Citation
160
+ If you use this model, please cite:
161
+ ```
162
+ Xia, Yingda, et al. "3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training."
163
+ arXiv preprint arXiv:1811.12506 (2018).
164
+ ```
165
+ """
166
+ )
167
+
168
+ if __name__ == "__main__":
169
+ demo.launch()