andythebest commited on
Commit
35a3e38
·
verified ·
1 Parent(s): c0cc420

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +273 -0
main.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import os
4
+ import cv2
5
+ from ultralytics import YOLO
6
+ import shutil # Import shutil for copying files
7
+ import zipfile # Import zipfile for creating zip archives
8
+
9
+ def multi_model_detection(image_paths_list: list, model_paths_list: list, output_dir: str = 'detection_results', conf_threshold: float = 0.25):
10
+ """
11
+ 使用多個 YOLOv8 模型對多張圖片進行物件辨識,
12
+ 並將結果繪製在圖片上,同時保存辨識資訊到文字檔案。
13
+
14
+ Args:
15
+ image_paths_list (list): 包含所有待辨識圖片路徑的列表。
16
+ model_paths_list (list): 包含所有模型 (.pt 檔案) 路徑的列表。
17
+ output_dir (str): 儲存結果圖片和文字檔案的目錄。
18
+ 如果不存在,函式會自動創建。
19
+ conf_threshold (float): 置信度閾值,只有高於此值的偵測結果會被標示。
20
+
21
+ Returns:
22
+ list: A list of paths to the annotated images.
23
+ list: A list of paths to the text files with detection information.
24
+ """
25
+
26
+ # 確保輸出目錄存在
27
+ if not os.path.exists(output_dir):
28
+ os.makedirs(output_dir)
29
+ print(f"已創建輸出目錄: {output_dir}")
30
+
31
+ # 載入所有模型
32
+ loaded_models = []
33
+ print("\n--- 載入模型 ---")
34
+ # If no models are uploaded, use the default yolov8n.pt
35
+ if not model_paths_list:
36
+ default_model_path = 'yolov8n.pt'
37
+ try:
38
+ model = YOLO(default_model_path)
39
+ loaded_models.append((default_model_path, model))
40
+ print(f"成功載入預設模型: {default_model_path}")
41
+ except Exception as e:
42
+ print(f"錯誤: 無法載入預設模型 '{default_model_path}' - {e}")
43
+ return [], []
44
+ else:
45
+ for model_path in model_paths_list:
46
+ try:
47
+ model = YOLO(model_path)
48
+ loaded_models.append((model_path, model)) # 儲存模型路徑和模型物件
49
+ print(f"成功載入模型: {model_path}")
50
+ except Exception as e:
51
+ print(f"錯誤: 無法載入模型 '{model_path}' - {e}")
52
+ continue # 如果模型載入失敗,跳過它
53
+
54
+
55
+ if not loaded_models:
56
+ print("沒有模型成功載入,請檢查模型路徑或預設模型。")
57
+ return [], []
58
+
59
+ annotated_image_paths = []
60
+ txt_output_paths = []
61
+
62
+ # 處理每張圖片
63
+ print("\n--- 開始圖片辨識 ---")
64
+ for image_path in image_paths_list:
65
+ if not os.path.exists(image_path):
66
+ print(f"警告: 圖片 '{image_path}' 不存在,跳過。")
67
+ continue
68
+
69
+ print(f"\n處理圖片: {os.path.basename(image_path)}")
70
+ original_image = cv2.imread(image_path)
71
+ if original_image is None:
72
+ print(f"錯誤: 無法讀取圖片 '{image_path}',跳過。")
73
+ continue
74
+
75
+ # 複製圖片用於繪製,避免修改原始圖片
76
+ # 使用 NumPy 複製,而不是直接賦值
77
+ annotated_image = original_image.copy()
78
+
79
+ # 準備寫入文字檔的內容
80
+ txt_output_content = []
81
+ txt_output_content.append(f"檔案: {os.path.basename(image_path)}\n")
82
+
83
+ # 對每張圖片使用所有模型進行辨識
84
+ all_detections_for_image = [] # 儲存所有模型在當前圖片上的偵測結果
85
+
86
+ for model_path_str, model_obj in loaded_models:
87
+ model_name = os.path.basename(model_path_str) # 獲取模型檔案名
88
+ print(f" 使用模型 '{model_name}' 進行辨識...")
89
+
90
+ # 執行推論, device="cpu" ensures it runs on CPU if GPU is not available or preferred
91
+ results = model_obj(image_path, verbose=False, device="cpu")[0]
92
+
93
+ # 將辨識結果添加到 txt 輸出內容和繪圖列表
94
+ txt_output_content.append(f"\n--- 模型: {model_name} ---")
95
+
96
+ if results.boxes: # 檢查是否有偵測到物件
97
+ for box in results.boxes:
98
+ # 取得邊界框座標和置信度
99
+ conf = float(box.conf[0])
100
+ if conf >= conf_threshold: # 檢查置信度是否達到閾值
101
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
102
+ cls_id = int(box.cls[0])
103
+ cls_name = model_obj.names[cls_id] # 取得類別名稱
104
+
105
+ detection_info = {
106
+ 'model_name': model_name,
107
+ 'class_name': cls_name,
108
+ 'confidence': conf,
109
+ 'bbox': (x1, y1, x2, y2)
110
+ }
111
+ all_detections_for_image.append(detection_info)
112
+
113
+ # 加入到文字檔內容
114
+ txt_output_content.append(f" - {cls_name} (Conf: {conf:.2f}) [x1:{x1}, y1:{y1}, x2:{x2}, y2:{y2}]")
115
+ else:
116
+ txt_output_content.append(" 沒有偵測到���何物件。")
117
+
118
+ # 繪製所有模型在當前圖片上的偵測結果
119
+ # 我們會根據模型來源給予不同的顏色或樣式,讓結果更容易區分
120
+
121
+ # 定義一個顏色循環列表,方便給不同模型分配不同顏色
122
+ colors = [
123
+ (255, 0, 0), # 紅色 (例如給模型 A)
124
+ (0, 255, 0), # 綠色 (例如給模型 B)
125
+ (0, 0, 255), # 藍色
126
+ (255, 255, 0), # 黃色
127
+ (255, 0, 255), # 紫色
128
+ (0, 255, 255), # 青色
129
+ (128, 0, 0), # 深紅
130
+ (0, 128, 0) # 深綠
131
+ ]
132
+ color_map = {} # 用來映射模型名稱到顏色
133
+
134
+ for idx, (model_path_str, _) in enumerate(loaded_models):
135
+ model_name = os.path.basename(model_path_str)
136
+ color_map[model_name] = colors[idx % len(colors)] # 確保顏色循環使用
137
+
138
+ for det in all_detections_for_image:
139
+ x1, y1, x2, y2 = det['bbox']
140
+ conf = det['confidence']
141
+ cls_name = det['class_name']
142
+ model_name = det['model_name']
143
+
144
+ color = color_map.get(model_name, (200, 200, 200)) # 預設灰色
145
+
146
+ # 繪製邊界框
147
+ cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
148
+
149
+ # 繪製標籤 (類別名稱 + 置信度 + 模型名稱縮寫)
150
+ # 為了避免標籤過長,模型名稱只取前幾個字母
151
+ model_abbr = "".join([s[0] for s in model_name.split('.')[:-1]]) # 例如 'a.pt' -> 'a'
152
+ label = f'{cls_name} {conf:.2f} ({model_abbr})'
153
+ cv2.putText(annotated_image, label, (x1, y1 - 10),
154
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
155
+
156
+ # 保存繪製後的圖片
157
+ image_base_name = os.path.basename(image_path)
158
+ image_name_without_ext = os.path.splitext(image_base_name)[0]
159
+ output_image_path = os.path.join(output_dir, f"{image_name_without_ext}_detected.jpg")
160
+ cv2.imwrite(output_image_path, annotated_image)
161
+ annotated_image_paths.append(output_image_path)
162
+ print(f" 結果圖片保存至: {output_image_path}")
163
+
164
+ # 保存辨識資訊到文字檔案
165
+ output_txt_path = os.path.join(output_dir, f"{image_name_without_ext}.txt")
166
+ with open(output_txt_path, 'w', encoding='utf-8') as f:
167
+ f.write("\n".join(txt_output_content))
168
+ txt_output_paths.append(output_txt_path)
169
+ print(f" 辨識資訊保存至: {output_txt_path}")
170
+
171
+
172
+ print("\n--- 所有圖片處理完成 ---")
173
+ return annotated_image_paths, txt_output_paths
174
+
175
+ def create_zip_archive(files, zip_filename):
176
+ """Creates a zip archive from a list of files."""
177
+ with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
178
+ for file in files:
179
+ if os.path.exists(file):
180
+ zipf.write(file, os.path.basename(file))
181
+ else:
182
+ print(f"警告: 檔案 '{file}' 不存在,無法加入壓縮檔。")
183
+ return zip_filename
184
+
185
+
186
+ # --- Gradio Interface ---
187
+ def gradio_multi_model_detection(image_files, model_files, conf_threshold, output_subdir):
188
+ """
189
+ Gradio 的主要處理函式。
190
+ 接收上傳的檔案和參數,呼叫後端辨識函式,並返回結果。
191
+
192
+ Args:
193
+ image_files (list): Gradio File 元件回傳的圖片檔案列表 (暫存路徑)。
194
+ model_files (list): Gradio File 元件回傳的模型檔案列表 (暫存路徑)。
195
+ conf_threshold (float): 置信度閾值。
196
+ output_subdir (str): 用於儲存本次執行結果的子目錄名稱。
197
+
198
+ Returns:
199
+ tuple: 更新 Gradio 介面所需的多個輸出。
200
+ """
201
+ if not image_files:
202
+ return None, "請上傳圖片檔案。", None, None
203
+
204
+ # Get the temporary file paths from Gradio File objects
205
+ image_paths = [file.name for file in image_files]
206
+ # Use uploaded model paths or an empty list if none are uploaded
207
+ model_paths = [file.name for file in model_files] if model_files else []
208
+
209
+
210
+ # Define the output directory for this run within the main results directory
211
+ base_output_dir = 'gradio_detection_results'
212
+ run_output_dir = os.path.join(base_output_dir, output_subdir)
213
+
214
+ # Perform detection
215
+ annotated_images, detection_texts = multi_model_detection(
216
+ image_paths_list=image_paths,
217
+ model_paths_list=model_paths,
218
+ output_dir=run_output_dir,
219
+ conf_threshold=conf_threshold
220
+ )
221
+
222
+ if not annotated_images:
223
+ return None, "辨識失敗,請檢查輸入或模型。", None, None
224
+
225
+ # Combine detection texts for display in one textbox
226
+ combined_detection_text = "--- 辨識結果 ---\n\n"
227
+ for txt_path in detection_texts:
228
+ with open(txt_path, 'r', encoding='utf-8') as f:
229
+ combined_detection_text += f.read() + "\n\n"
230
+
231
+ # Create a zip file containing both annotated images and text files
232
+ all_result_files = annotated_images + detection_texts
233
+ zip_filename = os.path.join(run_output_dir, f"{output_subdir}_results.zip")
234
+ created_zip_path = create_zip_archive(all_result_files, zip_filename)
235
+
236
+
237
+ # Return annotated images and combined text for Gradio output
238
+ # Gradio Gallery expects a list of image paths
239
+ return annotated_images, combined_detection_text, f"結果儲存於: {os.path.abspath(run_output_dir)}", created_zip_path
240
+
241
+
242
+ # Create the Gradio interface
243
+ with gr.Blocks() as demo:
244
+ gr.Markdown("# 支援多模型YOLO物件辨識(demo)")
245
+ gr.Markdown("上傳您的圖片和模型,並設定置信度閾值進行物件辨識。若未上傳模型,將使用預設的 yolov8n.pt 進行辨識。")
246
+
247
+ with gr.Row():
248
+ with gr.Column():
249
+ image_input = gr.File(label="上傳圖片", file_count="multiple", file_types=["image"])
250
+ model_input = gr.File(label="上傳模型 (.pt)", file_count="multiple", file_types=[".pt"])
251
+ conf_slider = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.05, label="置信度閾值")
252
+ output_subdir_input = gr.Textbox(label="結果子目錄名稱", value="run_1", placeholder="請輸入儲存結果的子目錄名稱")
253
+ run_button = gr.Button("開始辨識")
254
+
255
+ with gr.Column():
256
+ # show_label=False hides the class name label below each image
257
+ # allow_preview=True enables double-clicking to zoom
258
+ # allow_download=True adds a download button for each image in the gallery
259
+ output_gallery = gr.Gallery(label="辨識結果圖片", height=400, allow_preview=True, object_fit="contain")
260
+ output_text = gr.Textbox(label="辨識資訊", lines=10)
261
+ output_status = gr.Textbox(label="狀態/儲存路徑")
262
+ download_button = gr.File(label="下載所有結果 (.zip)", file_count="single")
263
+
264
+
265
+ # Link the button click to the function
266
+ run_button.click(
267
+ fn=gradio_multi_model_detection,
268
+ inputs=[image_input, model_input, conf_slider, output_subdir_input],
269
+ outputs=[output_gallery, output_text, output_status, download_button]
270
+ )
271
+
272
+ # Launch the interface
273
+ demo.launch(debug=True)