Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image, ImageDraw
|
| 5 |
+
import requests
|
| 6 |
+
from transformers import SamModel, SamProcessor
|
| 7 |
+
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
| 8 |
+
import cv2
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
+
|
| 13 |
+
#Load clipseg Model
|
| 14 |
+
clip_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 15 |
+
clip_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
|
| 16 |
+
|
| 17 |
+
# Load SAM model and processor
|
| 18 |
+
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
| 19 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
| 20 |
+
|
| 21 |
+
cache_data = None
|
| 22 |
+
|
| 23 |
+
# Prompts to segment damaged area and car
|
| 24 |
+
prompts = ['damaged', 'car']
|
| 25 |
+
damage_threshold = 0.4
|
| 26 |
+
vehicle_threshold = 0.5
|
| 27 |
+
|
| 28 |
+
def bbox_normalization(bbox, width, height):
|
| 29 |
+
height_coeff = height/352
|
| 30 |
+
width_coeff = width/352
|
| 31 |
+
normalized_bbox = [int(bbox[0]*width_coeff), int(bbox[1]*height_coeff),
|
| 32 |
+
int(bbox[2]*width_coeff), int(bbox[3]*height_coeff)]
|
| 33 |
+
return normalized_bbox
|
| 34 |
+
|
| 35 |
+
def bbox_area(bbox):
|
| 36 |
+
area = (bbox[2]-bbox[0])*(bbox[3]-bbox[1])
|
| 37 |
+
return area
|
| 38 |
+
|
| 39 |
+
def segment_to_bbox(segment_indexs):
|
| 40 |
+
x_points = []
|
| 41 |
+
y_points = []
|
| 42 |
+
for y, list_val in enumerate(segment_indexs):
|
| 43 |
+
for x, val in enumerate(list_val):
|
| 44 |
+
if val == 1:
|
| 45 |
+
x_points.append(x)
|
| 46 |
+
y_points.append(y)
|
| 47 |
+
return [np.min(x_points), np.min(y_points), np.max(x_points), np.max(y_points)]
|
| 48 |
+
|
| 49 |
+
def clipseg_prediction(image):
|
| 50 |
+
inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
|
| 51 |
+
# predict
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
outputs = model(**inputs)
|
| 54 |
+
preds = outputs.logits.unsqueeze(1)
|
| 55 |
+
# Setting threshold and classify the image contains vehicle or not
|
| 56 |
+
flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))
|
| 57 |
+
|
| 58 |
+
# Initialize a dummy "unlabeled" mask with the threshold
|
| 59 |
+
flat_damage_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), damage_threshold)
|
| 60 |
+
flat_vehicle_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), vehicle_threshold)
|
| 61 |
+
flat_damage_preds_with_treshold[1:2,:] = flat_preds[0] # damage
|
| 62 |
+
flat_vehicle_preds_with_treshold[1:2,:] = flat_preds[1] # vehicle
|
| 63 |
+
|
| 64 |
+
# Get the top mask index for each pixel
|
| 65 |
+
damage_inds = torch.topk(flat_damage_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
|
| 66 |
+
vehicle_inds = torch.topk(flat_vehicle_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
|
| 67 |
+
|
| 68 |
+
# bbox creation
|
| 69 |
+
damage_bbox = segment_to_bbox(damage_inds)
|
| 70 |
+
vehicle_bbox = segment_to_bbox(vehicle_inds)
|
| 71 |
+
|
| 72 |
+
# Vehicle checking
|
| 73 |
+
if bbox_area(vehicle_bbox) > bbox_area(damage_bbox):
|
| 74 |
+
return True, bbox_normalization(damage_bbox)
|
| 75 |
+
else:
|
| 76 |
+
return False, []
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@torch.no_grad()
|
| 80 |
+
def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
|
| 81 |
+
global cache_data
|
| 82 |
+
image_input = Image.fromarray(image_input)
|
| 83 |
+
inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
|
| 84 |
+
if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]):
|
| 85 |
+
embedding = model.get_image_embeddings(inputs["pixel_values"])
|
| 86 |
+
pixels = inputs["pixel_values"]
|
| 87 |
+
cache_data = [pixels, embedding]
|
| 88 |
+
del inputs["pixel_values"]
|
| 89 |
+
|
| 90 |
+
outputs = model.forward(image_embeddings=cache_data[1], **inputs)
|
| 91 |
+
masks = processor.image_processor.post_process_masks(
|
| 92 |
+
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
|
| 93 |
+
)
|
| 94 |
+
masks = masks[0].squeeze(0).numpy().transpose(1, 2, 0)
|
| 95 |
+
|
| 96 |
+
return masks
|
| 97 |
+
|
| 98 |
+
def main_func(inputs):
|
| 99 |
+
|
| 100 |
+
image_input = inputs['image']
|
| 101 |
+
classification, points = clipseg_prediction(image_input)
|
| 102 |
+
if classification:
|
| 103 |
+
masks = foward_pass(image_input, points)
|
| 104 |
+
|
| 105 |
+
image_input = Image.fromarray(image_input)
|
| 106 |
+
|
| 107 |
+
final_mask = masks[0]
|
| 108 |
+
mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
|
| 109 |
+
mask_colors[final_mask, :] = np.array([[128, 0, 0]])
|
| 110 |
+
return Image.fromarray((mask_colors * 0.6 + image_input * 0.4).astype('uint8'), 'RGB')
|
| 111 |
+
else:
|
| 112 |
+
return Image.fromarray(image_input)
|
| 113 |
+
|
| 114 |
+
return pred_masks
|
| 115 |
+
|
| 116 |
+
def reset_data():
|
| 117 |
+
global cache_data
|
| 118 |
+
cache_data = None
|
| 119 |
+
|
| 120 |
+
with gr.Blocks() as demo:
|
| 121 |
+
gr.Markdown("# Demo to run Vehicle damage detection")
|
| 122 |
+
gr.Markdown("""This app uses the SAM model and clipseg model to get a vehicle damage area from image.""")
|
| 123 |
+
with gr.Row():
|
| 124 |
+
image_input = gr.Image()
|
| 125 |
+
image_output = gr.Image()
|
| 126 |
+
|
| 127 |
+
image_button = gr.Button("Segment Image", variant='primary')
|
| 128 |
+
|
| 129 |
+
image_button.click(main_func, inputs=image_input, outputs=image_output)
|
| 130 |
+
image_input.upload(reset_data)
|
| 131 |
+
|
| 132 |
+
demo.launch()
|