Spaces:
Sleeping
Sleeping
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import json | |
| import os | |
| from datetime import datetime | |
| from ultralytics import YOLO | |
| from insightface.app import FaceAnalysis | |
| import torchreid | |
| import torch | |
| import logging | |
| import shutil | |
| import tempfile | |
| import uuid | |
| # ========== Logging Configuration ========== | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='[%(asctime)s] [%(levelname)s] %(message)s', | |
| handlers=[ | |
| logging.FileHandler("app.log"), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ========== Configuration ========== | |
| DETECTION_THRESHOLD = 0.75 | |
| # Create output directory for Gradio | |
| OUTPUT_DIR = os.path.join(os.getcwd(), "outputs") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # ========== Video Processing Class ========== | |
| class VideoProcessor: | |
| def __init__(self): | |
| try: | |
| self.model = YOLO('detection.pt') | |
| self.face_app = FaceAnalysis(name='buffalo_l', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
| self.face_app.prepare(ctx_id=0) | |
| self.reid_extractor = torchreid.utils.FeatureExtractor( | |
| model_name='osnet_x0_25', | |
| model_path=None, | |
| device='cuda' if torch.cuda.is_available() else 'cpu' | |
| ) | |
| self.models_loaded = True | |
| logger.info("Models loaded successfully.") | |
| except Exception as e: | |
| logger.exception("Model loading failed.") | |
| self.models_loaded = False | |
| self.reset_tracking() | |
| def reset_tracking(self): | |
| self.known_embeddings = [] | |
| self.known_ids = [] | |
| self.next_global_id = 1 | |
| self.track_to_global = {} | |
| self.tracking_data = { | |
| "metadata": { | |
| "total_frames": 0, | |
| "total_people": 0, | |
| "id_mapping": {} | |
| }, | |
| "frames": [] | |
| } | |
| logger.info("Tracking state reset.") | |
| def extract_embeddings(self, person_crop): | |
| face_embedding, body_embedding = None, None | |
| try: | |
| faces = self.face_app.get(person_crop) | |
| if faces: | |
| face_embedding = faces[0].embedding | |
| except Exception: | |
| logger.debug("Face embedding failed.") | |
| try: | |
| body_input = cv2.resize(person_crop, (128, 256)) | |
| body_input = cv2.cvtColor(body_input, cv2.COLOR_BGR2RGB) | |
| body_embedding = self.reid_extractor(body_input)[0].cpu().numpy() | |
| except Exception: | |
| logger.debug("Body embedding failed.") | |
| if face_embedding is not None and body_embedding is not None: | |
| return np.concatenate((face_embedding, body_embedding)).astype(np.float32) | |
| elif face_embedding is not None: | |
| return face_embedding.astype(np.float32) | |
| elif body_embedding is not None: | |
| return body_embedding.astype(np.float32) | |
| return None | |
| def assign_global_id(self, embedding, track_id): | |
| if embedding is None: | |
| return self.track_to_global.get(track_id, f"T{track_id}") | |
| match_found = False | |
| if self.known_embeddings: | |
| matching_embeddings = [ | |
| (emb, gid) for emb, gid in zip(self.known_embeddings, self.known_ids) | |
| if emb.shape[0] == embedding.shape[0] | |
| ] | |
| if matching_embeddings: | |
| embs, gids = zip(*matching_embeddings) | |
| embs = np.array(embs) | |
| sims = np.dot(embs, embedding) / ( | |
| np.linalg.norm(embs, axis=1) * np.linalg.norm(embedding) + 1e-6 | |
| ) | |
| best_match = np.argmax(sims) | |
| if sims[best_match] > 0.6: | |
| global_id = gids[best_match] | |
| match_found = True | |
| if not match_found: | |
| global_id = self.next_global_id | |
| self.next_global_id += 1 | |
| self.known_embeddings.append(embedding) | |
| self.known_ids.append(global_id) | |
| if track_id is not None: | |
| self.track_to_global[track_id] = global_id | |
| return global_id | |
| def process_video(self, input_video_path, progress_callback=None): | |
| if not self.models_loaded: | |
| raise Exception("Models not loaded properly") | |
| self.reset_tracking() | |
| # Create output files with timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| unique_id = str(uuid.uuid4())[:8] | |
| # Use the OUTPUT_DIR instead of temp directory | |
| output_video_path = os.path.join(OUTPUT_DIR, f"tracked_video_{timestamp}_{unique_id}.mp4") | |
| output_json_path = os.path.join(OUTPUT_DIR, f"tracking_data_{timestamp}_{unique_id}.json") | |
| cap = cv2.VideoCapture(input_video_path) | |
| if not cap.isOpened(): | |
| raise Exception("Could not open video file") | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Use H.264 codec for better compatibility and add proper video codec | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Changed from 'mp4v' to 'H264' | |
| out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) | |
| # Verify video writer is properly initialized | |
| if not out.isOpened(): | |
| logger.warning("H264 codec failed, trying XVID") | |
| fourcc = cv2.VideoWriter_fourcc(*'XVID') | |
| output_video_path = output_video_path.replace('.mp4', '.avi') | |
| out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) | |
| if not out.isOpened(): | |
| logger.warning("XVID codec failed, trying mp4v") | |
| fourcc = cv2.VideoWriter_fourcc(*'H264') | |
| output_video_path = output_video_path.replace('.avi', '.mp4') | |
| out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) | |
| frame_count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_count += 1 | |
| if progress_callback: | |
| progress_callback(frame_count / total_frames, f"Processing frame {frame_count}/{total_frames}") | |
| frame_data = {"frame": frame_count, "people": []} | |
| try: | |
| results = self.model.track( | |
| frame, tracker="bytetrack.yaml", persist=True, verbose=False, conf=DETECTION_THRESHOLD | |
| ) | |
| for result in results: | |
| if result.boxes is not None: | |
| boxes = result.boxes.xyxy.cpu().numpy() | |
| confidences = result.boxes.conf.cpu().numpy() | |
| track_ids = result.boxes.id.int().cpu().tolist() if result.boxes.id is not None else [None] * len(boxes) | |
| for box, conf, track_id in zip(boxes, confidences, track_ids): | |
| x1, y1, x2, y2 = map(int, box) | |
| person_crop = frame[y1:y2, x1:x2] | |
| if person_crop.size > 0: | |
| embedding = self.extract_embeddings(person_crop) | |
| global_id = self.assign_global_id(embedding, track_id) | |
| frame_data["people"].append({ | |
| "person_id": global_id, | |
| "center_x": (x1 + x2) / 2, | |
| "center_y": (y1 + y2) / 2, | |
| "confidence": float(conf), | |
| "bbox": {"x1": float(x1), "y1": float(y1), "x2": float(x2), "y2": float(y2)} | |
| }) | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| cv2.putText(frame, f"ID {global_id}", (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) | |
| except Exception as e: | |
| logger.exception(f"Error processing frame {frame_count}.") | |
| self.tracking_data["frames"].append(frame_data) | |
| out.write(frame) | |
| cap.release() | |
| out.release() | |
| # Verify the output file was created and has content | |
| if not os.path.exists(output_video_path) or os.path.getsize(output_video_path) == 0: | |
| raise Exception("Output video file was not created properly") | |
| self.tracking_data["metadata"]["total_frames"] = frame_count | |
| self.tracking_data["metadata"]["total_people"] = len(set(self.known_ids)) | |
| self.tracking_data["metadata"]["id_mapping"] = {str(k): v for k, v in self.track_to_global.items()} | |
| # Save JSON file | |
| with open(output_json_path, 'w') as f: | |
| json.dump(self.tracking_data, f, indent=2) | |
| logger.info(f"Video processing completed. Saved to {output_video_path}") | |
| logger.info(f"Video file size: {os.path.getsize(output_video_path)} bytes") | |
| return output_video_path, output_json_path | |
| # ========== Processor ========== | |
| processor = VideoProcessor() | |
| # ========== Gradio Handler ========== | |
| def process_video_gradio(input_video, progress=gr.Progress()): | |
| if input_video is None: | |
| return None, None, "Please upload a video file." | |
| try: | |
| def progress_callback(prog, message): | |
| progress(prog, desc=message) | |
| # Process video | |
| output_video_path, output_json_path = processor.process_video(input_video, progress_callback) | |
| # Verify files exist and are accessible | |
| if not os.path.exists(output_video_path): | |
| raise Exception(f"Output video not found at {output_video_path}") | |
| if not os.path.exists(output_json_path): | |
| raise Exception(f"Output JSON not found at {output_json_path}") | |
| # Read tracking data for stats | |
| with open(output_json_path, 'r') as f: | |
| data = json.load(f) | |
| stats = f""" | |
| **Processing Complete!** β | |
| - **Total Frames Processed:** {data['metadata']['total_frames']} | |
| - **Total People Detected:** {data['metadata']['total_people']} | |
| - **Unique IDs Assigned:** {len(data['metadata']['id_mapping'])} | |
| - **Output Video Size:** {os.path.getsize(output_video_path) / (1024*1024):.1f} MB | |
| πΉ **Output video** is ready for download | |
| π **JSON tracking data** contains frame-by-frame detection results | |
| """ | |
| logger.info(f"Returning video path: {output_video_path}") | |
| logger.info(f"Video exists: {os.path.exists(output_video_path)}") | |
| return output_video_path, output_json_path, stats | |
| except Exception as e: | |
| logger.exception("Video processing failed.") | |
| return None, None, f"β **Error processing video:** {str(e)}" | |
| # ========== Gradio Interface ========== | |
| def create_interface(): | |
| with gr.Blocks(title="Video Person Detection & Tracking", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π₯ Video Person Detection & Tracking with ReID") | |
| gr.Markdown("Upload a video to detect and track people using YOLOv8, InsightFace, and ReID models for consistent person identification across frames.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_video = gr.Video( | |
| label="π Upload Input Video", | |
| height=400, | |
| interactive=True | |
| ) | |
| process_btn = gr.Button( | |
| "π Process Video", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| output_video = gr.Video( | |
| label="π¬ Processed Video (with tracking)", | |
| height=400, | |
| interactive=False, | |
| show_download_button=True # Enable download button | |
| ) | |
| download_json = gr.File( | |
| label="π Download Tracking Data (JSON)", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| status_text = gr.Markdown("π€ Upload a video and click **'Process Video'** to start tracking people.") | |
| # Event handler | |
| process_btn.click( | |
| fn=process_video_gradio, | |
| inputs=[input_video], | |
| outputs=[output_video, download_json, status_text], | |
| show_progress=True | |
| ) | |
| # Additional information | |
| with gr.Accordion("π How it works", open=False): | |
| gr.Markdown(""" | |
| ### π§ **Technology Stack:** | |
| - **YOLOv8:** Real-time person detection | |
| - **ByteTrack:** Multi-object tracking algorithm | |
| - **InsightFace:** Facial feature extraction for person identification | |
| - **OSNet:** Full-body re-identification features | |
| ### π **Process:** | |
| 1. **Detection:** YOLOv8 detects people in each frame | |
| 2. **Tracking:** ByteTrack assigns temporary tracking IDs | |
| 3. **Feature Extraction:** InsightFace + OSNet extract identifying features | |
| 4. **Re-identification:** Combines face and body features for consistent global IDs | |
| 5. **Output:** Generates annotated video + detailed JSON tracking data | |
| ### π **Supported Formats:** | |
| - **Input:** MP4, AVI, MOV, WEBM | |
| - **Output:** MP4 video + JSON metadata | |
| """) | |
| with gr.Accordion("βοΈ Model Configuration", open=False): | |
| gr.Markdown(f""" | |
| - **Detection Threshold:** {DETECTION_THRESHOLD} | |
| - **Similarity Threshold:** 0.6 (for person re-identification) | |
| - **Device:** {"CUDA" if torch.cuda.is_available() else "CPU"} | |
| - **Output Directory:** {OUTPUT_DIR} | |
| """) | |
| with gr.Accordion("π§ Troubleshooting", open=False): | |
| gr.Markdown(""" | |
| **If video doesn't display:** | |
| 1. Check if the output file exists in the outputs directory | |
| 2. Try downloading the video manually | |
| 3. Ensure proper video codec support | |
| **Common issues:** | |
| - Large video files may take time to load | |
| - Some browsers may not support certain video formats | |
| - Network issues can affect video streaming | |
| """) | |
| return demo | |
| # ========== Launch ========== | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| # Add file serving for outputs directory | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| share=True | |
| ) |