增加对置信度和类别的选择
Browse files
app.py
CHANGED
|
@@ -17,20 +17,16 @@ COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
|
|
| 17 |
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
|
| 18 |
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
|
| 19 |
|
| 20 |
-
fdic = {
|
| 21 |
-
"style": "italic",
|
| 22 |
-
"size": 24,
|
| 23 |
-
"color": "yellow",
|
| 24 |
-
"weight": "bold"
|
| 25 |
-
}
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
|
| 29 |
label_color_dict = {}
|
| 30 |
|
| 31 |
def query_data(in_pil_img: Image.Image):
|
| 32 |
results = detector(in_pil_img)
|
| 33 |
-
|
| 34 |
return results
|
| 35 |
|
| 36 |
|
|
@@ -61,6 +57,8 @@ def get_annotated_image(in_pil_img):
|
|
| 61 |
score = round(prediction['score'] * 100, 1)
|
| 62 |
if score < threshold:
|
| 63 |
continue # 过滤掉低置信度的预测结果
|
|
|
|
|
|
|
| 64 |
|
| 65 |
if label not in label_color_dict: # 为每个类别随机分配颜色, 后续维持一致
|
| 66 |
color = choice(COLORS)
|
|
@@ -105,7 +103,7 @@ def process_video(input_video_path):
|
|
| 105 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 106 |
output_video_filename = f"output_{timestamp}.mp4"
|
| 107 |
output_video_path = os.path.join(output_dir, output_video_filename)
|
| 108 |
-
|
| 109 |
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
| 110 |
|
| 111 |
while True:
|
|
@@ -115,15 +113,15 @@ def process_video(input_video_path):
|
|
| 115 |
|
| 116 |
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 117 |
pil_image = Image.fromarray(rgb_frame)
|
| 118 |
-
|
| 119 |
annotated_frame = get_annotated_image(pil_image)
|
| 120 |
bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
| 121 |
-
|
| 122 |
# 确保帧的尺寸与视频输出一致
|
| 123 |
if bgr_frame.shape[:2] != (height, width):
|
| 124 |
bgr_frame = cv2.resize(bgr_frame, (width, height))
|
| 125 |
|
| 126 |
-
|
| 127 |
out.write(bgr_frame)
|
| 128 |
|
| 129 |
cap.release()
|
|
@@ -132,9 +130,31 @@ def process_video(input_video_path):
|
|
| 132 |
# 返回输出视频路径给 Gradio
|
| 133 |
return output_video_path
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo:
|
| 136 |
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
with gr.Row():
|
| 139 |
input_video = gr.Video(label="输入视频")
|
| 140 |
detect_button = gr.Button("开始检测", variant="primary")
|
|
|
|
| 17 |
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
|
| 18 |
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
|
| 22 |
+
threshold = 90 # 置信度阈值
|
| 23 |
+
label_list = ["person", "car", "truck"]
|
| 24 |
|
| 25 |
label_color_dict = {}
|
| 26 |
|
| 27 |
def query_data(in_pil_img: Image.Image):
|
| 28 |
results = detector(in_pil_img)
|
| 29 |
+
print(f"检测结果:{results}")
|
| 30 |
return results
|
| 31 |
|
| 32 |
|
|
|
|
| 57 |
score = round(prediction['score'] * 100, 1)
|
| 58 |
if score < threshold:
|
| 59 |
continue # 过滤掉低置信度的预测结果
|
| 60 |
+
if label not in label_list:
|
| 61 |
+
continue # 过滤掉不在允许显示的label列表中的预测结果
|
| 62 |
|
| 63 |
if label not in label_color_dict: # 为每个类别随机分配颜色, 后续维持一致
|
| 64 |
color = choice(COLORS)
|
|
|
|
| 103 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 104 |
output_video_filename = f"output_{timestamp}.mp4"
|
| 105 |
output_video_path = os.path.join(output_dir, output_video_filename)
|
| 106 |
+
print(f"输出视频信息:{output_video_path}, {width}x{height}, {fps}fps")
|
| 107 |
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
| 108 |
|
| 109 |
while True:
|
|
|
|
| 113 |
|
| 114 |
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 115 |
pil_image = Image.fromarray(rgb_frame)
|
| 116 |
+
print(f"Input frame of shape {rgb_frame.shape} and type {rgb_frame.dtype}") # 调试信息
|
| 117 |
annotated_frame = get_annotated_image(pil_image)
|
| 118 |
bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
| 119 |
+
print(f"Annotated frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
|
| 120 |
# 确保帧的尺寸与视频输出一致
|
| 121 |
if bgr_frame.shape[:2] != (height, width):
|
| 122 |
bgr_frame = cv2.resize(bgr_frame, (width, height))
|
| 123 |
|
| 124 |
+
print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
|
| 125 |
out.write(bgr_frame)
|
| 126 |
|
| 127 |
cap.release()
|
|
|
|
| 130 |
# 返回输出视频路径给 Gradio
|
| 131 |
return output_video_path
|
| 132 |
|
| 133 |
+
def change_threshold(value):
|
| 134 |
+
global threshold
|
| 135 |
+
threshold = value
|
| 136 |
+
return f"当前置信度阈值为{threshold}%"
|
| 137 |
+
|
| 138 |
+
def update_labels(selected_labels):
|
| 139 |
+
# 更新 label_list 以匹配用户的选择
|
| 140 |
+
global label_list
|
| 141 |
+
label_list = selected_labels
|
| 142 |
+
return selected_labels
|
| 143 |
+
|
| 144 |
with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo:
|
| 145 |
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
|
| 146 |
|
| 147 |
+
# 设置置信度阈值
|
| 148 |
+
threshold_slider = gr.Slider(minimum=0, maximum=100, value=threshold, step=1, label="置信度阈值")
|
| 149 |
+
textbox = gr.Textbox(value=f"当前置信度阈值为{threshold}%", label="置信度显示")
|
| 150 |
+
# 绑定滑块变化事件到change_threshold函数,同时设置输出为textbox
|
| 151 |
+
threshold_slider.change(fn=change_threshold, inputs=[threshold_slider], outputs=[textbox])
|
| 152 |
+
|
| 153 |
+
# 设置允许显示的label列表
|
| 154 |
+
label_checkboxes = gr.CheckboxGroup(choices=label_list, value=label_list, label="检测目标")
|
| 155 |
+
# 允许修改label_list
|
| 156 |
+
label_checkboxes.change(fn=update_labels, inputs=[label_checkboxes], outputs=[label_checkboxes])
|
| 157 |
+
|
| 158 |
with gr.Row():
|
| 159 |
input_video = gr.Video(label="输入视频")
|
| 160 |
detect_button = gr.Button("开始检测", variant="primary")
|