Azgadel commited on
Commit
abe7eaf
Β·
verified Β·
1 Parent(s): c0f6727

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +494 -0
  2. best_embedding_model.pth +3 -0
  3. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
3
+
4
+ import streamlit as st
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import soundfile as sf
9
+ import torchaudio
10
+ from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
11
+ import numpy as np
12
+ from pathlib import Path
13
+ import json
14
+ import tempfile
15
+
16
+ # ============================================================
17
+ # MODEL DEFINITION
18
+ # ============================================================
19
+
20
+ class Wav2Vec2ForSpeakerEmbedding(nn.Module):
21
+ def __init__(self, embedding_size=256):
22
+ super().__init__()
23
+ self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
24
+ for param in self.wav2vec2.parameters():
25
+ param.requires_grad = False
26
+
27
+ self.projection = nn.Sequential(
28
+ nn.Linear(768, 512),
29
+ nn.ReLU(),
30
+ nn.Dropout(0.1),
31
+ nn.Linear(512, embedding_size)
32
+ )
33
+
34
+ def forward(self, input_values):
35
+ outputs = self.wav2vec2(input_values)
36
+ hidden_states = outputs.last_hidden_state
37
+ embeddings = torch.mean(hidden_states, dim=1)
38
+ embeddings = self.projection(embeddings)
39
+ embeddings = F.normalize(embeddings, p=2, dim=1)
40
+ return embeddings
41
+
42
+
43
+ # ============================================================
44
+ # AUDIO PROCESSING
45
+ # ============================================================
46
+
47
+ def process_audio(audio_file, feature_extractor, max_length=16000*3):
48
+ """Process uploaded audio file"""
49
+ try:
50
+ # Save uploaded file temporarily
51
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
52
+ tmp_file.write(audio_file.getvalue())
53
+ tmp_path = tmp_file.name
54
+
55
+ # Load audio
56
+ waveform, sr = sf.read(tmp_path, dtype='float32')
57
+ waveform = torch.from_numpy(waveform)
58
+
59
+ # Convert to mono
60
+ if len(waveform.shape) > 1:
61
+ waveform = torch.mean(waveform, dim=-1)
62
+
63
+ # Resample to 16kHz
64
+ if sr != 16000:
65
+ resampler = torchaudio.transforms.Resample(sr, 16000)
66
+ waveform = resampler(waveform)
67
+
68
+ # Take middle chunk
69
+ if len(waveform) > max_length:
70
+ start = (len(waveform) - max_length) // 2
71
+ waveform = waveform[start:start + max_length]
72
+ elif len(waveform) < max_length:
73
+ padding = max_length - len(waveform)
74
+ waveform = torch.nn.functional.pad(waveform, (0, padding))
75
+
76
+ # Normalize
77
+ if waveform.abs().max() > 0:
78
+ waveform = waveform / waveform.abs().max()
79
+
80
+ # Extract features
81
+ inputs = feature_extractor(
82
+ waveform.numpy(),
83
+ sampling_rate=16000,
84
+ return_tensors="pt"
85
+ )
86
+
87
+ # Cleanup
88
+ os.unlink(tmp_path)
89
+
90
+ return inputs.input_values, waveform.numpy(), sr
91
+
92
+ except Exception as e:
93
+ st.error(f"Error processing audio: {e}")
94
+ return None, None, None
95
+
96
+
97
+ def get_embedding(model, audio_file, feature_extractor, device):
98
+ """Extract embedding from audio file"""
99
+ inputs, waveform, sr = process_audio(audio_file, feature_extractor)
100
+ if inputs is None:
101
+ return None
102
+
103
+ model.eval()
104
+ with torch.no_grad():
105
+ inputs = inputs.to(device)
106
+ embedding = model(inputs)
107
+
108
+ return embedding.cpu().numpy()
109
+
110
+
111
+ # ============================================================
112
+ # ENROLLMENT DATABASE
113
+ # ============================================================
114
+
115
+ class EnrollmentDB:
116
+ def __init__(self, db_path='enrollments.json'):
117
+ self.db_path = db_path
118
+ self.load_db()
119
+
120
+ def load_db(self):
121
+ if os.path.exists(self.db_path):
122
+ with open(self.db_path, 'r') as f:
123
+ data = json.load(f)
124
+ self.enrollments = {k: np.array(v) for k, v in data.items()}
125
+ else:
126
+ self.enrollments = {}
127
+
128
+ def save_db(self):
129
+ data = {k: v.tolist() for k, v in self.enrollments.items()}
130
+ with open(self.db_path, 'w') as f:
131
+ json.dump(data, f)
132
+
133
+ def enroll(self, name, embedding):
134
+ self.enrollments[name] = embedding
135
+ self.save_db()
136
+
137
+ def verify(self, embedding, threshold=0.75):
138
+ """
139
+ Verify against all enrolled users
140
+ Returns: (best_match_name, similarity_score, is_verified)
141
+ """
142
+ if not self.enrollments:
143
+ return None, 0.0, False
144
+
145
+ best_match = None
146
+ best_score = -1.0
147
+
148
+ embedding_tensor = torch.from_numpy(embedding)
149
+
150
+ for name, enrolled_emb in self.enrollments.items():
151
+ enrolled_tensor = torch.from_numpy(enrolled_emb)
152
+ similarity = F.cosine_similarity(embedding_tensor, enrolled_tensor, dim=1).item()
153
+
154
+ if similarity > best_score:
155
+ best_score = similarity
156
+ best_match = name
157
+
158
+ is_verified = best_score >= threshold
159
+
160
+ return best_match, best_score, is_verified
161
+
162
+ def get_all_users(self):
163
+ return list(self.enrollments.keys())
164
+
165
+ def remove_user(self, name):
166
+ if name in self.enrollments:
167
+ del self.enrollments[name]
168
+ self.save_db()
169
+ return True
170
+ return False
171
+
172
+
173
+ # ============================================================
174
+ # STREAMLIT APP
175
+ # ============================================================
176
+
177
+ @st.cache_resource
178
+ def load_model():
179
+ """Load model once and cache it"""
180
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
181
+
182
+ model = Wav2Vec2ForSpeakerEmbedding(embedding_size=256).to(device)
183
+ checkpoint = torch.load('best_embedding_model.pth', map_location=device)
184
+ model.load_state_dict(checkpoint['model_state_dict'])
185
+ model.eval()
186
+
187
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
188
+
189
+ return model, feature_extractor, device
190
+
191
+
192
+ def main():
193
+ st.set_page_config(
194
+ page_title="Voice Biometry Demo",
195
+ page_icon="🎀",
196
+ layout="wide"
197
+ )
198
+
199
+ # Custom CSS
200
+ st.markdown("""
201
+ <style>
202
+ .big-font {
203
+ font-size:20px !important;
204
+ font-weight: bold;
205
+ }
206
+ .success-box {
207
+ padding: 20px;
208
+ border-radius: 10px;
209
+ background-color: #d4edda;
210
+ border: 2px solid #28a745;
211
+ color: #155724;
212
+ }
213
+ .failure-box {
214
+ padding: 20px;
215
+ border-radius: 10px;
216
+ background-color: #f8d7da;
217
+ border: 2px solid #dc3545;
218
+ color: #721c24;
219
+ }
220
+ .info-box {
221
+ padding: 20px;
222
+ border-radius: 10px;
223
+ background-color: #d1ecf1;
224
+ border: 2px solid #17a2b8;
225
+ color: #0c5460;
226
+ }
227
+ </style>
228
+ """, unsafe_allow_html=True)
229
+
230
+ # Header
231
+ st.title("Voice Biometry System - Proof of Concept")
232
+ st.markdown("### Finetuned Wav2Vec 2.0")
233
+
234
+ # Load model
235
+ with st.spinner("Loading model..."):
236
+ model, feature_extractor, device = load_model()
237
+
238
+ # Initialize database
239
+ db = EnrollmentDB()
240
+
241
+ # Sidebar - Configuration
242
+ st.sidebar.header("βš™οΈ Configuration")
243
+ threshold = st.sidebar.slider(
244
+ "Verification Threshold",
245
+ min_value=0.5,
246
+ max_value=0.95,
247
+ value=0.75,
248
+ step=0.05,
249
+ help="Higher = more strict verification"
250
+ )
251
+
252
+ st.sidebar.markdown("---")
253
+ st.sidebar.header("πŸ“Š System Stats")
254
+ st.sidebar.metric("Enrolled Users", len(db.get_all_users()))
255
+ st.sidebar.metric("Model Accuracy", "76%")
256
+ st.sidebar.metric("AUC Score", "0.82")
257
+
258
+ # Enrolled users list
259
+ if db.get_all_users():
260
+ st.sidebar.markdown("---")
261
+ st.sidebar.header("πŸ‘₯ Enrolled Users")
262
+ for user in db.get_all_users():
263
+ col1, col2 = st.sidebar.columns([3, 1])
264
+ col1.write(f"β€’ {user}")
265
+ if col2.button("πŸ—‘οΈ", key=f"del_{user}"):
266
+ db.remove_user(user)
267
+ st.rerun()
268
+
269
+ # Main tabs
270
+ tab1, tab2, tab3 = st.tabs(["πŸ“ Enrollment", "βœ… Verification", "ℹ️ About"])
271
+
272
+ # ============================================================
273
+ # TAB 1: ENROLLMENT
274
+ # ============================================================
275
+ with tab1:
276
+ st.header("Enroll a New User")
277
+ st.markdown("Upload a voice recording to register a new user in the system.")
278
+
279
+ col1, col2 = st.columns([2, 1])
280
+
281
+ with col1:
282
+ enroll_name = st.text_input(
283
+ "User Name",
284
+ placeholder="Enter name (e.g., Abdou Diop)",
285
+ help="This name will be used to identify the speaker"
286
+ )
287
+
288
+ enroll_audio = st.file_uploader(
289
+ "Upload Voice Recording",
290
+ type=['wav', 'mp3', 'flac', 'ogg'],
291
+ help="Upload a clear voice recording (3-20 seconds recommended)",
292
+ key="enroll"
293
+ )
294
+
295
+ with col2:
296
+ st.info("""
297
+ **Enrollment Tips:**
298
+ - Use clear audio
299
+ - 3-20 seconds long
300
+ - Minimal background noise
301
+ - Normal speaking voice
302
+ """)
303
+
304
+ if st.button("🎯 Enroll User", type="primary", disabled=(not enroll_name or not enroll_audio)):
305
+ with st.spinner(f"Processing enrollment for {enroll_name}..."):
306
+ # Check if user already exists
307
+ if enroll_name in db.get_all_users():
308
+ st.warning(f"⚠️ User '{enroll_name}' already exists. Please use a different name or remove the existing user first.")
309
+ else:
310
+ # Get embedding
311
+ embedding = get_embedding(model, enroll_audio, feature_extractor, device)
312
+
313
+ if embedding is not None:
314
+ # Save enrollment
315
+ db.enroll(enroll_name, embedding)
316
+
317
+ st.markdown(f"""
318
+ <div class="success-box">
319
+ <h3>βœ… Enrollment Successful!</h3>
320
+ <p><strong>{enroll_name}</strong> has been enrolled in the system.</p>
321
+ <p>Total enrolled users: {len(db.get_all_users())}</p>
322
+ </div>
323
+ """, unsafe_allow_html=True)
324
+
325
+ #st.balloons()
326
+ else:
327
+ st.error("❌ Failed to process audio. Please try again with a different recording.")
328
+
329
+ # ============================================================
330
+ # TAB 2: VERIFICATION
331
+ # ============================================================
332
+ with tab2:
333
+ st.header("Verify User Identity")
334
+ st.markdown("Upload a voice recording to verify against enrolled users.")
335
+
336
+ if not db.get_all_users():
337
+ st.warning("⚠️ No users enrolled yet. Please enroll at least one user first.")
338
+ else:
339
+ col1, col2 = st.columns([2, 1])
340
+
341
+ with col1:
342
+ verify_audio = st.file_uploader(
343
+ "Upload Voice Recording for Verification",
344
+ type=['wav', 'mp3', 'flac', 'ogg'],
345
+ help="Upload a voice recording from a speaker you want to verify",
346
+ key="verify"
347
+ )
348
+
349
+ with col2:
350
+ st.info(f"""
351
+ **Verification Info:**
352
+ - {len(db.get_all_users())} users enrolled
353
+ - Threshold: {threshold:.2f}
354
+ - Model: Wav2Vec 2.0
355
+ """)
356
+
357
+ if st.button("πŸ” Verify Identity", type="primary", disabled=(not verify_audio)):
358
+ with st.spinner("Analyzing voice..."):
359
+ # Get embedding
360
+ embedding = get_embedding(model, verify_audio, feature_extractor, device)
361
+
362
+ if embedding is not None:
363
+ # Verify
364
+ match_name, similarity, is_verified = db.verify(embedding, threshold)
365
+
366
+ # Display results
367
+ st.markdown("---")
368
+
369
+ if is_verified:
370
+ st.markdown(f"""
371
+ <div class="success-box">
372
+ <h2>βœ… VERIFICATION SUCCESSFUL</h2>
373
+ <h3>Identified as: {match_name}</h3>
374
+ <p style="font-size: 18px;">Confidence Score: <strong>{similarity:.1%}</strong></p>
375
+ </div>
376
+ """, unsafe_allow_html=True)
377
+
378
+ st.success(f"πŸŽ‰ Welcome back, {match_name}!")
379
+
380
+ else:
381
+ st.markdown(f"""
382
+ <div class="failure-box">
383
+ <h2>❌ VERIFICATION FAILED</h2>
384
+ <p>Closest match: <strong>{match_name}</strong></p>
385
+ <p>Similarity: <strong>{similarity:.1%}</strong></p>
386
+ <p>Threshold required: <strong>{threshold:.1%}</strong></p>
387
+ <p><em>This speaker is not recognized in the system.</em></p>
388
+ </div>
389
+ """, unsafe_allow_html=True)
390
+
391
+ # Show all scores
392
+ with st.expander("πŸ“Š See detailed scores for all enrolled users"):
393
+ st.markdown("### Similarity Scores")
394
+
395
+ scores = []
396
+ embedding_tensor = torch.from_numpy(embedding)
397
+
398
+ for name, enrolled_emb in db.enrollments.items():
399
+ enrolled_tensor = torch.from_numpy(enrolled_emb)
400
+ sim = F.cosine_similarity(embedding_tensor, enrolled_tensor, dim=1).item()
401
+ scores.append({
402
+ 'User': name,
403
+ 'Similarity': f"{sim:.1%}",
404
+ 'Status': 'βœ… Match' if sim >= threshold else '❌ No match'
405
+ })
406
+
407
+ # Sort by similarity
408
+ scores.sort(key=lambda x: x['Similarity'], reverse=True)
409
+
410
+ import pandas as pd
411
+ df = pd.DataFrame(scores)
412
+ st.dataframe(df, use_container_width=True, hide_index=True)
413
+
414
+ else:
415
+ st.error("❌ Failed to process audio. Please try again with a different recording.")
416
+
417
+ # ============================================================
418
+ # TAB 3: ABOUT
419
+ # ============================================================
420
+ with tab3:
421
+ st.header("About This System")
422
+
423
+ col1, col2 = st.columns(2)
424
+
425
+ with col1:
426
+ st.markdown("""
427
+ ### 🎯 Technology
428
+
429
+ **Model Architecture:**
430
+ - Base: Wav2Vec 2.0 (Facebook AI)
431
+ - Finetuned on 247 speakers
432
+ - 1035 voice samples (telephone quality, 8kHz)
433
+ - Embedding dimension: 256
434
+
435
+ **Training Details:**
436
+ - Loss: Supervised Contrastive Learning
437
+ - Framework: PyTorch + Transformers
438
+ - Training time: ~50 epochs
439
+ - Hardware: NVIDIA RTX 3050
440
+ """)
441
+
442
+ with col2:
443
+ st.markdown("""
444
+ ### πŸ“Š Performance Metrics
445
+
446
+ **Evaluation Results:**
447
+ - **Accuracy:** 76%
448
+ - **AUC Score:** 0.82
449
+ - **True Positive Rate:** 79%
450
+ - **False Positive Rate:** 27%
451
+
452
+ **Test Set:**
453
+ - 1000 verification pairs
454
+ - 500 same-speaker pairs
455
+ - 500 different-speaker pairs
456
+ """)
457
+
458
+ st.markdown("---")
459
+
460
+ st.markdown("""
461
+ ### πŸ”§ How It Works
462
+
463
+ 1. **Enrollment Phase:**
464
+ - User uploads voice recording
465
+ - System extracts 256-dimensional embedding
466
+ - Embedding stored in database with user name
467
+
468
+ 2. **Verification Phase:**
469
+ - Unknown voice recording uploaded
470
+ - System extracts embedding
471
+ - Computes cosine similarity with all enrolled users
472
+ - Returns match if similarity exceeds threshold
473
+
474
+ 3. **Matching Algorithm:**
475
+ - Cosine similarity between embeddings
476
+ - Range: -1 (opposite) to +1 (identical)
477
+ - Typical same-speaker: 0.75-0.95
478
+ - Typical different-speaker: 0.30-0.70
479
+ """)
480
+
481
+ st.markdown("---")
482
+
483
+ st.info("""
484
+ **Note:** This is a proof of concept system. For production deployment, consider:
485
+ - Larger training dataset (10-20 samples per speaker)
486
+ - Better base model (WavLM for noisy conditions)
487
+ - Anti-spoofing measures
488
+ - Liveness detection
489
+ - Multi-enrollment (average multiple recordings per user)
490
+ """)
491
+
492
+
493
+ if __name__ == "__main__":
494
+ main()
best_embedding_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3312a4527b3bea45dc377a9a8dacf0f8421e8a8597947338b49140c0bc2e35e4
3
+ size 379678794
requirements.txt ADDED
Binary file (176 Bytes). View file