Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- .gitignore +4 -0
- ReadMe.md +221 -0
- app.py +371 -0
- detection.pt +3 -0
- detection.py +215 -0
- requirements.txt +14 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.gradio
|
| 2 |
+
*.mp4
|
| 3 |
+
*.json
|
| 4 |
+
*.log
|
ReadMe.md
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎥 Video Person Detection & Tracking with ReID
|
| 2 |
+
|
| 3 |
+
A sophisticated computer vision application that combines YOLOv8, InsightFace, and TorchReID for robust person detection, tracking, and re-identification in videos. The application provides a user-friendly Gradio interface for easy video processing.
|
| 4 |
+
|
| 5 |
+
## 🔧 Technology Stack
|
| 6 |
+
|
| 7 |
+
- **YOLOv8**: Real-time person detection
|
| 8 |
+
- **ByteTrack**: Multi-object tracking algorithm
|
| 9 |
+
- **InsightFace**: Facial feature extraction for person identification
|
| 10 |
+
- **OSNet**: Full-body re-identification features
|
| 11 |
+
- **Gradio**: Web-based user interface
|
| 12 |
+
|
| 13 |
+
## 📋 Features
|
| 14 |
+
|
| 15 |
+
- Real-time person detection and tracking
|
| 16 |
+
- Consistent person re-identification across frames
|
| 17 |
+
- Face and body feature extraction
|
| 18 |
+
- Interactive web interface
|
| 19 |
+
- JSON export of tracking data
|
| 20 |
+
- Support for multiple video formats
|
| 21 |
+
|
| 22 |
+
## 🚀 Quick Start
|
| 23 |
+
|
| 24 |
+
### Prerequisites
|
| 25 |
+
|
| 26 |
+
**System Requirements:**
|
| 27 |
+
- Python 3.8 or higher
|
| 28 |
+
- CUDA-compatible GPU (recommended for better performance)
|
| 29 |
+
- At least 4GB RAM
|
| 30 |
+
- 2GB free disk space
|
| 31 |
+
|
| 32 |
+
**Platform-Specific Dependencies:**
|
| 33 |
+
|
| 34 |
+
**Linux:**
|
| 35 |
+
```bash
|
| 36 |
+
# Install g++ compiler (required for InsightFace)
|
| 37 |
+
sudo apt-get update
|
| 38 |
+
sudo apt-get install g++ build-essential
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
**Windows:**
|
| 42 |
+
- Install [Microsoft Visual C++ Redistributable](https://aka.ms/vs/17/release/vc_redist.x64.exe) (latest version)
|
| 43 |
+
- Ensure you have Visual Studio Build Tools or Visual Studio Community installed
|
| 44 |
+
|
| 45 |
+
**macOS:**
|
| 46 |
+
```bash
|
| 47 |
+
# Install Xcode command line tools
|
| 48 |
+
xcode-select --install
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Installation
|
| 52 |
+
|
| 53 |
+
1. **Clone the repository:**
|
| 54 |
+
```bash
|
| 55 |
+
git clone git@gitlab.com:zebshah7851/object-detection-and-tracking.git
|
| 56 |
+
cd video-person-tracking
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
2. **Create a virtual environment:**
|
| 60 |
+
```bash
|
| 61 |
+
python -m venv venv
|
| 62 |
+
|
| 63 |
+
# Activate virtual environment
|
| 64 |
+
# On Windows:
|
| 65 |
+
venv\Scripts\activate
|
| 66 |
+
# On Linux/macOS:
|
| 67 |
+
source venv/bin/activate
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
3. **Install dependencies:**
|
| 71 |
+
```bash
|
| 72 |
+
pip install --upgrade pip
|
| 73 |
+
pip install -r requirements.txt
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
**Note:** The installation process may take 10-15 minutes due to large model downloads (PyTorch, CUDA libraries, etc.).
|
| 77 |
+
|
| 78 |
+
### Model Setup
|
| 79 |
+
|
| 80 |
+
The application requires several pre-trained models:
|
| 81 |
+
|
| 82 |
+
1. **YOLOv8 Detection Model:**
|
| 83 |
+
- Place your trained `detection.pt` model file in the project root directory
|
| 84 |
+
- Alternatively, the app will download a default YOLOv8 model on first run
|
| 85 |
+
|
| 86 |
+
2. **InsightFace Model:**
|
| 87 |
+
- The `buffalo_l` model will be automatically downloaded on first run
|
| 88 |
+
- Requires ~2GB of storage space
|
| 89 |
+
|
| 90 |
+
3. **TorchReID Model:**
|
| 91 |
+
- The `osnet_x0_25` model will be automatically downloaded
|
| 92 |
+
- Pre-trained on Market1501 dataset
|
| 93 |
+
|
| 94 |
+
### Running the Application
|
| 95 |
+
|
| 96 |
+
1. **Start the Gradio interface:**
|
| 97 |
+
```bash
|
| 98 |
+
python app.py
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
2. **Access the web interface:**
|
| 102 |
+
- Open your browser and navigate to: `http://127.0.0.1:7860`
|
| 103 |
+
- The interface will load automatically
|
| 104 |
+
|
| 105 |
+
3. **Process videos:**
|
| 106 |
+
- Upload a video file (MP4, AVI, MOV, WEBM)
|
| 107 |
+
- Click "🚀 Process Video"
|
| 108 |
+
- Download the processed video and tracking data
|
| 109 |
+
|
| 110 |
+
## 📁 Project Structure
|
| 111 |
+
|
| 112 |
+
```
|
| 113 |
+
video-person-tracking/
|
| 114 |
+
├── app.py # Gradio web interface
|
| 115 |
+
├── detection.py # Core detection script
|
| 116 |
+
├── requirements.txt # Python dependencies
|
| 117 |
+
├── README.md # This file
|
| 118 |
+
├── outputs/ # Generated output files
|
| 119 |
+
├── detection.pt # YOLOv8 model to detect persons
|
| 120 |
+
└── logs/ # Application logs
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
## 🔧 Configuration
|
| 124 |
+
|
| 125 |
+
### Model Parameters
|
| 126 |
+
|
| 127 |
+
You can adjust the following parameters in `app.py`:
|
| 128 |
+
|
| 129 |
+
```python
|
| 130 |
+
DETECTION_THRESHOLD = 0.75 # Person detection confidence threshold
|
| 131 |
+
SIMILARITY_THRESHOLD = 0.6 # Person re-identification threshold
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
### Performance Optimization
|
| 135 |
+
|
| 136 |
+
**For GPU acceleration:**
|
| 137 |
+
- Ensure CUDA is properly installed
|
| 138 |
+
- The application automatically detects and uses GPU if available
|
| 139 |
+
- Monitor GPU memory usage for large videos
|
| 140 |
+
|
| 141 |
+
**For CPU-only systems:**
|
| 142 |
+
- Reduce video resolution before processing
|
| 143 |
+
- Process shorter video segments
|
| 144 |
+
- Expect longer processing times
|
| 145 |
+
|
| 146 |
+
## 📊 Output Format
|
| 147 |
+
|
| 148 |
+
### Processed Video
|
| 149 |
+
- Annotated video with bounding boxes
|
| 150 |
+
- Consistent person IDs across frames
|
| 151 |
+
- Real-time tracking visualization
|
| 152 |
+
|
| 153 |
+
### JSON Tracking Data
|
| 154 |
+
```json
|
| 155 |
+
{
|
| 156 |
+
"metadata": {
|
| 157 |
+
"total_frames": 1500,
|
| 158 |
+
"total_people": 5,
|
| 159 |
+
"id_mapping": {...}
|
| 160 |
+
},
|
| 161 |
+
"frames": [
|
| 162 |
+
{
|
| 163 |
+
"frame": 1,
|
| 164 |
+
"people": [
|
| 165 |
+
{
|
| 166 |
+
"person_id": 1,
|
| 167 |
+
"center_x": 320.5,
|
| 168 |
+
"center_y": 240.0,
|
| 169 |
+
"confidence": 0.85,
|
| 170 |
+
"bbox": {"x1": 100, "y1": 50, "x2": 200, "y2": 300}
|
| 171 |
+
}
|
| 172 |
+
]
|
| 173 |
+
}
|
| 174 |
+
]
|
| 175 |
+
}
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
## 🐛 Troubleshooting
|
| 179 |
+
|
| 180 |
+
### Common Issues
|
| 181 |
+
|
| 182 |
+
**Installation Problems:**
|
| 183 |
+
|
| 184 |
+
1. **InsightFace installation fails:**
|
| 185 |
+
```bash
|
| 186 |
+
# Try installing with specific version
|
| 187 |
+
pip install insightface==0.7.3
|
| 188 |
+
pip install onnxruntime-gpu==1.14.1
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
If you running linux, you need to install g++. If running on windows, you will need to install latest Visual C++ Redistributions.
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
2. **Model download issues:**
|
| 195 |
+
- Check internet connection
|
| 196 |
+
- Manually download models if automatic download fails
|
| 197 |
+
- Ensure sufficient disk space
|
| 198 |
+
|
| 199 |
+
**Runtime Issues:**
|
| 200 |
+
|
| 201 |
+
1. **Video won't load in browser:**
|
| 202 |
+
- Try downloading the output video manually
|
| 203 |
+
- Check browser compatibility
|
| 204 |
+
- Clear browser cache
|
| 205 |
+
|
| 206 |
+
2. **Slow processing:**
|
| 207 |
+
- Use GPU acceleration if available
|
| 208 |
+
- Reduce detection threshold
|
| 209 |
+
- Process shorter video segments
|
| 210 |
+
|
| 211 |
+
3. **High memory usage:**
|
| 212 |
+
- Monitor system resources
|
| 213 |
+
- Close unnecessary applications
|
| 214 |
+
- Use smaller input videos
|
| 215 |
+
|
| 216 |
+
## 📝 System Requirements
|
| 217 |
+
|
| 218 |
+
- **CPU:** Intel i5 or AMD Ryzen 5 (4 cores)
|
| 219 |
+
- **RAM:** 8GB
|
| 220 |
+
- **Storage:** 5GB free space
|
| 221 |
+
- **GPU:** Optional, but recommended for faster processing
|
app.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.filterwarnings("ignore")
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from ultralytics import YOLO
|
| 11 |
+
from insightface.app import FaceAnalysis
|
| 12 |
+
import torchreid
|
| 13 |
+
import torch
|
| 14 |
+
import logging
|
| 15 |
+
import shutil
|
| 16 |
+
import tempfile
|
| 17 |
+
import uuid
|
| 18 |
+
|
| 19 |
+
# ========== Logging Configuration ==========
|
| 20 |
+
logging.basicConfig(
|
| 21 |
+
level=logging.INFO,
|
| 22 |
+
format='[%(asctime)s] [%(levelname)s] %(message)s',
|
| 23 |
+
handlers=[
|
| 24 |
+
logging.FileHandler("app.log"),
|
| 25 |
+
logging.StreamHandler()
|
| 26 |
+
]
|
| 27 |
+
)
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
# ========== Configuration ==========
|
| 31 |
+
DETECTION_THRESHOLD = 0.75
|
| 32 |
+
|
| 33 |
+
# Create output directory for Gradio
|
| 34 |
+
OUTPUT_DIR = os.path.join(os.getcwd(), "outputs")
|
| 35 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
# ========== Video Processing Class ==========
|
| 38 |
+
class VideoProcessor:
|
| 39 |
+
def __init__(self):
|
| 40 |
+
try:
|
| 41 |
+
self.model = YOLO('detection.pt')
|
| 42 |
+
self.face_app = FaceAnalysis(name='buffalo_l', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
| 43 |
+
self.face_app.prepare(ctx_id=0)
|
| 44 |
+
self.reid_extractor = torchreid.utils.FeatureExtractor(
|
| 45 |
+
model_name='osnet_x0_25',
|
| 46 |
+
model_path=None,
|
| 47 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
| 48 |
+
)
|
| 49 |
+
self.models_loaded = True
|
| 50 |
+
logger.info("Models loaded successfully.")
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.exception("Model loading failed.")
|
| 53 |
+
self.models_loaded = False
|
| 54 |
+
self.reset_tracking()
|
| 55 |
+
|
| 56 |
+
def reset_tracking(self):
|
| 57 |
+
self.known_embeddings = []
|
| 58 |
+
self.known_ids = []
|
| 59 |
+
self.next_global_id = 1
|
| 60 |
+
self.track_to_global = {}
|
| 61 |
+
self.tracking_data = {
|
| 62 |
+
"metadata": {
|
| 63 |
+
"total_frames": 0,
|
| 64 |
+
"total_people": 0,
|
| 65 |
+
"id_mapping": {}
|
| 66 |
+
},
|
| 67 |
+
"frames": []
|
| 68 |
+
}
|
| 69 |
+
logger.info("Tracking state reset.")
|
| 70 |
+
|
| 71 |
+
def extract_embeddings(self, person_crop):
|
| 72 |
+
face_embedding, body_embedding = None, None
|
| 73 |
+
try:
|
| 74 |
+
faces = self.face_app.get(person_crop)
|
| 75 |
+
if faces:
|
| 76 |
+
face_embedding = faces[0].embedding
|
| 77 |
+
except Exception:
|
| 78 |
+
logger.debug("Face embedding failed.")
|
| 79 |
+
try:
|
| 80 |
+
body_input = cv2.resize(person_crop, (128, 256))
|
| 81 |
+
body_input = cv2.cvtColor(body_input, cv2.COLOR_BGR2RGB)
|
| 82 |
+
body_embedding = self.reid_extractor(body_input)[0].cpu().numpy()
|
| 83 |
+
except Exception:
|
| 84 |
+
logger.debug("Body embedding failed.")
|
| 85 |
+
|
| 86 |
+
if face_embedding is not None and body_embedding is not None:
|
| 87 |
+
return np.concatenate((face_embedding, body_embedding)).astype(np.float32)
|
| 88 |
+
elif face_embedding is not None:
|
| 89 |
+
return face_embedding.astype(np.float32)
|
| 90 |
+
elif body_embedding is not None:
|
| 91 |
+
return body_embedding.astype(np.float32)
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
def assign_global_id(self, embedding, track_id):
|
| 95 |
+
if embedding is None:
|
| 96 |
+
return self.track_to_global.get(track_id, f"T{track_id}")
|
| 97 |
+
match_found = False
|
| 98 |
+
if self.known_embeddings:
|
| 99 |
+
matching_embeddings = [
|
| 100 |
+
(emb, gid) for emb, gid in zip(self.known_embeddings, self.known_ids)
|
| 101 |
+
if emb.shape[0] == embedding.shape[0]
|
| 102 |
+
]
|
| 103 |
+
if matching_embeddings:
|
| 104 |
+
embs, gids = zip(*matching_embeddings)
|
| 105 |
+
embs = np.array(embs)
|
| 106 |
+
sims = np.dot(embs, embedding) / (
|
| 107 |
+
np.linalg.norm(embs, axis=1) * np.linalg.norm(embedding) + 1e-6
|
| 108 |
+
)
|
| 109 |
+
best_match = np.argmax(sims)
|
| 110 |
+
if sims[best_match] > 0.6:
|
| 111 |
+
global_id = gids[best_match]
|
| 112 |
+
match_found = True
|
| 113 |
+
if not match_found:
|
| 114 |
+
global_id = self.next_global_id
|
| 115 |
+
self.next_global_id += 1
|
| 116 |
+
self.known_embeddings.append(embedding)
|
| 117 |
+
self.known_ids.append(global_id)
|
| 118 |
+
if track_id is not None:
|
| 119 |
+
self.track_to_global[track_id] = global_id
|
| 120 |
+
return global_id
|
| 121 |
+
|
| 122 |
+
def process_video(self, input_video_path, progress_callback=None):
|
| 123 |
+
if not self.models_loaded:
|
| 124 |
+
raise Exception("Models not loaded properly")
|
| 125 |
+
|
| 126 |
+
self.reset_tracking()
|
| 127 |
+
|
| 128 |
+
# Create output files with timestamp
|
| 129 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 130 |
+
unique_id = str(uuid.uuid4())[:8]
|
| 131 |
+
|
| 132 |
+
# Use the OUTPUT_DIR instead of temp directory
|
| 133 |
+
output_video_path = os.path.join(OUTPUT_DIR, f"tracked_video_{timestamp}_{unique_id}.mp4")
|
| 134 |
+
output_json_path = os.path.join(OUTPUT_DIR, f"tracking_data_{timestamp}_{unique_id}.json")
|
| 135 |
+
|
| 136 |
+
cap = cv2.VideoCapture(input_video_path)
|
| 137 |
+
if not cap.isOpened():
|
| 138 |
+
raise Exception("Could not open video file")
|
| 139 |
+
|
| 140 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 141 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 142 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 143 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 144 |
+
|
| 145 |
+
# Use H.264 codec for better compatibility and add proper video codec
|
| 146 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Changed from 'mp4v' to 'H264'
|
| 147 |
+
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
| 148 |
+
|
| 149 |
+
# Verify video writer is properly initialized
|
| 150 |
+
if not out.isOpened():
|
| 151 |
+
logger.warning("H264 codec failed, trying XVID")
|
| 152 |
+
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
| 153 |
+
output_video_path = output_video_path.replace('.mp4', '.avi')
|
| 154 |
+
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
| 155 |
+
|
| 156 |
+
if not out.isOpened():
|
| 157 |
+
logger.warning("XVID codec failed, trying mp4v")
|
| 158 |
+
fourcc = cv2.VideoWriter_fourcc(*'H264')
|
| 159 |
+
output_video_path = output_video_path.replace('.avi', '.mp4')
|
| 160 |
+
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
| 161 |
+
|
| 162 |
+
frame_count = 0
|
| 163 |
+
|
| 164 |
+
while True:
|
| 165 |
+
ret, frame = cap.read()
|
| 166 |
+
if not ret:
|
| 167 |
+
break
|
| 168 |
+
|
| 169 |
+
frame_count += 1
|
| 170 |
+
if progress_callback:
|
| 171 |
+
progress_callback(frame_count / total_frames, f"Processing frame {frame_count}/{total_frames}")
|
| 172 |
+
|
| 173 |
+
frame_data = {"frame": frame_count, "people": []}
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
results = self.model.track(
|
| 177 |
+
frame, tracker="bytetrack.yaml", persist=True, verbose=False, conf=DETECTION_THRESHOLD
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
for result in results:
|
| 181 |
+
if result.boxes is not None:
|
| 182 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
| 183 |
+
confidences = result.boxes.conf.cpu().numpy()
|
| 184 |
+
track_ids = result.boxes.id.int().cpu().tolist() if result.boxes.id is not None else [None] * len(boxes)
|
| 185 |
+
|
| 186 |
+
for box, conf, track_id in zip(boxes, confidences, track_ids):
|
| 187 |
+
x1, y1, x2, y2 = map(int, box)
|
| 188 |
+
person_crop = frame[y1:y2, x1:x2]
|
| 189 |
+
if person_crop.size > 0:
|
| 190 |
+
embedding = self.extract_embeddings(person_crop)
|
| 191 |
+
global_id = self.assign_global_id(embedding, track_id)
|
| 192 |
+
|
| 193 |
+
frame_data["people"].append({
|
| 194 |
+
"person_id": global_id,
|
| 195 |
+
"center_x": (x1 + x2) / 2,
|
| 196 |
+
"center_y": (y1 + y2) / 2,
|
| 197 |
+
"confidence": float(conf),
|
| 198 |
+
"bbox": {"x1": float(x1), "y1": float(y1), "x2": float(x2), "y2": float(y2)}
|
| 199 |
+
})
|
| 200 |
+
|
| 201 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 202 |
+
cv2.putText(frame, f"ID {global_id}", (x1, y1 - 10),
|
| 203 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.exception(f"Error processing frame {frame_count}.")
|
| 206 |
+
|
| 207 |
+
self.tracking_data["frames"].append(frame_data)
|
| 208 |
+
out.write(frame)
|
| 209 |
+
|
| 210 |
+
cap.release()
|
| 211 |
+
out.release()
|
| 212 |
+
|
| 213 |
+
# Verify the output file was created and has content
|
| 214 |
+
if not os.path.exists(output_video_path) or os.path.getsize(output_video_path) == 0:
|
| 215 |
+
raise Exception("Output video file was not created properly")
|
| 216 |
+
|
| 217 |
+
self.tracking_data["metadata"]["total_frames"] = frame_count
|
| 218 |
+
self.tracking_data["metadata"]["total_people"] = len(set(self.known_ids))
|
| 219 |
+
self.tracking_data["metadata"]["id_mapping"] = {str(k): v for k, v in self.track_to_global.items()}
|
| 220 |
+
|
| 221 |
+
# Save JSON file
|
| 222 |
+
with open(output_json_path, 'w') as f:
|
| 223 |
+
json.dump(self.tracking_data, f, indent=2)
|
| 224 |
+
|
| 225 |
+
logger.info(f"Video processing completed. Saved to {output_video_path}")
|
| 226 |
+
logger.info(f"Video file size: {os.path.getsize(output_video_path)} bytes")
|
| 227 |
+
|
| 228 |
+
return output_video_path, output_json_path
|
| 229 |
+
|
| 230 |
+
# ========== Processor ==========
|
| 231 |
+
processor = VideoProcessor()
|
| 232 |
+
|
| 233 |
+
# ========== Gradio Handler ==========
|
| 234 |
+
def process_video_gradio(input_video, progress=gr.Progress()):
|
| 235 |
+
if input_video is None:
|
| 236 |
+
return None, None, "Please upload a video file."
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
def progress_callback(prog, message):
|
| 240 |
+
progress(prog, desc=message)
|
| 241 |
+
|
| 242 |
+
# Process video
|
| 243 |
+
output_video_path, output_json_path = processor.process_video(input_video, progress_callback)
|
| 244 |
+
|
| 245 |
+
# Verify files exist and are accessible
|
| 246 |
+
if not os.path.exists(output_video_path):
|
| 247 |
+
raise Exception(f"Output video not found at {output_video_path}")
|
| 248 |
+
if not os.path.exists(output_json_path):
|
| 249 |
+
raise Exception(f"Output JSON not found at {output_json_path}")
|
| 250 |
+
|
| 251 |
+
# Read tracking data for stats
|
| 252 |
+
with open(output_json_path, 'r') as f:
|
| 253 |
+
data = json.load(f)
|
| 254 |
+
|
| 255 |
+
stats = f"""
|
| 256 |
+
**Processing Complete!** ✅
|
| 257 |
+
|
| 258 |
+
- **Total Frames Processed:** {data['metadata']['total_frames']}
|
| 259 |
+
- **Total People Detected:** {data['metadata']['total_people']}
|
| 260 |
+
- **Unique IDs Assigned:** {len(data['metadata']['id_mapping'])}
|
| 261 |
+
- **Output Video Size:** {os.path.getsize(output_video_path) / (1024*1024):.1f} MB
|
| 262 |
+
|
| 263 |
+
📹 **Output video** is ready for download
|
| 264 |
+
📄 **JSON tracking data** contains frame-by-frame detection results
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
logger.info(f"Returning video path: {output_video_path}")
|
| 268 |
+
logger.info(f"Video exists: {os.path.exists(output_video_path)}")
|
| 269 |
+
|
| 270 |
+
return output_video_path, output_json_path, stats
|
| 271 |
+
|
| 272 |
+
except Exception as e:
|
| 273 |
+
logger.exception("Video processing failed.")
|
| 274 |
+
return None, None, f"❌ **Error processing video:** {str(e)}"
|
| 275 |
+
|
| 276 |
+
# ========== Gradio Interface ==========
|
| 277 |
+
def create_interface():
|
| 278 |
+
with gr.Blocks(title="Video Person Detection & Tracking", theme=gr.themes.Soft()) as demo:
|
| 279 |
+
gr.Markdown("# 🎥 Video Person Detection & Tracking with ReID")
|
| 280 |
+
gr.Markdown("Upload a video to detect and track people using YOLOv8, InsightFace, and ReID models for consistent person identification across frames.")
|
| 281 |
+
|
| 282 |
+
with gr.Row():
|
| 283 |
+
with gr.Column(scale=1):
|
| 284 |
+
input_video = gr.Video(
|
| 285 |
+
label="📂 Upload Input Video",
|
| 286 |
+
height=400,
|
| 287 |
+
interactive=True
|
| 288 |
+
)
|
| 289 |
+
process_btn = gr.Button(
|
| 290 |
+
"🚀 Process Video",
|
| 291 |
+
variant="primary",
|
| 292 |
+
size="lg"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
with gr.Column(scale=1):
|
| 296 |
+
output_video = gr.Video(
|
| 297 |
+
label="🎬 Processed Video (with tracking)",
|
| 298 |
+
height=400,
|
| 299 |
+
interactive=False,
|
| 300 |
+
show_download_button=True # Enable download button
|
| 301 |
+
)
|
| 302 |
+
download_json = gr.File(
|
| 303 |
+
label="📊 Download Tracking Data (JSON)",
|
| 304 |
+
interactive=False
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
with gr.Row():
|
| 308 |
+
status_text = gr.Markdown("📤 Upload a video and click **'Process Video'** to start tracking people.")
|
| 309 |
+
|
| 310 |
+
# Event handler
|
| 311 |
+
process_btn.click(
|
| 312 |
+
fn=process_video_gradio,
|
| 313 |
+
inputs=[input_video],
|
| 314 |
+
outputs=[output_video, download_json, status_text],
|
| 315 |
+
show_progress=True
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Additional information
|
| 319 |
+
with gr.Accordion("📖 How it works", open=False):
|
| 320 |
+
gr.Markdown("""
|
| 321 |
+
### 🔧 **Technology Stack:**
|
| 322 |
+
- **YOLOv8:** Real-time person detection
|
| 323 |
+
- **ByteTrack:** Multi-object tracking algorithm
|
| 324 |
+
- **InsightFace:** Facial feature extraction for person identification
|
| 325 |
+
- **OSNet:** Full-body re-identification features
|
| 326 |
+
|
| 327 |
+
### 📋 **Process:**
|
| 328 |
+
1. **Detection:** YOLOv8 detects people in each frame
|
| 329 |
+
2. **Tracking:** ByteTrack assigns temporary tracking IDs
|
| 330 |
+
3. **Feature Extraction:** InsightFace + OSNet extract identifying features
|
| 331 |
+
4. **Re-identification:** Combines face and body features for consistent global IDs
|
| 332 |
+
5. **Output:** Generates annotated video + detailed JSON tracking data
|
| 333 |
+
|
| 334 |
+
### 📁 **Supported Formats:**
|
| 335 |
+
- **Input:** MP4, AVI, MOV, WEBM
|
| 336 |
+
- **Output:** MP4 video + JSON metadata
|
| 337 |
+
""")
|
| 338 |
+
|
| 339 |
+
with gr.Accordion("⚙️ Model Configuration", open=False):
|
| 340 |
+
gr.Markdown(f"""
|
| 341 |
+
- **Detection Threshold:** {DETECTION_THRESHOLD}
|
| 342 |
+
- **Similarity Threshold:** 0.6 (for person re-identification)
|
| 343 |
+
- **Device:** {"CUDA" if torch.cuda.is_available() else "CPU"}
|
| 344 |
+
- **Output Directory:** {OUTPUT_DIR}
|
| 345 |
+
""")
|
| 346 |
+
|
| 347 |
+
with gr.Accordion("🔧 Troubleshooting", open=False):
|
| 348 |
+
gr.Markdown("""
|
| 349 |
+
**If video doesn't display:**
|
| 350 |
+
1. Check if the output file exists in the outputs directory
|
| 351 |
+
2. Try downloading the video manually
|
| 352 |
+
3. Ensure proper video codec support
|
| 353 |
+
|
| 354 |
+
**Common issues:**
|
| 355 |
+
- Large video files may take time to load
|
| 356 |
+
- Some browsers may not support certain video formats
|
| 357 |
+
- Network issues can affect video streaming
|
| 358 |
+
""")
|
| 359 |
+
|
| 360 |
+
return demo
|
| 361 |
+
|
| 362 |
+
# ========== Launch ==========
|
| 363 |
+
if __name__ == "__main__":
|
| 364 |
+
demo = create_interface()
|
| 365 |
+
# Add file serving for outputs directory
|
| 366 |
+
demo.launch(
|
| 367 |
+
share=False,
|
| 368 |
+
server_name="127.0.0.1",
|
| 369 |
+
server_port=7860,
|
| 370 |
+
show_error=True
|
| 371 |
+
)
|
detection.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04f78656185b52201e8bb37ac0990901ccbfcb4b1455c3f514ea18adc702672c
|
| 3 |
+
size 40485178
|
detection.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
from django import conf
|
| 3 |
+
import numpy as np
|
| 4 |
+
from ultralytics import YOLO
|
| 5 |
+
from insightface.app import FaceAnalysis
|
| 6 |
+
import torchreid
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
# Configuration
|
| 10 |
+
DETECTION_THRESHOLD = 0.75 # Confidence threshold for person detection
|
| 11 |
+
|
| 12 |
+
# =============================================================================
|
| 13 |
+
# MODEL INITIALIZATION
|
| 14 |
+
# =============================================================================
|
| 15 |
+
|
| 16 |
+
# Load YOLOv8 model with ByteTrack tracker for person detection and tracking
|
| 17 |
+
# YOLOv8 handles object detection while ByteTrack provides consistent tracking IDs
|
| 18 |
+
model = YOLO(r'detection.pt') # Replace with your trained model path
|
| 19 |
+
|
| 20 |
+
# Initialize InsightFace for facial feature extraction
|
| 21 |
+
# Uses buffalo_l model which provides high-quality face embeddings
|
| 22 |
+
face_app = FaceAnalysis(name='buffalo_l', providers=['CUDAExecutionProvider'])
|
| 23 |
+
face_app.prepare(ctx_id=0) # Prepare for GPU inference
|
| 24 |
+
|
| 25 |
+
# Initialize TorchReID for full-body person re-identification
|
| 26 |
+
# OSNet is a lightweight but effective model for person ReID
|
| 27 |
+
reid_extractor = torchreid.utils.FeatureExtractor(
|
| 28 |
+
model_name='osnet_x0_25',
|
| 29 |
+
model_path='osnet_x0_25_market1501.pth', # Pre-trained on Market1501 dataset
|
| 30 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# =============================================================================
|
| 34 |
+
# GLOBAL VARIABLES FOR PERSON RE-IDENTIFICATION
|
| 35 |
+
# =============================================================================
|
| 36 |
+
|
| 37 |
+
# Storage for known person embeddings and their assigned global IDs
|
| 38 |
+
known_embeddings = [] # List of combined face+body embeddings
|
| 39 |
+
known_ids = [] # Corresponding global IDs for each embedding
|
| 40 |
+
next_global_id = 1 # Counter for assigning new global IDs
|
| 41 |
+
|
| 42 |
+
# Mapping from ByteTrack tracker IDs to global person IDs
|
| 43 |
+
# This helps maintain consistency when tracker IDs change
|
| 44 |
+
track_to_global = {}
|
| 45 |
+
|
| 46 |
+
# =============================================================================
|
| 47 |
+
# VIDEO INPUT/OUTPUT SETUP
|
| 48 |
+
# =============================================================================
|
| 49 |
+
|
| 50 |
+
# Initialize video capture and output writer
|
| 51 |
+
cap = cv2.VideoCapture("demo.mp4") # Input video file
|
| 52 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 53 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 54 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 55 |
+
|
| 56 |
+
# Create output video writer with same properties as input
|
| 57 |
+
out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
|
| 58 |
+
|
| 59 |
+
# =============================================================================
|
| 60 |
+
# MAIN PROCESSING LOOP
|
| 61 |
+
# =============================================================================
|
| 62 |
+
|
| 63 |
+
while True:
|
| 64 |
+
ret, frame = cap.read()
|
| 65 |
+
if not ret:
|
| 66 |
+
break # End of video
|
| 67 |
+
|
| 68 |
+
# Run YOLOv8 detection with ByteTrack tracking
|
| 69 |
+
# persist=True maintains tracking across frames
|
| 70 |
+
results = model.track(frame, tracker="bytetrack.yaml", persist=True,
|
| 71 |
+
verbose=False, conf=DETECTION_THRESHOLD)
|
| 72 |
+
|
| 73 |
+
# Process each detection result
|
| 74 |
+
for result in results:
|
| 75 |
+
# Extract bounding boxes in (x1, y1, x2, y2) format
|
| 76 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
| 77 |
+
|
| 78 |
+
# Extract tracking IDs if available
|
| 79 |
+
if result.boxes.id is not None:
|
| 80 |
+
track_ids = result.boxes.id.int().cpu().tolist()
|
| 81 |
+
else:
|
| 82 |
+
# No tracking IDs available, assign None for each detection
|
| 83 |
+
track_ids = [None] * len(boxes)
|
| 84 |
+
|
| 85 |
+
# Process each detected person
|
| 86 |
+
for box, track_id in zip(boxes, track_ids):
|
| 87 |
+
x1, y1, x2, y2 = map(int, box)
|
| 88 |
+
|
| 89 |
+
# Crop the person from the frame
|
| 90 |
+
person_crop = frame[y1:y2, x1:x2]
|
| 91 |
+
|
| 92 |
+
# Initialize embedding variables
|
| 93 |
+
face_embedding = None
|
| 94 |
+
body_embedding = None
|
| 95 |
+
|
| 96 |
+
# =============================================================
|
| 97 |
+
# FACE EMBEDDING EXTRACTION
|
| 98 |
+
# =============================================================
|
| 99 |
+
|
| 100 |
+
# Extract face embedding using InsightFace
|
| 101 |
+
faces = face_app.get(person_crop)
|
| 102 |
+
if faces:
|
| 103 |
+
# Use the first detected face (most confident)
|
| 104 |
+
face_embedding = faces[0].embedding
|
| 105 |
+
|
| 106 |
+
# =============================================================
|
| 107 |
+
# BODY EMBEDDING EXTRACTION
|
| 108 |
+
# =============================================================
|
| 109 |
+
|
| 110 |
+
# Extract body embedding using TorchReID
|
| 111 |
+
try:
|
| 112 |
+
# TorchReID expects 128x256 RGB input
|
| 113 |
+
body_input = cv2.resize(person_crop, (128, 256))
|
| 114 |
+
body_input = cv2.cvtColor(body_input, cv2.COLOR_BGR2RGB)
|
| 115 |
+
|
| 116 |
+
# Extract features and convert to numpy
|
| 117 |
+
body_embedding = reid_extractor(body_input)[0].cpu().numpy()
|
| 118 |
+
except:
|
| 119 |
+
# Handle cases where crop is too small or invalid
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
+
# =============================================================
|
| 123 |
+
# EMBEDDING COMBINATION AND PERSON MATCHING
|
| 124 |
+
# =============================================================
|
| 125 |
+
|
| 126 |
+
# Combine face and body embeddings for robust person representation
|
| 127 |
+
embedding = None
|
| 128 |
+
if face_embedding is not None and body_embedding is not None:
|
| 129 |
+
# Concatenate both embeddings for maximum distinctiveness
|
| 130 |
+
embedding = np.concatenate((face_embedding, body_embedding)).astype(np.float32)
|
| 131 |
+
elif face_embedding is not None:
|
| 132 |
+
# Use only face embedding if body embedding failed
|
| 133 |
+
embedding = face_embedding.astype(np.float32)
|
| 134 |
+
elif body_embedding is not None:
|
| 135 |
+
# Use only body embedding if face detection failed
|
| 136 |
+
embedding = body_embedding.astype(np.float32)
|
| 137 |
+
|
| 138 |
+
# Assign global ID based on embedding similarity
|
| 139 |
+
if embedding is not None:
|
| 140 |
+
match_found = False
|
| 141 |
+
|
| 142 |
+
# Search for similar embeddings among known people
|
| 143 |
+
if known_embeddings:
|
| 144 |
+
# Only compare embeddings of the same dimension
|
| 145 |
+
matching_embeddings = [
|
| 146 |
+
(emb, gid) for emb, gid in zip(known_embeddings, known_ids)
|
| 147 |
+
if emb.shape[0] == embedding.shape[0]
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
if matching_embeddings:
|
| 151 |
+
embs, gids = zip(*matching_embeddings)
|
| 152 |
+
embs = np.array(embs)
|
| 153 |
+
|
| 154 |
+
# Calculate cosine similarity with all known embeddings
|
| 155 |
+
sims = np.dot(embs, embedding) / (
|
| 156 |
+
np.linalg.norm(embs, axis=1) * np.linalg.norm(embedding) + 1e-6
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Find the best match
|
| 160 |
+
best_match = np.argmax(sims)
|
| 161 |
+
if sims[best_match] > 0.6: # Similarity threshold
|
| 162 |
+
global_id = gids[best_match]
|
| 163 |
+
match_found = True
|
| 164 |
+
|
| 165 |
+
# If no match found, assign new global ID
|
| 166 |
+
if not match_found:
|
| 167 |
+
global_id = next_global_id
|
| 168 |
+
next_global_id += 1
|
| 169 |
+
known_embeddings.append(embedding)
|
| 170 |
+
known_ids.append(global_id)
|
| 171 |
+
|
| 172 |
+
# Update tracker ID to global ID mapping
|
| 173 |
+
if track_id is not None:
|
| 174 |
+
track_to_global[track_id] = global_id
|
| 175 |
+
|
| 176 |
+
display_id = global_id
|
| 177 |
+
|
| 178 |
+
else:
|
| 179 |
+
# No usable embedding available, fallback to tracker ID
|
| 180 |
+
global_id = track_to_global.get(track_id, f"T{track_id}")
|
| 181 |
+
display_id = global_id
|
| 182 |
+
|
| 183 |
+
# =============================================================
|
| 184 |
+
# VISUALIZATION
|
| 185 |
+
# =============================================================
|
| 186 |
+
|
| 187 |
+
# Draw bounding box around detected person
|
| 188 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 189 |
+
|
| 190 |
+
# Display the global ID above the bounding box
|
| 191 |
+
cv2.putText(frame, f"ID {display_id}", (x1, y1 - 10),
|
| 192 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
| 193 |
+
|
| 194 |
+
# =============================================================================
|
| 195 |
+
# OUTPUT AND DISPLAY
|
| 196 |
+
# =============================================================================
|
| 197 |
+
|
| 198 |
+
# Show the frame with tracking results
|
| 199 |
+
cv2.imshow("Tracking + ReID", frame)
|
| 200 |
+
|
| 201 |
+
# Break loop if 'q' key is pressed
|
| 202 |
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
| 203 |
+
break
|
| 204 |
+
|
| 205 |
+
# Write frame to output video
|
| 206 |
+
out.write(frame)
|
| 207 |
+
|
| 208 |
+
# =============================================================================
|
| 209 |
+
# CLEANUP
|
| 210 |
+
# =============================================================================
|
| 211 |
+
|
| 212 |
+
# Release video capture and writer resources
|
| 213 |
+
cap.release()
|
| 214 |
+
out.release()
|
| 215 |
+
cv2.destroyAllWindows()
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
| 2 |
+
|
| 3 |
+
torch==2.4.1
|
| 4 |
+
torchvision==0.19.1
|
| 5 |
+
torchaudio==2.4.1
|
| 6 |
+
gradio==5.35.0
|
| 7 |
+
insightface==0.7.3
|
| 8 |
+
onnxruntime-gpu==1.14.1
|
| 9 |
+
torchreid==0.2.5
|
| 10 |
+
ultralytics==8.3.161
|
| 11 |
+
gdown==5.2.0
|
| 12 |
+
gradio==5.35.0
|
| 13 |
+
lap==0.5.12
|
| 14 |
+
tensorboard==2.19.0
|