Spaces:
Runtime error
Runtime error
| from transformers import SegformerForSemanticSegmentation | |
| from transformers import SegformerImageProcessor | |
| from PIL import Image | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import cv2 | |
| import torch | |
| from imutils import perspective | |
| def midpoint(ptA, ptB): | |
| return ((ptA[0] + ptB[0]) * 0.5, (ptA[1] + ptB[1]) * 0.5) | |
| # Load in image, convert to gray scale, and Otsu's threshold | |
| kernel1 =( np.ones((5,5), dtype=np.float32)) | |
| blur_radius=0.5 | |
| kernel_sharpening = np.array([[-1,-1,-1], | |
| [-1,9,-1], | |
| [-1,-1,-1]])*(1/9) | |
| def cca_analysis(image,predicted_mask): | |
| image2=np.asarray(image) | |
| print(image.shape) | |
| image = cv2.resize(predicted_mask, (image2.shape[1],image2.shape[1]), interpolation = cv2.INTER_AREA) | |
| image=cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel1,iterations=1 ) | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] | |
| labels=cv2.connectedComponents(thresh,connectivity=8)[1] | |
| a=np.unique(labels) | |
| count2=0 | |
| for label in a: | |
| if label == 0: | |
| continue | |
| # Create a mask | |
| mask = np.zeros(thresh.shape, dtype="uint8") | |
| mask[labels == label] = 255 | |
| # Find contours and determine contour area | |
| cnts,hieararch = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cnts = cnts[0] | |
| c_area = cv2.contourArea(cnts) | |
| # threshhold for tooth count | |
| if c_area>100: | |
| count2+=1 | |
| rect = cv2.minAreaRect(cnts) | |
| box = cv2.boxPoints(rect) | |
| box = np.array(box, dtype="int") | |
| box = perspective.order_points(box) | |
| color1 = (list(np.random.choice(range(150), size=3))) | |
| color =[int(color1[0]), int(color1[1]), int(color1[2])] | |
| cv2.drawContours(image2,[box.astype("int")],0,color,2) | |
| (tl,tr,br,bl)=box | |
| (tltrX,tltrY)=midpoint(tl,tr) | |
| (blbrX,blbrY)=midpoint(bl,br) | |
| # compute the midpoint between the top-left and top-right points, | |
| # followed by the midpoint between the top-righ and bottom-right | |
| (tlblX,tlblY)=midpoint(tl,bl) | |
| (trbrX,trbrY)=midpoint(tr,br) | |
| # draw the midpoints on the image | |
| cv2.circle(image2, (int(tltrX), int(tltrY)), 5, (255, 0, 0), -1) | |
| cv2.circle(image2, (int(blbrX), int(blbrY)), 5, (255, 0, 0), -1) | |
| cv2.circle(image2, (int(tlblX), int(tlblY)), 5, (255, 0, 0), -1) | |
| cv2.circle(image2, (int(trbrX), int(trbrY)), 5, (255, 0, 0), -1) | |
| cv2.line(image2, (int(tltrX), int(tltrY)), (int(blbrX), int(blbrY)),color, 2) | |
| cv2.line(image2, (int(tlblX), int(tlblY)), (int(trbrX), int(trbrY)),color, 2) | |
| return image2 | |
| def to_rgb(img): | |
| result_new=np.zeros((img.shape[1],img.shape[0],3)) | |
| result_new[:,:,0]=img | |
| result_new[:,:,1]=img | |
| result_new[:,:,2]=img | |
| result_new=np.uint8(result_new*255) | |
| return result_new | |
| image_list = [ | |
| "data/1.png", | |
| "data/2.png", | |
| "data/3.png", | |
| "data/4.png", | |
| ] | |
| model_path = ['deprem-ml/deprem_satellite_semantic_whu'] | |
| def visualize_instance_seg_mask(mask): | |
| # Initialize image with zeros with the image resolution | |
| # of the segmentation mask and 3 channels | |
| image = np.zeros((mask.shape[0], mask.shape[1], 3)) | |
| # Create labels | |
| labels = np.unique(mask) | |
| label2color = { | |
| label: ( | |
| random.randint(0, 255), | |
| random.randint(0, 255), | |
| random.randint(0, 255), | |
| ) | |
| for label in labels | |
| } | |
| for height in range(image.shape[0]): | |
| for width in range(image.shape[1]): | |
| image[height, width, :] = label2color[mask[height, width]] | |
| image = image / 255 | |
| return image | |
| def Segformer_Segmentation(image_path, model_id,postpro): | |
| output_save = "output.png" | |
| test_image = cv2.imread(image_path) | |
| model = SegformerForSemanticSegmentation.from_pretrained(model_id) | |
| proccessor = SegformerImageProcessor(model_id) | |
| inputs = proccessor(images=test_image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| result = proccessor.post_process_semantic_segmentation(outputs)[0] | |
| result = np.array(result) | |
| if postpro=="Connected Components Labelling": | |
| result=to_rgb(result) | |
| result=cca_analysis(test_image,result) | |
| else: | |
| result = visualize_instance_seg_mask(result) | |
| result=result*255 | |
| cv2.imwrite(output_save, result) | |
| return image_path, output_save | |
| examples = [[image_list[0], "deprem-ml/deprem_satellite_semantic_whu"], | |
| [image_list[1], "deprem-ml/deprem_satellite_semantic_whu"], | |
| [image_list[2], "deprem-ml/deprem_satellite_semantic_whu"], | |
| [image_list[3], "deprem-ml/deprem_satellite_semantic_whu"]] | |
| title = "Deprem ML - Segformer Semantic Segmentation" | |
| app = gr.Blocks() | |
| with app: | |
| gr.HTML("<h1 style='text-align: center'>{}</h1>".format(title)) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Image(type='filepath') | |
| model_id = gr.Dropdown(value=model_path[0], choices=model_path,label="Model Name") | |
| cca = gr.Dropdown(value="Connected Components Labelling", choices=["Connected Components Labelling","No Post Process"],label="Post Process") | |
| input_video_button = gr.Button(value="Predict") | |
| with gr.Column(): | |
| output_orijinal_image = gr.Image(type='filepath') | |
| with gr.Column(): | |
| output_mask_image = gr.Image(type='filepath') | |
| gr.Examples(examples, inputs=[input_video, model_id,cca], outputs=[output_orijinal_image, output_mask_image], fn=Segformer_Segmentation, cache_examples=True) | |
| input_video_button.click(Segformer_Segmentation, inputs=[input_video, model_id,cca], outputs=[output_orijinal_image, output_mask_image]) | |
| app.launch(debug=True) | |