File size: 9,072 Bytes
49c6db7
 
8786ac3
49c6db7
 
 
f0821bf
 
44c9541
e16c5c4
3e66137
 
 
 
 
d874e72
3e66137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49c6db7
c994db7
 
49c6db7
 
 
 
3e66137
 
 
49c6db7
 
 
 
 
 
 
 
 
 
f0821bf
49c6db7
 
 
 
 
 
 
 
f0821bf
 
 
 
 
 
 
 
 
 
1c5339e
f0821bf
 
 
 
c034f55
 
 
 
 
 
 
 
 
 
 
 
 
736285e
c034f55
ea7f537
 
 
49c6db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32b0ba4
49c6db7
 
 
 
 
 
 
 
 
 
 
 
 
eeb07d9
ded4361
a887322
c994db7
a887322
 
c08aa90
 
c994db7
c08aa90
 
 
 
c994db7
c08aa90
 
 
 
c994db7
c08aa90
 
a887322
21fd97a
8742230
 
c08aa90
 
 
 
 
 
 
 
 
 
 
 
 
a887322
49c6db7
 
 
18b4441
8a8ccfd
 
18b4441
49c6db7
958ea27
49c6db7
 
18b4441
958ea27
18b4441
 
 
4c3c584
eecd9f2
 
7902217
 
 
 
 
 
 
b2eef14
e16c5c4
 
b2eef14
e16c5c4
 
7170f20
e16c5c4
 
 
 
7170f20
e16c5c4
 
4c3c584
7170f20
49c6db7
 
2c27168
 
 
 
 
 
 
49c6db7
fd1e2f9
25641bf
 
 
 
2c27168
 
 
 
c3abe48
18b4441
c994db7
18b4441
 
ea35654
1fe72e5
 
21fd97a
1fe72e5
 
18b4441
e16c5c4
c3abe48
2c27168
3c73591
 
18b4441
 
c3abe48
7902217
 
 
0305d39
b2eef14
 
1fe72e5
a3dc09f
25641bf
2c27168
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import numpy as np
import gradio as gr
import spaces
import cv2
from cellpose import models
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt
import os, io, base64
from PIL import Image 
from cellpose.io import imread, imsave

# @title Data retrieval
def download_weights():
    import os, requests
    
    fname = ['cpsam']
    
    url = ["https://osf.io/d7c8e/download"]
    
    for j in range(len(url)):
      if not os.path.isfile(fname[j]):
        try:
          r = requests.get(url[j])
        except requests.ConnectionError:
          print("!!! Failed to download data !!!")
        else:
          if r.status_code != requests.codes.ok:
            print("!!! Failed to download data !!!")
          else:
            with open(fname[j], "wb") as fid:
              fid.write(r.content)

try:
    download_weights()
    model = models.CellposeModel(gpu=True, pretrained_model="cpsam")
except Exception as e:
    print(f"Error loading model: {e}")
    exit(1)



            
def plot_flows(y):
    Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2
    X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2
    H = (np.arctan2(Y, X) + np.pi) / (2*np.pi)
    S = normalize99(y[0][0]**2 + y[1][0]**2)
    HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1)
    HSV = np.clip(HSV, 0.0, 1.0)
    flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
    return flow

def plot_outlines(img, masks):
    outpix = []
    contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE)
    for c in range(len(contours)):
        pix = contours[c].astype(int).squeeze()
        if len(pix)>4:
            peri = cv2.arcLength(contours[c], True)
            approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:]
            outpix.append(approx)
    
    figsize = (6,6)
    if img.shape[0]>img.shape[1]:
        figsize = (6*img.shape[1]/img.shape[0], 6)
    else:
        figsize = (6, 6*img.shape[0]/img.shape[1])
    fig = plt.figure(figsize=figsize, facecolor='k')
    ax = fig.add_axes([0.0,0.0,1,1])
    ax.set_xlim([0,img.shape[1]])
    ax.set_ylim([0,img.shape[0]])
    ax.imshow(img[::-1], origin='upper', aspect = 'auto')
    if outpix is not None:
        for o in outpix:
            ax.plot(o[:,0], img.shape[0]-o[:,1], color=[1,0,0], lw=1)
    ax.axis('off')
    
    #bytes_image = io.BytesIO()
    #plt.savefig(bytes_image, format='png', facecolor=fig.get_facecolor(), edgecolor='none')
    #bytes_image.seek(0)
    #img_arr = np.frombuffer(bytes_image.getvalue(), dtype=np.uint8)
    #bytes_image.close()
    #img = cv2.imdecode(img_arr, 1)
    #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #del bytes_image
    #fig.clf()
    #plt.close(fig)

    buf = io.BytesIO()
    fig.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    pil_img = Image.open(buf)

    return pil_img

def plot_overlay(img, masks):
    img = normalize99(img.astype(np.float32).mean(axis=-1))
    img -= img.min()
    img /= img.max()
    HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
    HSV[:,:,2] = np.clip(img*1.5, 0, 1.0)
    for n in range(int(masks.max())):
        ipix = (masks==n+1).nonzero()
        HSV[ipix[0],ipix[1],0] = np.random.rand()
        HSV[ipix[0],ipix[1],1] = 1.0
    RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
    return RGB

def normalize99(img):
    X = img.copy()
    X = (X - np.percentile(X, 1)) / (np.percentile(X, 99) - np.percentile(X, 1))
    return X

def image_resize(img, resize=400):
    ny,nx = img.shape[:2]
    if np.array(img.shape).max() > resize:
        if ny>nx:
            nx = int(nx/ny * resize)
            ny = resize
        else:
            ny = int(ny/nx * resize)
            nx = resize
        shape = (nx,ny)
        img = cv2.resize(img, shape)
    img = img.astype(np.uint8)
    return img

    
