File size: 7,583 Bytes
49c6db7
 
8786ac3
49c6db7
 
 
f0821bf
 
44c9541
49c6db7
3e66137
 
 
 
 
d874e72
3e66137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49c6db7
180fb05
644bb2c
49c6db7
 
 
 
3e66137
 
 
49c6db7
 
 
 
 
 
 
 
 
 
f0821bf
49c6db7
 
 
 
 
 
 
 
f0821bf
 
 
 
 
 
 
 
 
 
1c5339e
f0821bf
 
 
 
c034f55
 
 
 
 
 
 
 
 
 
 
 
 
736285e
c034f55
 
 
49c6db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32b0ba4
49c6db7
 
 
 
 
 
 
 
 
 
 
 
 
eeb07d9
ded4361
a887322
644bb2c
a887322
 
 
958ea27
 
a887322
 
49c6db7
 
 
18b4441
8a8ccfd
 
18b4441
49c6db7
958ea27
49c6db7
 
18b4441
958ea27
18b4441
 
 
4c3c584
b2eef14
 
 
 
7170f20
 
 
a13d968
7170f20
 
 
4c3c584
7170f20
49c6db7
 
2c27168
 
 
 
 
 
 
49c6db7
fd1e2f9
25641bf
 
 
 
2c27168
 
 
 
c3abe48
18b4441
 
 
 
77e473c
 
18b4441
 
c3abe48
2c27168
18b4441
 
 
 
c3abe48
18b4441
 
 
0305d39
b2eef14
 
7170f20
a3dc09f
25641bf
2c27168
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import numpy as np
import gradio as gr
import spaces
import cv2
from cellpose import models
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt
import os, io, base64
from PIL import Image 


# @title Data retrieval
def download_weights():
    import os, requests
    
    fname = ['cpsam']
    
    url = ["https://osf.io/d7c8e/download"]
    
    for j in range(len(url)):
      if not os.path.isfile(fname[j]):
        try:
          r = requests.get(url[j])
        except requests.ConnectionError:
          print("!!! Failed to download data !!!")
        else:
          if r.status_code != requests.codes.ok:
            print("!!! Failed to download data !!!")
          else:
            with open(fname[j], "wb") as fid:
              fid.write(r.content)

try:
    #download_weights()
    model = models.CellposeModel(gpu=True, pretrained_model="cyto3")
except Exception as e:
    print(f"Error loading model: {e}")
    exit(1)



            
def plot_flows(y):
    Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2
    X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2
    H = (np.arctan2(Y, X) + np.pi) / (2*np.pi)
    S = normalize99(y[0][0]**2 + y[1][0]**2)
    HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1)
    HSV = np.clip(HSV, 0.0, 1.0)
    flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
    return flow

def plot_outlines(img, masks):
    outpix = []
    contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE)
    for c in range(len(contours)):
        pix = contours[c].astype(int).squeeze()
        if len(pix)>4:
            peri = cv2.arcLength(contours[c], True)
            approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:]
            outpix.append(approx)
    
    figsize = (6,6)
    if img.shape[0]>img.shape[1]:
        figsize = (6*img.shape[1]/img.shape[0], 6)
    else:
        figsize = (6, 6*img.shape[0]/img.shape[1])
    fig = plt.figure(figsize=figsize, facecolor='k')
    ax = fig.add_axes([0.0,0.0,1,1])
    ax.set_xlim([0,img.shape[1]])
    ax.set_ylim([0,img.shape[0]])
    ax.imshow(img[::-1], origin='upper', aspect = 'auto')
    if outpix is not None:
        for o in outpix:
            ax.plot(o[:,0], img.shape[0]-o[:,1], color=[1,0,0], lw=1)
    ax.axis('off')
    
    #bytes_image = io.BytesIO()
    #plt.savefig(bytes_image, format='png', facecolor=fig.get_facecolor(), edgecolor='none')
    #bytes_image.seek(0)
    #img_arr = np.frombuffer(bytes_image.getvalue(), dtype=np.uint8)
    #bytes_image.close()
    #img = cv2.imdecode(img_arr, 1)
    #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #del bytes_image
    #fig.clf()
    #plt.close(fig)

    buf = io.BytesIO()
    fig.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    output_pil_img = Image.open(buf)
    return output_pil_img

def plot_overlay(img, masks):
    img = normalize99(img.astype(np.float32).mean(axis=-1))
    img -= img.min()
    img /= img.max()
    HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
    HSV[:,:,2] = np.clip(img*1.5, 0, 1.0)
    for n in range(int(masks.max())):
        ipix = (masks==n+1).nonzero()
        HSV[ipix[0],ipix[1],0] = np.random.rand()
        HSV[ipix[0],ipix[1],1] = 1.0
    RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8)
    return RGB

