JrEasy commited on
Commit
c4a1c35
·
verified ·
1 Parent(s): e4cf8a9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -0
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Judol Gradio YOLO11.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1oiuTAi-cys1ydtUhSDJSRdeA02mAmZQH
8
+ """
9
+
10
+
11
+
12
+ import cv2
13
+ from ultralytics import YOLO
14
+ import gradio as gr
15
+ import imageio
16
+
17
+
18
+
19
+
20
+ model = YOLO('https://huggingface.co/JrEasy/Judol-Detection-YOLO11/resolve/main/best.pt')
21
+
22
+
23
+ confidence_threshold = 0.6
24
+
25
+ class_names = {
26
+ 0: "BK8",
27
+ 1: "Gate of Olympus",
28
+ 2: "Princess",
29
+ 3: "Starlight Princess",
30
+ 4: "Zeus",
31
+ }
32
+
33
+ class_colors = {
34
+ 0: (0, 255, 0), # Green for BK8
35
+ 1: (255, 0, 0), # Blue for Gate of Olympus
36
+ 2: (0, 0, 255), # Red for Princess
37
+ 3: (255, 255, 0), # Cyan for Starlight Princess
38
+ 4: (255, 0, 255), # Magenta for Zeus
39
+ }
40
+
41
+ def format_time_ranges(timestamps, classes):
42
+
43
+ if not timestamps:
44
+ return ""
45
+
46
+
47
+ class_timestamps = {}
48
+
49
+ for timestamp, class_id in zip(timestamps, classes):
50
+ class_name = class_names.get(class_id, 'Unknown')
51
+ if class_name not in class_timestamps:
52
+ class_timestamps[class_name] = []
53
+ class_timestamps[class_name].append(timestamp)
54
+
55
+
56
+ formatted_ranges = []
57
+
58
+ for class_name, timestamps in class_timestamps.items():
59
+ timestamps = sorted(timestamps)
60
+ ranges = []
61
+ start = timestamps[0]
62
+ for i in range(1, len(timestamps)):
63
+ if timestamps[i] - timestamps[i - 1] <= 1:
64
+ continue
65
+ else:
66
+ ranges.append(f"{int(start)}-{int(timestamps[i - 1])}")
67
+ start = timestamps[i]
68
+
69
+ ranges.append(f"{int(start)}-{int(timestamps[-1])}")
70
+
71
+ formatted_ranges.append(f"{class_name} = {', '.join(ranges)}")
72
+
73
+ return ", ".join(formatted_ranges)
74
+
75
+ import os
76
+
77
+ def process_video(input_video):
78
+ cap = cv2.VideoCapture(input_video)
79
+ if not cap.isOpened():
80
+ print("Error: Could not open input video.")
81
+ return None, []
82
+
83
+ fps = cap.get(cv2.CAP_PROP_FPS)
84
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
85
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
86
+
87
+ # Define the output video path in the current directory
88
+ output_video_path = os.path.join(os.getcwd(), "processed_video.mp4")
89
+
90
+ writer = imageio.get_writer(output_video_path, fps=fps, codec="h264")
91
+
92
+ frame_count = 0
93
+ timestamps = []
94
+ classes_detected = []
95
+
96
+ while cap.isOpened():
97
+ ret, frame = cap.read()
98
+ if not ret:
99
+ break
100
+
101
+ timestamp = frame_count / fps
102
+ frame_count += 1
103
+
104
+ # Resize the frame to 640x640 before passing to the model
105
+ resized_frame = cv2.resize(frame, (640, 640))
106
+
107
+ gray_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2GRAY)
108
+ input_frame = cv2.merge([gray_frame, gray_frame, gray_frame])
109
+
110
+ results = model.predict(input_frame)
111
+
112
+ for result in results:
113
+ for box in result.boxes:
114
+ if box.conf[0] >= confidence_threshold:
115
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
116
+ class_id = int(box.cls[0])
117
+ class_name = class_names.get(class_id, f"Class {class_id}")
118
+ color = class_colors.get(class_id, (0, 255, 0))
119
+ cv2.rectangle(resized_frame, (x1, y1), (x2, y2), color, 2)
120
+ text = f'{class_name}, Conf: {box.conf[0]:.2f}'
121
+ text_position = (x1, y1 - 10 if y1 > 20 else y1 + 20)
122
+ cv2.putText(resized_frame, text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
123
+
124
+ timestamps.append(timestamp)
125
+ classes_detected.append(class_id)
126
+
127
+ # Resize the frame back to original size for the output video
128
+ output_frame = cv2.resize(resized_frame, (frame_width, frame_height))
129
+
130
+ writer.append_data(cv2.cvtColor(output_frame, cv2.COLOR_BGR2RGB))
131
+
132
+ cap.release()
133
+ writer.close()
134
+
135
+ formatted_time_ranges = format_time_ranges(timestamps, classes_detected)
136
+
137
+ print(f"Processed video saved at: {output_video_path}")
138
+
139
+ return output_video_path, formatted_time_ranges
140
+
141
+
142
+
143
+
144
+ def process_image(input_image):
145
+ # Convert image from RGB to BGR for OpenCV processing
146
+ bgr_frame = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
147
+
148
+ # Resize the frame to 640x640 before passing to the model
149
+ resized_frame = cv2.resize(bgr_frame, (640, 640))
150
+
151
+ # Convert to grayscale and create a 3-channel grayscale image
152
+ gray_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2GRAY)
153
+ input_frame = cv2.merge([gray_frame, gray_frame, gray_frame])
154
+
155
+
156
+ results = model.predict(input_frame)
157
+
158
+ detections_log = []
159
+ classes_detected = []
160
+
161
+ for result in results:
162
+ for box in result.boxes:
163
+ if box.conf[0] >= confidence_threshold:
164
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
165
+ class_id = int(box.cls[0])
166
+ class_name = class_names.get(class_id, f"Class {class_id}")
167
+ color = class_colors.get(class_id, (0, 255, 0)) # Default green color
168
+
169
+
170
+ cv2.rectangle(resized_frame, (x1, y1), (x2, y2), color, 2)
171
+ text = f'{class_name}, Conf: {box.conf[0]:.2f}'
172
+ text_position = (x1, y1 - 10 if y1 > 20 else y1 + 20)
173
+ cv2.putText(resized_frame, text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
174
+
175
+
176
+ detections_log.append({
177
+ "class": class_name,
178
+ "confidence": box.conf[0]
179
+ })
180
+ classes_detected.append(class_id)
181
+
182
+ # Count occurrences of each class detected
183
+ class_count = {class_names.get(cls, f"Class {cls}"): classes_detected.count(cls) for cls in set(classes_detected)}
184
+
185
+ # Format the detections as 'Class = Count' pairs
186
+ formatted_log = ", ".join([f"{class_name} = {count}" for class_name, count in class_count.items()])
187
+
188
+ # Convert the output frame back to RGB
189
+ output_image = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
190
+ return output_image, formatted_log
191
+
192
+
193
+ with gr.Blocks() as app:
194
+ gr.Markdown("## Judol Detection using YOLOv11")
195
+
196
+ with gr.Tab("Video Detection"):
197
+ with gr.Row():
198
+ input_video = gr.Video(label="Upload a video")
199
+ output_video = gr.Video(label="Processed Video")
200
+ detections_log = gr.Textbox(label="Detections Log", lines=10)
201
+
202
+ input_video.change(
203
+ fn=lambda input_video: process_video(input_video) if input_video else ("", []),
204
+ inputs=input_video,
205
+ outputs=[output_video, detections_log],
206
+ )
207
+
208
+ with gr.Tab("Image Detection"):
209
+ with gr.Row():
210
+ input_image = gr.Image(label="Upload an image")
211
+ output_image = gr.Image(label="Processed Image")
212
+ image_detections_log = gr.Textbox(label="Detections Log", lines=10)
213
+
214
+ input_image.change(
215
+ fn=process_image,
216
+ inputs=input_image,
217
+ outputs=[output_image, image_detections_log],
218
+ )
219
+
220
+ app.launch()