File size: 5,304 Bytes
1e89df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc5be7
1e89df9
 
1338ad7
1e89df9
 
 
 
 
 
 
1338ad7
1e89df9
 
1338ad7
1e89df9
1338ad7
1e89df9
 
 
 
ac616fc
1e89df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1338ad7
1e89df9
 
 
 
 
 
 
 
 
3cc5be7
1e89df9
 
 
 
 
3cc5be7
1e89df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc5be7
1e89df9
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
import spaces
import torch
import nibabel as nib
import numpy as np
from huggingface_hub import hf_hub_download
from monai.transforms import (
    Compose,
    LoadImage,
    EnsureChannelFirst,
    ScaleIntensity,
    Resize,
    AsDiscrete,
)
from monai.networks.nets import UNet
import tempfile
import os

# Load the model
model = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model():
    global model
    if model is None:
        # Download model from HuggingFace
        model_path = hf_hub_download(
            repo_id="MONAI/example_spleen_segmentation",
            filename="model.pt"
        )
        
        # Initialize UNet architecture with exact parameters from inference.json
        model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm="batch",  # Added: batch normalization as specified in inference.json
        )
        
        # Load weights with strict=False to handle minor key mismatches
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint, strict=False)
        model.to(device)
        model.eval()
    return model


def segment_spleen(input_file):
    """Segment spleen from CT NIfTI file"""
    try:
        # Load model
        net = load_model()
        
        # Load NIfTI file
        img = nib.load(input_file)
        img_data = img.get_fdata()
        
        # Preprocessing
        img_tensor = torch.from_numpy(img_data).float().unsqueeze(0).unsqueeze(0)
        
        # Normalize
        img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min())
        
        # Resize to model input size (96x96x96)
        img_resized = torch.nn.functional.interpolate(
            img_tensor,
            size=(96, 96, 96),
            mode="trilinear",
            align_corners=True
        )
        
        # Move to device and run inference
        img_resized = img_resized.to(device)
        
        with torch.no_grad():
            output = net(img_resized)
            pred = torch.argmax(output, dim=1)
        
        # Resize back to original size
        pred_resized = torch.nn.functional.interpolate(
            pred.float().unsqueeze(0),
            size=img_data.shape,
            mode="nearest"
        )
        
        pred_np = pred_resized.squeeze().cpu().numpy().astype(np.uint8)
        
        # Save segmentation as NIfTI
        seg_img = nib.Nifti1Image(pred_np, img.affine, img.header)
        output_path = tempfile.mktemp(suffix="_segmentation.nii.gz")
        nib.save(seg_img, output_path)
        
        # Create visualization (middle slice)
        mid_slice = img_data.shape[2] // 2
        img_slice = img_data[:, :, mid_slice]
        seg_slice = pred_np[:, :, mid_slice]
        
        # Normalize image for display
        img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min()) * 255
        
        # Create overlay
        overlay = np.stack([img_slice, img_slice, img_slice], axis=-1).astype(np.uint8)
        overlay[seg_slice == 1] = [255, 0, 0]  # Red for spleen
        
        return overlay, output_path, "Segmentation completed successfully!"
        
    except Exception as e:
        return None, None, f"Error: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Spleen Segmentation") as demo:
    gr.Markdown("# 🏥 CT Spleen Segmentation")
    gr.Markdown(
        """Upload a CT scan in NIfTI format (.nii or .nii.gz) to segment the spleen using the 
        [MONAI/example_spleen_segmentation](https://huggingface.co/MONAI/example_spleen_segmentation) model.

        **Model Info:**
        - Architecture: UNet
        - Input: 3D CT image (96×96×96)
        - Output: Binary segmentation (spleen vs background)
        - Mean Dice Score: 0.96

        **Instructions:**
        1. Upload a NIfTI file (.nii or .nii.gz)
        2. Click Submit
        3. View the segmentation overlay and download the result
        """
    )
    
    with gr.Row():
        with gr.Column():
            input_file = gr.File(
                label="Upload CT Scan (NIfTI format)",
                file_types=[".nii", ".nii.gz"]
            )
            submit_btn = gr.Button("Segment Spleen", variant="primary")
        
        with gr.Column():
            output_image = gr.Image(label="Segmentation Overlay (Middle Slice)", type="numpy")
            output_file = gr.File(label="Download Segmentation")
            status_text = gr.Textbox(label="Status")
    
    submit_btn.click(
        fn=segment_spleen,
        inputs=[input_file],
        outputs=[output_image, output_file, status_text]
    )
    
    gr.Markdown(
        """### Requirements
        - MONAI
        - PyTorch
        - nibabel
        - numpy
        - huggingface_hub

        ### Citation
        If you use this model, please cite:
        ```
        Xia, Yingda, et al. "3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training." 
        arXiv preprint arXiv:1811.12506 (2018).
        ```
        """
    )

if __name__ == "__main__":
    demo.launch()