def normalize99(img):
    X = img.copy()
    X = (X - np.percentile(X, 1)) / (np.percentile(X, 99) - np.percentile(X, 1))
    return X

def image_resize(img, resize=400):
    ny,nx = img.shape[:2]
    if np.array(img.shape).max() > resize:
        if ny>nx:
            nx = int(nx/ny * resize)
            ny = resize
        else:
            ny = int(ny/nx * resize)
            nx = resize
        shape = (nx,ny)
        img = cv2.resize(img, shape)
    img = img.astype(np.uint8)
    return img

    
@spaces.GPU(duration=10)
def run_model_gpu(img):
    masks, flows, _ = model.eval(img, channels = [0,0])
    return masks, flows

#@spaces.GPU(duration=10)
def cellpose_segment(img_input):
    img = image_resize(img_input)
    masks, flows = run_model_gpu(img)
    #masks, flows, _ = model.eval(img, channels=[0,0])
    flows = flows[0]
    # masks = np.zeros(img.shape[:2])
    # flows = np.zeros_like(img)

    outpix = plot_outlines(img, masks)
    overlay = plot_overlay(img, masks)
    
    target_size = (img_input.shape[1], img_input.shape[0])
    if (target_size[0]!=img.shape[1] or target_size[1]!=img.shape[0]):
        # scale it back to keep the orignal size
        masks = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16')
        #flows = cv2.resize(flows.astype('float32'), target_size).astype('uint8')
    
    
    #crand = .2 + .8 * np.random.rand(np.max(masks.flatten()).astype('int')+1,).astype('float32')
    #crand[0] = 0

    #overlay = Image.fromarray(overlay)
    #flows = Image.fromarray(flows)    
    #masks = Image.fromarray(255. * crand[masks])

    pil_masks = Image.fromarray(masks.astype('int32'))
    pil_masks.save("masks.tiff")

    outpix.save("outlines.png")

    b1 = gr.DownloadButton(visible=True, value = "masks.tiff")
    b2 = gr.DownloadButton(visible=True, value = "outlines.png")
    
    return outpix, overlay, flows, b1, b2

# Gradio Interface
#iface = gr.Interface(
#    fn=cellpose_segment, 
#    inputs="image", 
#    outputs=["image", "image", "image", "image"],
#    title="cellpose segmentation",
#    description="upload an image, then cellpose will segment it at a max size of 400x400 (for full functionality, 'pip install cellpose' locally)"
#)

def download_function(): 
    b1 = gr.DownloadButton("Download masks as TIFF", visible=False)
    b2 = gr.DownloadButton("Download outline image as PNG", visible=False)
    return b1, b2

with gr.Blocks(title = "Hello", 
               css=".gradio-container {background:purple;}") as demo:

    with gr.Row():
        with gr.Column(scale=2):
            gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:white;">Cellpose-SAM for cellular segmentation</div>""")
            gr.HTML("""<h4 style="color:white;">You may need to refresh/login for 5 minutes of free GPU compute time/day. </h4>""")
            gr.HTML("""<h4 style="color:white;">"pip install cellpose" for full functionality. </h4>""")

            input_image = gr.Image(label = "Input image", type = "numpy")
            send_btn = gr.Button("Run Cellpose-SAM")
            with gr.Row():
                down_btn = gr.DownloadButton("Download masks (TIFF)", visible=False)            
                down_btn2 = gr.DownloadButton("Download outlines (PNG)", visible=False)

            gr.HTML("""<li><a href="https://github.com/MouseLand/cellpose" target="_blank">github page for cellpose</a>""")
            gr.HTML("""<li><a style="color:white;" href="https://github.com/MouseLand/cellpose" target="_blank">Cellpose-SAM paper</a>""")


        with gr.Column(scale=2):            
            img_outlines = gr.Image(label = "Outlines", type = "pil")            
            img_overlay = gr.Image(label = "Overlay", type = "numpy")            
            flows = gr.Image(label = "Cellpose flows", type = "numpy")
            #masks = gr.Image(label = "Output image", type = "numpy")
    
    
    send_btn.click(fn=cellpose_segment, inputs=[input_image], outputs=[img_outlines, img_overlay, flows, down_btn, down_btn2])

    #down_btn.click(download_function, None, [down_btn, down_btn2])
        
    
    

demo.launch()