@spaces.GPU(duration=10)
def run_model_gpu(img):
    masks, flows, _ = model.eval(img)#, channels = [0,0])
    return masks, flows

@spaces.GPU(duration=60)
def run_model_gpu60(img):
    masks, flows, _ = model.eval(img)#, channels = [0,0])
    return masks, flows

@spaces.GPU(duration=240)
def run_model_gpu240(img):
    masks, flows, _ = model.eval(img)#, channels = [0,0])
    return masks, flows

@spaces.GPU(duration=1000)
def run_model_gpu1000(img):
    masks, flows, _ = model.eval(img)#, channels = [0,0])
    return masks, flows

#@spaces.GPU(duration=10)
def cellpose_segment(img_pil, resize = 400):
    img_input = imread(img_pil)
    #img_input = np.array(img_pil)
    img = image_resize(img_input, resize = resize)

    resize = np.max(img.shape)
    if resize<1000:
        masks, flows = run_model_gpu(img)
    elif resize < 5000:
        masks, flows = run_model_gpu60(img)
    elif resize < 20000:
        masks, flows = run_model_gpu240(img)
    else:
        raise ValueError("Image size must be less than 20,000")
    
        
    #masks, flows, _ = model.eval(img, channels=[0,0])
    flows = flows[0]
    # masks = np.zeros(img.shape[:2])
    # flows = np.zeros_like(img)

    outpix = plot_outlines(img, masks)
    overlay = plot_overlay(img, masks)
    
    target_size = (img_input.shape[1], img_input.shape[0])
    if (target_size[0]!=img.shape[1] or target_size[1]!=img.shape[0]):
        # scale it back to keep the orignal size
        masks = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16')
        #flows = cv2.resize(flows.astype('float32'), target_size).astype('uint8')
    
    
    #crand = .2 + .8 * np.random.rand(np.max(masks.flatten()).astype('int')+1,).astype('float32')
    #crand[0] = 0

    overlay = Image.fromarray(overlay)
    flows = Image.fromarray(flows)

    Ly, Lx = img.shape[:2]
    c = Lx
    outpix = outpix.resize((Lx, Ly), resample  = Image.BICUBIC)
    overlay = overlay.resize((Lx, Ly), resample  = Image.BICUBIC)
    flows = flows.resize((Lx, Ly), resample  = Image.BICUBIC)

    #masks = Image.fromarray(255. * crand[masks])
    #pil_masks = Image.fromarray(masks.astype('int32'))
    #pil_masks.save(fname_mask)

    fname_out  = os.path.splitext(img_pil)[0]+"_outlines.png"
    fname_masks = os.path.splitext(img_pil)[0]+"_masks.tif"

    imsave(fname_masks, masks)
    
    
    outpix.save(fname_out) #"outlines.png")

    b1 = gr.DownloadButton(visible=True, value = fname_masks)
    b2 = gr.DownloadButton(visible=True, value = fname_out) #"outlines.png")
    
    return outpix, overlay, flows, b1, b2

# Gradio Interface
#iface = gr.Interface(
#    fn=cellpose_segment, 
#    inputs="image", 
#    outputs=["image", "image", "image", "image"],
#    title="cellpose segmentation",
#    description="upload an image, then cellpose will segment it at a max size of 400x400 (for full functionality, 'pip install cellpose' locally)"
#)

def download_function(): 
    b1 = gr.DownloadButton("Download masks as TIFF", visible=False)
    b2 = gr.DownloadButton("Download outline image as PNG", visible=False)
    return b1, b2

with gr.Blocks(title = "Hello", 
               css=".gradio-container {background:purple;}") as demo:

    with gr.Row():
        with gr.Column(scale=2):
            gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:white;">Cellpose-SAM for cellular segmentation</div>""")
            gr.HTML("""<h4 style="color:white;">You may need to refresh/login for 5 minutes of free GPU compute/day. </h4>""")
            gr.HTML("""<h4 style="color:white;">"pip install cellpose" for full functionality. </h4>""")

            input_image = gr.Image(label = "Input image", type = "filepath")

            with gr.Row():
                resize = gr.Number(label = 'max resize', value = 400)
                send_btn = gr.Button("Run Cellpose-SAM")
                
            with gr.Row():
                down_btn = gr.DownloadButton("Download masks (TIF)", visible=False)            
                down_btn2 = gr.DownloadButton("Download outlines (PNG)", visible=False)

            gr.HTML("""<a style="color:white;" href="https://github.com/MouseLand/cellpose" target="_blank">github page for cellpose</a>""")
            gr.HTML("""<a style="color:white;" href="https://github.com/MouseLand/cellpose" target="_blank">Cellpose-SAM paper</a>""")


        with gr.Column(scale=2):            
            img_outlines = gr.Image(label = "Outlines", type = "pil", format = 'png') #, width = "50vw", height = "20vw")
            img_overlay = gr.Image(label = "Overlay", type = "pil", format = 'png') #, width = "50vw", height = "20vw")
            flows = gr.Image(label = "Cellpose flows", type = "pil", format = 'png') #, width = "50vw", height = "20vw")
            #masks = gr.Image(label = "Output image", type = "numpy")
    
    
    send_btn.click(fn=cellpose_segment, inputs=[input_image, resize], outputs=[img_outlines, img_overlay, flows, down_btn, down_btn2])

    #down_btn.click(download_function, None, [down_btn, down_btn2])
        
    
    

demo.launch()