Spaces:
Sleeping
Sleeping
| import torch | |
| import librosa | |
| import numpy as np | |
| import os | |
| import webrtcvad | |
| import wave | |
| import contextlib | |
| import gradio as gr | |
| from utils.VAD_segments import * | |
| from utils.hparam import hparam as hp | |
| from utils.speech_embedder_net import * | |
| from utils.evaluation import * | |
| def read_wave(audio_data): | |
| """Reads audio data and returns (PCM audio data, sample rate). | |
| Assumes the input is a tuple (sample_rate, numpy_array). | |
| If the sample rate is unsupported, resamples to 16000 Hz. | |
| """ | |
| sample_rate, data = audio_data | |
| # Ensure data is in the correct shape | |
| assert len(data.shape) == 1, "Audio data must be a 1D array" | |
| # Convert to floating point if necessary | |
| if not np.issubdtype(data.dtype, np.floating): | |
| data = data.astype(np.float32) / np.iinfo(data.dtype).max | |
| # Supported sample rates | |
| supported_sample_rates = (8000, 16000, 32000, 48000) | |
| # If sample rate is not supported, resample to 16000 Hz | |
| if sample_rate not in supported_sample_rates: | |
| data = librosa.resample(data, orig_sr=sample_rate, target_sr=16000) | |
| sample_rate = 16000 | |
| # Convert numpy array to PCM format | |
| pcm_data = (data * np.iinfo(np.int16).max).astype(np.int16).tobytes() | |
| return data, pcm_data | |
| def VAD_chunk(aggressiveness, data): | |
| audio, byte_audio = read_wave(data) | |
| vad = webrtcvad.Vad(int(aggressiveness)) | |
| frames = frame_generator(20, byte_audio, hp.data.sr) | |
| frames = list(frames) | |
| times = vad_collector(hp.data.sr, 20, 200, vad, frames) | |
| speech_times = [] | |
| speech_segs = [] | |
| for i, time in enumerate(times): | |
| start = np.round(time[0],decimals=2) | |
| end = np.round(time[1],decimals=2) | |
| j = start | |
| while j + .4 < end: | |
| end_j = np.round(j+.4,decimals=2) | |
| speech_times.append((j, end_j)) | |
| speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)]) | |
| j = end_j | |
| else: | |
| speech_times.append((j, end)) | |
| speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)]) | |
| return speech_times, speech_segs | |
| def get_embedding(data, embedder_net, device, n_threshold=-1): | |
| times, segs = VAD_chunk(0, data) | |
| if not segs: | |
| print(f'No voice activity detected') | |
| return None | |
| concat_seg = concat_segs(times, segs) | |
| if not concat_seg: | |
| print(f'No concatenated segments') | |
| return None | |
| STFT_frames = get_STFTs(concat_seg) | |
| if not STFT_frames: | |
| #print(f'No STFT frames') | |
| return None | |
| STFT_frames = np.stack(STFT_frames, axis=2) | |
| STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device) | |
| with torch.no_grad(): | |
| embeddings = embedder_net(STFT_frames) | |
| embeddings = embeddings[:n_threshold, :] | |
| avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy() | |
| return avg_embedding | |
| model_path = "./speech_id_checkpoint/saved_02.model" | |
| embedder_net = SpeechEmbedder() | |
| embedder_net.load_state_dict(torch.load(model_path)) | |
| embedder_net.eval() | |
| def process_audio(audio1, audio2, threshold): | |
| e1 = get_embedding(audio1, embedder_net, torch.device("cpu")) | |
| if(e1 is None): | |
| return "No Voice Detected in file 1" | |
| e2 = get_embedding(audio2, embedder_net, torch.device("cpu")) | |
| if(e2 is None): | |
| return "No Voice Detected in file 2" | |
| cosi = cosine_similarity(e1, e2) | |
| if(cosi > threshold): | |
| return f"Same Speaker" | |
| else: | |
| return f"Different Speaker" | |
| # Define the Gradio interface | |
| def gradio_interface(audio1, audio2, threshold): | |
| output_text = process_audio(audio1, audio2, threshold) | |
| return output_text | |
| description = """ | |
| <p> | |
| <center> | |
| This is an LSTM based Speaker Embedding Model trained using <a href="https://arxiv.org/abs/1710.10467">GE2E loss</a> on the <a href="https://openslr.org/78/">Gujarati OpenSLR dataset</a>. | |
| <img src="https://huggingface.co/spaces/1rsh/gujarati-tisv/resolve/main/img/gujarati-text.png" alt="Gujarati" width="250"/> | |
| </center> | |
| </p> | |
| """ | |
| # Create the Gradio interface with microphone inputs | |
| iface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[gr.Audio("microphone", type="numpy", label="Audio File 1"), | |
| gr.Audio("microphone", type="numpy", label="Audio File 2"), | |
| gr.Slider(0.0, 1.0, value=0.85, step=0.01, label="Threshold") | |
| ], | |
| outputs="text", | |
| title="ગુજરાતી Text Independent Speaker Verification", | |
| description=description | |
| ) | |
| # Launch the interface | |
| iface.launch(share=False) |