Spaces:
Running
Running
| import torch | |
| import gradio as gr | |
| import json | |
| from torchvision import transforms | |
| from torchvision.ops import nms | |
| from PIL import Image, ImageDraw, ImageFont | |
| TORCHSCRIPT_PATH = "res/screenrecognition-web350k-vins.torchscript" | |
| LABELS_PATH = "res/class_map_vins_manual.json" | |
| model = torch.jit.load(TORCHSCRIPT_PATH) | |
| with open(LABELS_PATH, "r") as f: | |
| idx2Label = json.load(f)["idx2Label"] | |
| img_transforms = transforms.ToTensor() | |
| def inter_class_nms(boxes, scores, labels, iou_threshold=0.5): | |
| # Perform non-maximum suppression | |
| keep = nms(boxes, scores, iou_threshold) | |
| # Filter boxes and scores | |
| new_boxes = boxes[keep] | |
| new_scores = scores[keep] | |
| new_labels = labels[keep] | |
| # Return the result in a dictionary | |
| return {'boxes': new_boxes, 'scores': new_scores, 'labels': new_labels} | |
| def predict(img, conf_thresh=0.4): | |
| img_input = [img_transforms(img)] | |
| _, pred = model(img_input) | |
| pred = [inter_class_nms(pred[0]['boxes'], pred[0]['scores'], pred[0]['labels'])] | |
| out_img = img.copy() | |
| draw = ImageDraw.Draw(out_img) | |
| font = ImageFont.truetype("res/Tuffy_Bold.ttf", 25) | |
| for i in range(len(pred[0]['boxes'])): | |
| conf_score = pred[0]['scores'][i] | |
| if conf_score > conf_thresh: | |
| x1, y1, x2, y2 = pred[0]['boxes'][i] | |
| x1 = int(x1) | |
| y1 = int(y1) | |
| x2 = int(x2) | |
| y2 = int(y2) | |
| draw.rectangle([x1, y1, x2, y2], outline='red', width=3) | |
| text = idx2Label[str(int(pred[0]['labels'][i]))] + " {:.2f}".format(float(conf_score)) | |
| bbox = draw.textbbox((x1, y1), text, font=font) | |
| draw.rectangle(bbox, fill="red") | |
| draw.text((x1, y1), text, font=font, fill="black") | |
| return out_img | |
| example_imgs = [ | |
| ["res/example.jpg", 0.4], | |
| ["res/screenlane-snapchat-profile.jpg", 0.4], | |
| ["res/screenlane-snapchat-settings.jpg", 0.4], | |
| ["res/example_pair1.jpg", 0.4], | |
| ["res/example_pair2.jpg", 0.4], | |
| ] | |
| interface = gr.Interface(fn=predict, inputs=[gr.Image(type="pil", label="Screenshot"), gr.Slider(0.0, 1.0, step=0.1, value=0.4)], outputs=gr.Image(type="pil", label="Annotated Screenshot").style(height=600), examples=example_imgs) | |
| interface.launch() | |