JrEasy commited on
Commit
0c8bba6
·
verified ·
1 Parent(s): 5846765

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Judol Gradio YOLO11.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1oiuTAi-cys1ydtUhSDJSRdeA02mAmZQH
8
+ """
9
+
10
+
11
+
12
+ import cv2
13
+ from ultralytics import YOLO
14
+ import gradio as gr
15
+ import imageio
16
+
17
+
18
+ model = YOLO('https://huggingface.co/JrEasy/Judol-Detection-YOLO11/resolve/main/best.pt')
19
+
20
+
21
+ confidence_threshold = 0.6
22
+
23
+ class_names = {
24
+ 0: "BK8",
25
+ 1: "Gate of Olympus",
26
+ 2: "Princess",
27
+ 3: "Starlight Princess",
28
+ 4: "Zeus",
29
+ }
30
+
31
+ class_colors = {
32
+ 0: (0, 255, 0), # Green for BK8
33
+ 1: (255, 0, 0), # Blue for Gate of Olympus
34
+ 2: (0, 0, 255), # Red for Princess
35
+ 3: (255, 255, 0), # Cyan for Starlight Princess
36
+ 4: (255, 0, 255), # Magenta for Zeus
37
+ }
38
+
39
+ def format_time_ranges(timestamps, classes):
40
+
41
+ if not timestamps:
42
+ return ""
43
+
44
+
45
+ class_timestamps = {}
46
+
47
+ for timestamp, class_id in zip(timestamps, classes):
48
+ class_name = class_names.get(class_id, 'Unknown')
49
+ if class_name not in class_timestamps:
50
+ class_timestamps[class_name] = []
51
+ class_timestamps[class_name].append(timestamp)
52
+
53
+
54
+ formatted_ranges = []
55
+
56
+ for class_name, timestamps in class_timestamps.items():
57
+ timestamps = sorted(timestamps)
58
+ ranges = []
59
+ start = timestamps[0]
60
+ for i in range(1, len(timestamps)):
61
+ if timestamps[i] - timestamps[i - 1] <= 1:
62
+ continue
63
+ else:
64
+ ranges.append(f"{int(start)}-{int(timestamps[i - 1])}")
65
+ start = timestamps[i]
66
+
67
+ ranges.append(f"{int(start)}-{int(timestamps[-1])}")
68
+
69
+ formatted_ranges.append(f"{class_name} = {', '.join(ranges)}")
70
+
71
+ return ", ".join(formatted_ranges)
72
+
73
+ import os
74
+
75
+ def process_video(input_video):
76
+ cap = cv2.VideoCapture(input_video)
77
+ if not cap.isOpened():
78
+ print("Error: Could not open input video.")
79
+ return None, []
80
+
81
+ fps = cap.get(cv2.CAP_PROP_FPS)
82
+
83
+ # Define the output video path in the current directory
84
+ output_video_path = os.path.join(os.getcwd(), "processed_video.mp4")
85
+
86
+ writer = imageio.get_writer(output_video_path, fps=fps, codec="h264")
87
+
88
+ frame_count = 0
89
+ timestamps = []
90
+ classes_detected = []
91
+
92
+ while cap.isOpened():
93
+ ret, frame = cap.read()
94
+ if not ret:
95
+ break
96
+
97
+ timestamp = frame_count / fps
98
+ frame_count += 1
99
+
100
+ # Resize the frame to 640x640 before passing to the model
101
+ resized_frame = cv2.resize(frame, (640, 640))
102
+
103
+ gray_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2GRAY)
104
+ input_frame = cv2.merge([gray_frame, gray_frame, gray_frame])
105
+
106
+ results = model.predict(input_frame)
107
+
108
+ for result in results:
109
+ for box in result.boxes:
110
+ if box.conf[0] >= confidence_threshold:
111
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
112
+ class_id = int(box.cls[0])
113
+ class_name = class_names.get(class_id, f"Class {class_id}")
114
+ color = class_colors.get(class_id, (0, 255, 0))
115
+ cv2.rectangle(resized_frame, (x1, y1), (x2, y2), color, 2)
116
+ text = f'{class_name}, Conf: {box.conf[0]:.2f}'
117
+ text_position = (x1, y1 - 10 if y1 > 20 else y1 + 20)
118
+ cv2.putText(resized_frame, text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
119
+
120
+ timestamps.append(timestamp)
121
+ classes_detected.append(class_id)
122
+
123
+ writer.append_data(cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB))
124
+
125
+ cap.release()
126
+ writer.close()
127
+
128
+ formatted_time_ranges = format_time_ranges(timestamps, classes_detected)
129
+
130
+ print(f"Processed video saved at: {output_video_path}")
131
+
132
+ return output_video_path, formatted_time_ranges
133
+
134
+
135
+
136
+
137
+ def process_image(input_image):
138
+ # Convert image from RGB to BGR for OpenCV processing
139
+ bgr_frame = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
140
+
141
+ # Resize the frame to 640x640 before passing to the model
142
+ resized_frame = cv2.resize(bgr_frame, (640, 640))
143
+
144
+ # Convert to grayscale and create a 3-channel grayscale image
145
+ gray_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2GRAY)
146
+ input_frame = cv2.merge([gray_frame, gray_frame, gray_frame])
147
+
148
+
149
+ results = model.predict(input_frame)
150
+
151
+ detections_log = []
152
+ classes_detected = []
153
+
154
+ for result in results:
155
+ for box in result.boxes:
156
+ if box.conf[0] >= confidence_threshold:
157
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
158
+ class_id = int(box.cls[0])
159
+ class_name = class_names.get(class_id, f"Class {class_id}")
160
+ color = class_colors.get(class_id, (0, 255, 0)) # Default green color
161
+
162
+
163
+ cv2.rectangle(resized_frame, (x1, y1), (x2, y2), color, 2)
164
+ text = f'{class_name}, Conf: {box.conf[0]:.2f}'
165
+ text_position = (x1, y1 - 10 if y1 > 20 else y1 + 20)
166
+ cv2.putText(resized_frame, text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
167
+
168
+
169
+ detections_log.append({
170
+ "class": class_name,
171
+ "confidence": box.conf[0]
172
+ })
173
+ classes_detected.append(class_id)
174
+
175
+ # Count occurrences of each class detected
176
+ class_count = {class_names.get(cls, f"Class {cls}"): classes_detected.count(cls) for cls in set(classes_detected)}
177
+
178
+ # Format the detections as 'Class = Count' pairs
179
+ formatted_log = ", ".join([f"{class_name} = {count}" for class_name, count in class_count.items()])
180
+
181
+ # Convert the output frame back to RGB
182
+ output_image = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
183
+ return output_image, formatted_log
184
+
185
+
186
+ with gr.Blocks() as app:
187
+ gr.Markdown("## Judol Detection using YOLOv11")
188
+
189
+ with gr.Tab("Video Detection"):
190
+ with gr.Row():
191
+ input_video = gr.Video(label="Upload a video")
192
+ output_video = gr.Video(label="Processed Video")
193
+ detections_log = gr.Textbox(label="Detections Log", lines=10)
194
+
195
+ input_video.change(
196
+ fn=lambda input_video: process_video(input_video) if input_video else ("", []),
197
+ inputs=input_video,
198
+ outputs=[output_video, detections_log],
199
+ )
200
+
201
+ with gr.Tab("Image Detection"):
202
+ with gr.Row():
203
+ input_image = gr.Image(label="Upload an image")
204
+ output_image = gr.Image(label="Processed Image")
205
+ image_detections_log = gr.Textbox(label="Detections Log", lines=10)
206
+
207
+ input_image.change(
208
+ fn=process_image,
209
+ inputs=input_image,
210
+ outputs=[output_image, image_detections_log],
211
+ )
212
+
213
+ app.launch()