Saiky2k commited on
Commit
a6b2b6c
·
verified ·
1 Parent(s): 5295446

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +333 -0
app.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ from PIL import Image
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from ultralytics import YOLO
8
+ import time
9
+ import tempfile
10
+ import os
11
+ import requests
12
+ from io import BytesIO
13
+
14
+ # Tạo module depth_pro đơn giản (để thay thế module gốc)
15
+ class DepthPro:
16
+ @staticmethod
17
+ def create_model_and_transforms():
18
+ # Nhập các thư viện cần thiết ở đây để tránh lỗi khi khởi tạo
19
+ import torch
20
+ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
21
+
22
+ # Tải mô hình depth estimation từ Hugging Face
23
+ processor = AutoImageProcessor.from_pretrained("vinvino02/glpn-nyu")
24
+ model = AutoModelForDepthEstimation.from_pretrained("vinvino02/glpn-nyu")
25
+
26
+ # Tạo hàm transform đơn giản
27
+ def transform(image):
28
+ return processor(images=image, return_tensors="pt").pixel_values
29
+
30
+ # Mở rộng model với phương thức infer
31
+ def infer_method(self, image, f_px=None):
32
+ with torch.no_grad():
33
+ outputs = self(image)
34
+ predicted_depth = outputs.predicted_depth
35
+
36
+ # Chuẩn hóa độ sâu
37
+ depth_min = torch.min(predicted_depth)
38
+ depth_max = torch.max(predicted_depth)
39
+ predicted_depth = (predicted_depth - depth_min) / (depth_max - depth_min)
40
+ predicted_depth = predicted_depth * 10 # Nhân với 10 để có giá trị mét hợp lý hơn
41
+
42
+ return {"depth": predicted_depth}
43
+
44
+ # Thêm phương thức infer vào model
45
+ model.infer = infer_method.__get__(model)
46
+
47
+ return model, transform
48
+
49
+ # Hàm tải mô hình YOLO từ Hugging Face
50
+ @st.cache_resource
51
+ def load_yolo_model():
52
+ # Sử dụng mô hình YOLOv8n từ Hugging Face
53
+ model = YOLO("yolov8n.pt")
54
+ return model
55
+
56
+ # Hàm tải và chuẩn bị mô hình độ sâu
57
+ @st.cache_resource
58
+ def load_depth_model():
59
+ depth_pro = DepthPro()
60
+ model, transform = depth_pro.create_model_and_transforms()
61
+ return model, transform
62
+
63
+ # Hàm xử lý video
64
+ def process_video(video_path):
65
+ # Kiểm tra CUDA
66
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
67
+ st.info(f"Đang sử dụng thiết bị: {device}")
68
+
69
+ # Tải mô hình YOLO
70
+ with st.spinner('Đang tải mô hình YOLO...'):
71
+ yolo_model = load_yolo_model()
72
+ if device.type == 'cuda':
73
+ yolo_model.to(device)
74
+
75
+ # Tải mô hình độ sâu
76
+ with st.spinner('Đang tải mô hình độ sâu...'):
77
+ depth_model, transform = load_depth_model()
78
+ depth_model.eval()
79
+ if device.type == 'cuda':
80
+ depth_model.to(device)
81
+
82
+ # Mở video để xử lý
83
+ cap = cv2.VideoCapture(video_path)
84
+
85
+ # Lấy thuộc tính video cho đầu ra
86
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
87
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
88
+ fps = cap.get(cv2.CAP_PROP_FPS)
89
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
90
+
91
+ # Tạo tệp tạm thời cho video đầu ra
92
+ temp_output_dir = tempfile.mkdtemp()
93
+ output_video_path = os.path.join(temp_output_dir, "person_detection_with_depth.mp4")
94
+ output_depth_path = os.path.join(temp_output_dir, "depth_colormap.mp4")
95
+
96
+ # Sử dụng codec phù hợp với môi trường Hugging Face
97
+ fourcc = cv2.VideoWriter_fourcc(*'XVID') # Thay đổi từ mp4v sang XVID cho tương thích tốt hơn
98
+ out_detection = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
99
+ out_depth = cv2.VideoWriter(output_depth_path, fourcc, fps, (width, height))
100
+
101
+ # Ước tính chiều dài tiêu cự và chuyển đổi sang tensor
102
+ focal_length_px = torch.tensor([max(width, height)], device=device)
103
+
104
+ # Hiển thị thanh tiến trình
105
+ progress_bar = st.progress(0)
106
+ progress_text = st.empty()
107
+
108
+ frame_counter = 0
109
+ start_time = time.time()
110
+
111
+ # Tạo cột để hiển thị khung video
112
+ col1, col2 = st.columns(2)
113
+ detection_placeholder = col1.empty()
114
+ depth_placeholder = col2.empty()
115
+
116
+ # Giảm kích thước frame để tăng tốc độ xử lý
117
+ target_width = 640 # Kích thước đích
118
+ scale_factor = target_width / width
119
+ target_height = int(height * scale_factor)
120
+
121
+ try:
122
+ while cap.isOpened():
123
+ ret, frame = cap.read()
124
+ if not ret:
125
+ break
126
+
127
+ frame_counter += 1
128
+
129
+ # Cập nhật tiến trình
130
+ progress = int(frame_counter / total_frames * 100)
131
+ progress_bar.progress(progress)
132
+
133
+ if frame_counter % 10 == 0: # Hiển thị tiến trình mỗi 10 khung hình
134
+ elapsed_time = time.time() - start_time
135
+ frames_left = total_frames - frame_counter
136
+ est_time_left = (elapsed_time / frame_counter) * frames_left if frame_counter > 0 else 0
137
+ progress_text.text(f"Đang xử lý khung hình {frame_counter}/{total_frames} - Thời gian còn lại: {est_time_left:.2f}s")
138
+
139
+ # Giảm kích thước khung hình để tăng tốc xử lý
140
+ if scale_factor < 1:
141
+ frame_resized = cv2.resize(frame, (target_width, target_height))
142
+ else:
143
+ frame_resized = frame
144
+
145
+ # Phát hiện YOLO
146
+ results = yolo_model(frame_resized)
147
+
148
+ person_boxes = []
149
+ for result in results:
150
+ boxes = result.boxes.xyxy.cpu().numpy()
151
+ classes = result.boxes.cls.cpu().numpy()
152
+ confs = result.boxes.conf.cpu().numpy()
153
+
154
+ for box, cls, conf in zip(boxes, classes, confs):
155
+ if result.names[int(cls)] == "person" and conf > 0.5: # Thêm ngưỡng tin cậy
156
+ if scale_factor < 1: # Điều chỉnh lại khung giới hạn nếu đã thay đổi kích thước
157
+ x1, y1, x2, y2 = map(int, [box[0]/scale_factor, box[1]/scale_factor,
158
+ box[2]/scale_factor, box[3]/scale_factor])
159
+ else:
160
+ x1, y1, x2, y2 = map(int, box[:4])
161
+ person_boxes.append((x1, y1, x2, y2))
162
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
163
+
164
+ # Chuyển đổi khung hình cho đầu vào mô hình độ sâu
165
+ rgb_frame = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
166
+ pil_image = Image.fromarray(rgb_frame)
167
+ depth_input = transform(pil_image)
168
+
169
+ if device.type == 'cuda':
170
+ depth_input = depth_input.to(device)
171
+
172
+ # Ước tính độ sâu
173
+ with torch.no_grad():
174
+ predictions = depth_model.infer(depth_input, f_px=focal_length_px)
175
+ depth = predictions["depth"] # Độ sâu theo [m]
176
+
177
+ depth_np = depth.squeeze().cpu().numpy()
178
+
179
+ # Điều chỉnh lại kích thước bản đồ độ sâu
180
+ if scale_factor < 1:
181
+ depth_np = cv2.resize(depth_np, (width, height), interpolation=cv2.INTER_LINEAR)
182
+
183
+ # Tạo bản đồ màu độ sâu
184
+ depth_np_normalized = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min())
185
+ inv_depth_np_normalized = 1 - depth_np_normalized
186
+ depth_colormap = cv2.applyColorMap((inv_depth_np_normalized * 255).astype(np.uint8), cv2.COLORMAP_TURBO)
187
+
188
+ # Thêm giá trị độ sâu cho người được phát hiện
189
+ for x1, y1, x2, y2 in person_boxes:
190
+ center_x = (x1 + x2) // 2
191
+ center_y = (y1 + y2) // 2
192
+
193
+ # Đảm bảo tọa độ nằm trong giới hạn
194
+ center_x = min(center_x, depth_np.shape[1] - 1)
195
+ center_y = min(center_y, depth_np.shape[0] - 1)
196
+
197
+ depth_value = depth_np[center_y, center_x]
198
+
199
+ text = f"Độ sâu: {depth_value:.2f} m"
200
+ font = cv2.FONT_HERSHEY_SIMPLEX
201
+ font_scale = 0.8 # Giảm kích thước font để phù hợp
202
+ font_thickness = 2
203
+ text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0]
204
+
205
+ text_x = x1
206
+ text_y = y1 - 10
207
+ rect_x1 = text_x - 5
208
+ rect_y1 = text_y - text_size[1] - 10
209
+ rect_x2 = text_x + text_size[0] + 5
210
+ rect_y2 = text_y + 5
211
+
212
+ cv2.rectangle(frame, (rect_x1, rect_y1), (rect_x2, rect_y2), (0, 255, 0), -1)
213
+ cv2.putText(frame, text, (text_x, text_y), font, font_scale, (0, 0, 0), font_thickness)
214
+
215
+ # Hiển thị khung hình trong Streamlit (cập nhật mỗi 5 khung hình để tránh làm chậm)
216
+ if frame_counter % 5 == 0:
217
+ detection_placeholder.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), caption="Phát hiện người", use_column_width=True)
218
+ depth_placeholder.image(depth_colormap, caption="Bản đồ độ sâu", use_column_width=True)
219
+
220
+ # Ghi khung hình vào video đầu ra
221
+ out_detection.write(frame)
222
+ out_depth.write(depth_colormap)
223
+
224
+ finally:
225
+ # Giải phóng tài nguyên
226
+ cap.release()
227
+ out_detection.release()
228
+ out_depth.release()
229
+
230
+ total_time = time.time() - start_time
231
+ st.success(f"Xử lý hoàn tất! Tổng thời gian: {total_time:.2f}s")
232
+ st.success(f"FPS trung bình: {frame_counter / total_time:.2f}")
233
+
234
+ return output_video_path, output_depth_path
235
+
236
+ # Giao diện Streamlit chính
237
+ def main():
238
+ st.title("Ứng dụng Phát hiện Người và Ước tính Độ sâu")
239
+ st.write("Tải lên video để phát hiện người và hiển thị thông tin độ sâu")
240
+
241
+ # Tùy chọn video mẫu
242
+ st.sidebar.header("Tùy chọn")
243
+ use_sample = st.sidebar.checkbox("Sử dụng video mẫu")
244
+
245
+ video_path = None
246
+
247
+ if use_sample:
248
+ st.info("Đang sử dụng video mẫu...")
249
+ # URL của video mẫu (đặt URL video mẫu của bạn ở đây)
250
+ sample_video_url = "https://huggingface.co/spaces/Nupoor/SampleVideoDataset/resolve/main/pexels-richard-de-souza-1635985.mp4"
251
+
252
+ try:
253
+ # Tải video mẫu
254
+ response = requests.get(sample_video_url)
255
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
256
+ temp_file.write(response.content)
257
+ video_path = temp_file.name
258
+ temp_file.close()
259
+
260
+ st.video(video_path)
261
+ except Exception as e:
262
+ st.error(f"Không thể tải video mẫu: {e}")
263
+ video_path = None
264
+ else:
265
+ # Tải lên tệp video
266
+ uploaded_file = st.file_uploader("Chọn một tệp video", type=['mp4', 'avi', 'mov'])
267
+
268
+ if uploaded_file is not None:
269
+ # Lưu tệp đã tải lên vào thư mục tạm thời
270
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
271
+ temp_file.write(uploaded_file.read())
272
+ video_path = temp_file.name
273
+ temp_file.close()
274
+
275
+ st.video(video_path)
276
+
277
+ # Hiển thị thông tin về mô hình
278
+ st.sidebar.header("Thông tin mô hình")
279
+ st.sidebar.markdown("""
280
+ - Phát hiện người: YOLOv8n
281
+ - Ước tính độ sâu: GLPN-NYU từ HuggingFace
282
+ """)
283
+
284
+ # Thêm tùy chọn cho độ tin cậy phát hiện
285
+ confidence = st.sidebar.slider("Ngưỡng tin cậy", 0.0, 1.0, 0.5)
286
+
287
+ # Nút để bắt đầu xử lý
288
+ if video_path and st.button("Xử lý Video"):
289
+ with st.spinner("Đang xử lý video..."):
290
+ detection_video_path, depth_video_path = process_video(video_path)
291
+
292
+ # Hiển thị video đã xử lý
293
+ st.subheader("Video đã xử lý")
294
+
295
+ col1, col2 = st.columns(2)
296
+ with col1:
297
+ st.video(detection_video_path)
298
+ st.download_button(
299
+ label="Tải xuống video phát hiện",
300
+ data=open(detection_video_path, 'rb').read(),
301
+ file_name="person_detection_with_depth.mp4",
302
+ mime="video/mp4"
303
+ )
304
+
305
+ with col2:
306
+ st.video(depth_video_path)
307
+ st.download_button(
308
+ label="Tải xuống bản đồ độ sâu",
309
+ data=open(depth_video_path, 'rb').read(),
310
+ file_name="depth_colormap.mp4",
311
+ mime="video/mp4"
312
+ )
313
+
314
+ # Xóa tệp tạm thời
315
+ os.unlink(video_path)
316
+
317
+ # Tệp requirements.txt
318
+ def create_requirements():
319
+ requirements = """
320
+ streamlit
321
+ numpy
322
+ Pillow
323
+ opencv-python
324
+ torch
325
+ torchvision
326
+ transformers
327
+ ultralytics
328
+ requests
329
+ """
330
+ return requirements
331
+
332
+ if __name__ == "__main__":
333
+ main()