import numpy as np import gradio as gr import spaces import cv2 from cellpose import models from matplotlib.colors import hsv_to_rgb import matplotlib.pyplot as plt import os, io, base64 from PIL import Image # @title Data retrieval def download_weights(): import os, requests fname = ['cpsam'] url = ["https://osf.io/d7c8e/download"] for j in range(len(url)): if not os.path.isfile(fname[j]): try: r = requests.get(url[j]) except requests.ConnectionError: print("!!! Failed to download data !!!") else: if r.status_code != requests.codes.ok: print("!!! Failed to download data !!!") else: with open(fname[j], "wb") as fid: fid.write(r.content) try: #download_weights() model = models.CellposeModel(gpu=True, pretrained_model="cyto3") except Exception as e: print(f"Error loading model: {e}") exit(1) def plot_flows(y): Y = (np.clip(normalize99(y[0][0]),0,1) - 0.5) * 2 X = (np.clip(normalize99(y[1][0]),0,1) - 0.5) * 2 H = (np.arctan2(Y, X) + np.pi) / (2*np.pi) S = normalize99(y[0][0]**2 + y[1][0]**2) HSV = np.concatenate((H[:,:,np.newaxis], S[:,:,np.newaxis], S[:,:,np.newaxis]), axis=-1) HSV = np.clip(HSV, 0.0, 1.0) flow = (hsv_to_rgb(HSV) * 255).astype(np.uint8) return flow def plot_outlines(img, masks): outpix = [] contours, hierarchy = cv2.findContours(masks.astype(np.int32), mode=cv2.RETR_FLOODFILL, method=cv2.CHAIN_APPROX_SIMPLE) for c in range(len(contours)): pix = contours[c].astype(int).squeeze() if len(pix)>4: peri = cv2.arcLength(contours[c], True) approx = cv2.approxPolyDP(contours[c], 0.001, True)[:,0,:] outpix.append(approx) figsize = (6,6) if img.shape[0]>img.shape[1]: figsize = (6*img.shape[1]/img.shape[0], 6) else: figsize = (6, 6*img.shape[0]/img.shape[1]) fig = plt.figure(figsize=figsize, facecolor='k') ax = fig.add_axes([0.0,0.0,1,1]) ax.set_xlim([0,img.shape[1]]) ax.set_ylim([0,img.shape[0]]) ax.imshow(img[::-1], origin='upper', aspect = 'auto') if outpix is not None: for o in outpix: ax.plot(o[:,0], img.shape[0]-o[:,1], color=[1,0,0], lw=1) ax.axis('off') #bytes_image = io.BytesIO() #plt.savefig(bytes_image, format='png', facecolor=fig.get_facecolor(), edgecolor='none') #bytes_image.seek(0) #img_arr = np.frombuffer(bytes_image.getvalue(), dtype=np.uint8) #bytes_image.close() #img = cv2.imdecode(img_arr, 1) #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #del bytes_image #fig.clf() #plt.close(fig) buf = io.BytesIO() fig.savefig(buf, bbox_inches='tight') buf.seek(0) output_pil_img = Image.open(buf) return output_pil_img def plot_overlay(img, masks): img = normalize99(img.astype(np.float32).mean(axis=-1)) img -= img.min() img /= img.max() HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32) HSV[:,:,2] = np.clip(img*1.5, 0, 1.0) for n in range(int(masks.max())): ipix = (masks==n+1).nonzero() HSV[ipix[0],ipix[1],0] = np.random.rand() HSV[ipix[0],ipix[1],1] = 1.0 RGB = (hsv_to_rgb(HSV) * 255).astype(np.uint8) return RGB def normalize99(img): X = img.copy() X = (X - np.percentile(X, 1)) / (np.percentile(X, 99) - np.percentile(X, 1)) return X def image_resize(img, resize=400): ny,nx = img.shape[:2] if np.array(img.shape).max() > resize: if ny>nx: nx = int(nx/ny * resize) ny = resize else: ny = int(ny/nx * resize) nx = resize shape = (nx,ny) img = cv2.resize(img, shape) img = img.astype(np.uint8) return img @spaces.GPU(duration=10) def run_model_gpu(img): masks, flows, _ = model.eval(img, channels = [0,0]) return masks, flows #@spaces.GPU(duration=10) def cellpose_segment(img_input): img = image_resize(img_input) masks, flows = run_model_gpu(img) #masks, flows, _ = model.eval(img, channels=[0,0]) flows = flows[0] # masks = np.zeros(img.shape[:2]) # flows = np.zeros_like(img) outpix = plot_outlines(img, masks) overlay = plot_overlay(img, masks) target_size = (img_input.shape[1], img_input.shape[0]) if (target_size[0]!=img.shape[1] or target_size[1]!=img.shape[0]): # scale it back to keep the orignal size masks = cv2.resize(masks.astype('uint16'), target_size, interpolation=cv2.INTER_NEAREST).astype('uint16') #flows = cv2.resize(flows.astype('float32'), target_size).astype('uint8') #crand = .2 + .8 * np.random.rand(np.max(masks.flatten()).astype('int')+1,).astype('float32') #crand[0] = 0 overlay = Image.fromarray(overlay) flows = Image.fromarray(flows) #masks = Image.fromarray(255. * crand[masks]) pil_masks = Image.fromarray(masks.astype('int32')) pil_masks.save("masks.tiff") outpix.save("outlines.png") b1 = gr.DownloadButton(visible=True, value = "masks.tiff") b2 = gr.DownloadButton(visible=True, value = "outlines.png") return outpix, overlay, flows, b1, b2 # Gradio Interface #iface = gr.Interface( # fn=cellpose_segment, # inputs="image", # outputs=["image", "image", "image", "image"], # title="cellpose segmentation", # description="upload an image, then cellpose will segment it at a max size of 400x400 (for full functionality, 'pip install cellpose' locally)" #) def download_function(): b1 = gr.DownloadButton("Download masks as TIFF", visible=False) b2 = gr.DownloadButton("Download outline image as PNG", visible=False) return b1, b2 with gr.Blocks(title = "Hello", css=".gradio-container {background:purple;}") as demo: with gr.Row(): with gr.Column(scale=2): gr.HTML("""