Spaces:
Runtime error
Runtime error
isLinXu
commited on
Commit
·
95a8e8d
1
Parent(s):
f5d2cd0
update files
Browse files- app.py +103 -0
- requirements.txt +20 -0
app.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
os.system("pip install super-gradients~=3.2.0")
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
+
from super_gradients.training import models
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
warnings.filterwarnings("ignore")
|
| 13 |
+
class YOLOX_WebUI:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.download_test_img()
|
| 16 |
+
|
| 17 |
+
def download_test_img(self):
|
| 18 |
+
# Images
|
| 19 |
+
torch.hub.download_url_to_file(
|
| 20 |
+
'https://user-images.githubusercontent.com/59380685/266264420-21575a83-4057-41cf-8a4a-b3ea6f332d79.jpg',
|
| 21 |
+
'bus.jpg')
|
| 22 |
+
torch.hub.download_url_to_file(
|
| 23 |
+
'https://user-images.githubusercontent.com/59380685/266264536-82afdf58-6b9a-4568-b9df-551ee72cb6d9.jpg',
|
| 24 |
+
'dogs.jpg')
|
| 25 |
+
torch.hub.download_url_to_file(
|
| 26 |
+
'https://user-images.githubusercontent.com/59380685/266264600-9d0c26ca-8ba6-45f2-b53b-4dc98460c43e.jpg',
|
| 27 |
+
'zidane.jpg')
|
| 28 |
+
|
| 29 |
+
def predict(self, image_path,conf, iou, line_width, device, model_type, model_path):
|
| 30 |
+
self.device = device
|
| 31 |
+
if model_type == "yolox_n":
|
| 32 |
+
self.model = models.get("yolox_n", pretrained_weights="coco").to(self.device)
|
| 33 |
+
elif model_type == "yolox_s":
|
| 34 |
+
self.model = models.get("yolox_s", pretrained_weights="coco").to(self.device)
|
| 35 |
+
elif model_type == "yolox_m":
|
| 36 |
+
self.model = models.get("yolox_m", pretrained_weights="coco").to(self.device)
|
| 37 |
+
elif model_type == "yolox_t":
|
| 38 |
+
self.model = models.get("yolox_t", pretrained_weights="coco").to(self.device)
|
| 39 |
+
elif model_type == "yolox_l":
|
| 40 |
+
self.model = models.get("yolox_l", pretrained_weights="coco").to(self.device)
|
| 41 |
+
else:
|
| 42 |
+
self.model = models.get(model_path, pretrained_weights="coco").to(self.device)
|
| 43 |
+
if model_type not in ["yolox_n", "yolox_s", "yolox_m", "yolox_l","yolox_t"]:
|
| 44 |
+
self.model = models.get(model_path, pretrained_weights="coco").to(self.device)
|
| 45 |
+
|
| 46 |
+
results = self.model.predict(image_path)
|
| 47 |
+
|
| 48 |
+
# get image data and bbox information
|
| 49 |
+
image = results._images_prediction_lst[0].image
|
| 50 |
+
class_names = results._images_prediction_lst[0].class_names
|
| 51 |
+
prediction = results._images_prediction_lst[0].prediction
|
| 52 |
+
bboxes_xyxy = prediction.bboxes_xyxy
|
| 53 |
+
labels = prediction.labels
|
| 54 |
+
confidences = prediction.confidence
|
| 55 |
+
# draw rectangles and label names
|
| 56 |
+
for bbox, label, confidence in zip(bboxes_xyxy, labels, confidences):
|
| 57 |
+
color = tuple(np.random.randint(0, 255, 3).tolist())
|
| 58 |
+
x1, y1, x2, y2 = bbox.astype(int)
|
| 59 |
+
if confidence > conf:
|
| 60 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
| 61 |
+
cla_name = class_names[int(label)]
|
| 62 |
+
label_name = f"{cla_name}: {confidence:.2f}"
|
| 63 |
+
cv2.putText(image, label_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
| 64 |
+
|
| 65 |
+
return image
|
| 66 |
+
|
| 67 |
+
if __name__ == '__main__':
|
| 68 |
+
# Instantiate YOLO_NAS_WebUI class
|
| 69 |
+
detector = YOLOX_WebUI()
|
| 70 |
+
examples = [
|
| 71 |
+
['bus.jpg', 0.25, 0.45, 2, "cpu", "yolox_n", "yolox_n.pt"],
|
| 72 |
+
['dogs.jpg', 0.25, 0.45, 2, "cpu", "yolox_n", "yolox_n.pt"],
|
| 73 |
+
['zidane.jpg', 0.25, 0.45, 2, "cpu", "yolox_n", "yolox_n.pt"]
|
| 74 |
+
]
|
| 75 |
+
# Define Gradio interface
|
| 76 |
+
iface = gr.Interface(
|
| 77 |
+
fn=detector.predict,
|
| 78 |
+
inputs=["image",
|
| 79 |
+
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.25,
|
| 80 |
+
label="Confidence Threshold"),
|
| 81 |
+
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.45,
|
| 82 |
+
label="IoU Threshold"),
|
| 83 |
+
gr.inputs.Number(default=2, label="Line Width"),
|
| 84 |
+
gr.inputs.Radio(["cpu", "cuda"], label="Device", default="cpu"),
|
| 85 |
+
gr.inputs.Radio(["yolox_n", "yolox_s", "yolox_m", "yolox_l","yolox_t"],
|
| 86 |
+
label="Model Type", default="yolo_nas_s"),
|
| 87 |
+
gr.inputs.Textbox(default="yolox_n.pt", label="Model Path")],
|
| 88 |
+
outputs="image",
|
| 89 |
+
title="YOLOX_WebUI Object Detector",
|
| 90 |
+
description="Detect objects in an image using YOLOX model.",
|
| 91 |
+
theme="default", examples=examples,
|
| 92 |
+
layout="vertical",
|
| 93 |
+
allow_flagging=False,
|
| 94 |
+
analytics_enabled=True,
|
| 95 |
+
server_port=None,
|
| 96 |
+
server_name=None,
|
| 97 |
+
server_protocol=None,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Run Gradio interface
|
| 101 |
+
iface.launch(share=True)
|
| 102 |
+
|
| 103 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wget~=3.2
|
| 2 |
+
opencv-python~=4.6.0.66
|
| 3 |
+
numpy~=1.23.0
|
| 4 |
+
torch~=1.13.1
|
| 5 |
+
torchvision~=0.14.1
|
| 6 |
+
pillow~=9.4.0
|
| 7 |
+
gradio~=3.42.0
|
| 8 |
+
ultralytics~=8.0.169
|
| 9 |
+
pyyaml~=6.0
|
| 10 |
+
wandb~=0.13.11
|
| 11 |
+
tqdm~=4.65.0
|
| 12 |
+
matplotlib~=3.7.1
|
| 13 |
+
pandas~=2.0.0
|
| 14 |
+
seaborn~=0.12.2
|
| 15 |
+
requests~=2.31.0
|
| 16 |
+
psutil~=5.9.4
|
| 17 |
+
thop~=0.1.1-2209072238
|
| 18 |
+
timm~=0.9.2
|
| 19 |
+
super-gradients~=3.2.0
|
| 20 |
+
openmim
|