Spaces:
Running
Running
| import torch | |
| import torchaudio | |
| from typing import Callable, List | |
| import torch.nn.functional as F | |
| import warnings | |
| languages = ['ru', 'en', 'de', 'es'] | |
| class OnnxWrapper(): | |
| def __init__(self, path, force_onnx_cpu=False): | |
| import numpy as np | |
| global np | |
| import onnxruntime | |
| opts = onnxruntime.SessionOptions() | |
| opts.inter_op_num_threads = 1 | |
| opts.intra_op_num_threads = 1 | |
| if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): | |
| self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts) | |
| else: | |
| self.session = onnxruntime.InferenceSession(path, sess_options=opts) | |
| self.reset_states() | |
| self.sample_rates = [8000, 16000] | |
| def _validate_input(self, x, sr: int): | |
| if x.dim() == 1: | |
| x = x.unsqueeze(0) | |
| if x.dim() > 2: | |
| raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}") | |
| if sr != 16000 and (sr % 16000 == 0): | |
| step = sr // 16000 | |
| x = x[:,::step] | |
| sr = 16000 | |
| if sr not in self.sample_rates: | |
| raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)") | |
| if sr / x.shape[1] > 31.25: | |
| raise ValueError("Input audio chunk is too short") | |
| return x, sr | |
| def reset_states(self, batch_size=1): | |
| self._h = np.zeros((2, batch_size, 64)).astype('float32') | |
| self._c = np.zeros((2, batch_size, 64)).astype('float32') | |
| self._last_sr = 0 | |
| self._last_batch_size = 0 | |
| def __call__(self, x, sr: int): | |
| x, sr = self._validate_input(x, sr) | |
| batch_size = x.shape[0] | |
| if not self._last_batch_size: | |
| self.reset_states(batch_size) | |
| if (self._last_sr) and (self._last_sr != sr): | |
| self.reset_states(batch_size) | |
| if (self._last_batch_size) and (self._last_batch_size != batch_size): | |
| self.reset_states(batch_size) | |
| if sr in [8000, 16000]: | |
| ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')} | |
| ort_outs = self.session.run(None, ort_inputs) | |
| out, self._h, self._c = ort_outs | |
| else: | |
| raise ValueError() | |
| self._last_sr = sr | |
| self._last_batch_size = batch_size | |
| out = torch.tensor(out) | |
| return out | |
| def audio_forward(self, x, sr: int, num_samples: int = 512): | |
| outs = [] | |
| x, sr = self._validate_input(x, sr) | |
| if x.shape[1] % num_samples: | |
| pad_num = num_samples - (x.shape[1] % num_samples) | |
| x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0) | |
| self.reset_states(x.shape[0]) | |
| for i in range(0, x.shape[1], num_samples): | |
| wavs_batch = x[:, i:i+num_samples] | |
| out_chunk = self.__call__(wavs_batch, sr) | |
| outs.append(out_chunk) | |
| stacked = torch.cat(outs, dim=1) | |
| return stacked.cpu() | |
| class Validator(): | |
| def __init__(self, url, force_onnx_cpu): | |
| self.onnx = True if url.endswith('.onnx') else False | |
| torch.hub.download_url_to_file(url, 'inf.model') | |
| if self.onnx: | |
| import onnxruntime | |
| if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): | |
| self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider']) | |
| else: | |
| self.model = onnxruntime.InferenceSession('inf.model') | |
| else: | |
| self.model = init_jit_model(model_path='inf.model') | |
| def __call__(self, inputs: torch.Tensor): | |
| with torch.no_grad(): | |
| if self.onnx: | |
| ort_inputs = {'input': inputs.cpu().numpy()} | |
| outs = self.model.run(None, ort_inputs) | |
| outs = [torch.Tensor(x) for x in outs] | |
| else: | |
| outs = self.model(inputs) | |
| return outs | |
| def read_audio(path: str, | |
| sampling_rate: int = 16000): | |
| wav, sr = torchaudio.load(path) | |
| if wav.size(0) > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| if sr != sampling_rate: | |
| transform = torchaudio.transforms.Resample(orig_freq=sr, | |
| new_freq=sampling_rate) | |
| wav = transform(wav) | |
| sr = sampling_rate | |
| assert sr == sampling_rate | |
| return wav.squeeze(0) | |
| def save_audio(path: str, | |
| tensor: torch.Tensor, | |
| sampling_rate: int = 16000): | |
| torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16) | |
| def init_jit_model(model_path: str, | |
| device=torch.device('cpu')): | |
| torch.set_grad_enabled(False) | |
| model = torch.jit.load(model_path, map_location=device) | |
| model.eval() | |
| return model | |
| def make_visualization(probs, step): | |
| import pandas as pd | |
| pd.DataFrame({'probs': probs}, | |
| index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8), | |
| kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step], | |
| xlabel='seconds', | |
| ylabel='speech probability', | |
| colormap='tab20') | |
| def get_speech_timestamps(audio: torch.Tensor, | |
| model, | |
| threshold: float = 0.5, | |
| sampling_rate: int = 16000, | |
| min_speech_duration_ms: int = 250, | |
| max_speech_duration_s: float = float('inf'), | |
| min_silence_duration_ms: int = 100, | |
| window_size_samples: int = 512, | |
| speech_pad_ms: int = 30, | |
| return_seconds: bool = False, | |
| visualize_probs: bool = False, | |
| progress_tracking_callback: Callable[[float], None] = None): | |
| """ | |
| This method is used for splitting long audios into speech chunks using silero VAD | |
| Parameters | |
| ---------- | |
| audio: torch.Tensor, one dimensional | |
| One dimensional float torch.Tensor, other types are casted to torch if possible | |
| model: preloaded .jit silero VAD model | |
| threshold: float (default - 0.5) | |
| Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. | |
| It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. | |
| sampling_rate: int (default - 16000) | |
| Currently silero VAD models support 8000 and 16000 sample rates | |
| min_speech_duration_ms: int (default - 250 milliseconds) | |
| Final speech chunks shorter min_speech_duration_ms are thrown out | |
| max_speech_duration_s: int (default - inf) | |
| Maximum duration of speech chunks in seconds | |
| Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting. | |
| Otherwise, they will be split aggressively just before max_speech_duration_s. | |
| min_silence_duration_ms: int (default - 100 milliseconds) | |
| In the end of each speech chunk wait for min_silence_duration_ms before separating it | |
| window_size_samples: int (default - 1536 samples) | |
| Audio chunks of window_size_samples size are fed to the silero VAD model. | |
| WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples for 8000 sample rate. | |
| Values other than these may affect model perfomance!! | |
| speech_pad_ms: int (default - 30 milliseconds) | |
| Final speech chunks are padded by speech_pad_ms each side | |
| return_seconds: bool (default - False) | |
| whether return timestamps in seconds (default - samples) | |
| visualize_probs: bool (default - False) | |
| whether draw prob hist or not | |
| progress_tracking_callback: Callable[[float], None] (default - None) | |
| callback function taking progress in percents as an argument | |
| Returns | |
| ---------- | |
| speeches: list of dicts | |
| list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds) | |
| """ | |
| if not torch.is_tensor(audio): | |
| try: | |
| audio = torch.Tensor(audio) | |
| except: | |
| raise TypeError("Audio cannot be casted to tensor. Cast it manually") | |
| if len(audio.shape) > 1: | |
| for i in range(len(audio.shape)): # trying to squeeze empty dimensions | |
| audio = audio.squeeze(0) | |
| if len(audio.shape) > 1: | |
| raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?") | |
| if sampling_rate > 16000 and (sampling_rate % 16000 == 0): | |
| step = sampling_rate // 16000 | |
| sampling_rate = 16000 | |
| audio = audio[::step] | |
| warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!') | |
| else: | |
| step = 1 | |
| if sampling_rate == 8000 and window_size_samples > 768: | |
| warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!') | |
| if window_size_samples not in [256, 512, 768, 1024, 1536]: | |
| warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate') | |
| model.reset_states() | |
| min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 | |
| speech_pad_samples = sampling_rate * speech_pad_ms / 1000 | |
| max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples | |
| min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 | |
| min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 | |
| audio_length_samples = len(audio) | |
| speech_probs = [] | |
| for current_start_sample in range(0, audio_length_samples, window_size_samples): | |
| chunk = audio[current_start_sample: current_start_sample + window_size_samples] | |
| if len(chunk) < window_size_samples: | |
| chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk)))) | |
| speech_prob = model(chunk, sampling_rate).item() | |
| speech_probs.append(speech_prob) | |
| # caculate progress and seng it to callback function | |
| progress = current_start_sample + window_size_samples | |
| if progress > audio_length_samples: | |
| progress = audio_length_samples | |
| progress_percent = (progress / audio_length_samples) * 100 | |
| if progress_tracking_callback: | |
| progress_tracking_callback(progress_percent) | |
| triggered = False | |
| speeches = [] | |
| current_speech = {} | |
| neg_threshold = threshold - 0.15 | |
| temp_end = 0 # to save potential segment end (and tolerate some silence) | |
| prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached | |
| for i, speech_prob in enumerate(speech_probs): | |
| if (speech_prob >= threshold) and temp_end: | |
| temp_end = 0 | |
| if next_start < prev_end: | |
| next_start = window_size_samples * i | |
| if (speech_prob >= threshold) and not triggered: | |
| triggered = True | |
| current_speech['start'] = window_size_samples * i | |
| continue | |
| if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples: | |
| if prev_end: | |
| current_speech['end'] = prev_end | |
| speeches.append(current_speech) | |
| current_speech = {} | |
| if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres) | |
| triggered = False | |
| else: | |
| current_speech['start'] = next_start | |
| prev_end = next_start = temp_end = 0 | |
| else: | |
| current_speech['end'] = window_size_samples * i | |
| speeches.append(current_speech) | |
| current_speech = {} | |
| prev_end = next_start = temp_end = 0 | |
| triggered = False | |
| continue | |
| if (speech_prob < neg_threshold) and triggered: | |
| if not temp_end: | |
| temp_end = window_size_samples * i | |
| if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence | |
| prev_end = temp_end | |
| if (window_size_samples * i) - temp_end < min_silence_samples: | |
| continue | |
| else: | |
| current_speech['end'] = temp_end | |
| if (current_speech['end'] - current_speech['start']) > min_speech_samples: | |
| speeches.append(current_speech) | |
| current_speech = {} | |
| prev_end = next_start = temp_end = 0 | |
| triggered = False | |
| continue | |
| if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples: | |
| current_speech['end'] = audio_length_samples | |
| speeches.append(current_speech) | |
| for i, speech in enumerate(speeches): | |
| if i == 0: | |
| speech['start'] = int(max(0, speech['start'] - speech_pad_samples)) | |
| if i != len(speeches) - 1: | |
| silence_duration = speeches[i+1]['start'] - speech['end'] | |
| if silence_duration < 2 * speech_pad_samples: | |
| speech['end'] += int(silence_duration // 2) | |
| speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2)) | |
| else: | |
| speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) | |
| speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples)) | |
| else: | |
| speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) | |
| if return_seconds: | |
| for speech_dict in speeches: | |
| speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1) | |
| speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1) | |
| elif step > 1: | |
| for speech_dict in speeches: | |
| speech_dict['start'] *= step | |
| speech_dict['end'] *= step | |
| if visualize_probs: | |
| make_visualization(speech_probs, window_size_samples / sampling_rate) | |
| return speeches | |
| def get_number_ts(wav: torch.Tensor, | |
| model, | |
| model_stride=8, | |
| hop_length=160, | |
| sample_rate=16000): | |
| wav = torch.unsqueeze(wav, dim=0) | |
| perframe_logits = model(wav)[0] | |
| perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided) | |
| extended_preds = [] | |
| for i in perframe_preds: | |
| extended_preds.extend([i.item()] * model_stride) | |
| # len(extended_preds) is *num_frames_real*; for each frame of audio we know if it has a number in it. | |
| triggered = False | |
| timings = [] | |
| cur_timing = {} | |
| for i, pred in enumerate(extended_preds): | |
| if pred == 1: | |
| if not triggered: | |
| cur_timing['start'] = int((i * hop_length) / (sample_rate / 1000)) | |
| triggered = True | |
| elif pred == 0: | |
| if triggered: | |
| cur_timing['end'] = int((i * hop_length) / (sample_rate / 1000)) | |
| timings.append(cur_timing) | |
| cur_timing = {} | |
| triggered = False | |
| if cur_timing: | |
| cur_timing['end'] = int(len(wav) / (sample_rate / 1000)) | |
| timings.append(cur_timing) | |
| return timings | |
| def get_language(wav: torch.Tensor, | |
| model): | |
| wav = torch.unsqueeze(wav, dim=0) | |
| lang_logits = model(wav)[2] | |
| lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1 | |
| assert lang_pred < len(languages) | |
| return languages[lang_pred] | |
| def get_language_and_group(wav: torch.Tensor, | |
| model, | |
| lang_dict: dict, | |
| lang_group_dict: dict, | |
| top_n=1): | |
| wav = torch.unsqueeze(wav, dim=0) | |
| lang_logits, lang_group_logits = model(wav) | |
| softm = torch.softmax(lang_logits, dim=1).squeeze() | |
| softm_group = torch.softmax(lang_group_logits, dim=1).squeeze() | |
| srtd = torch.argsort(softm, descending=True) | |
| srtd_group = torch.argsort(softm_group, descending=True) | |
| outs = [] | |
| outs_group = [] | |
| for i in range(top_n): | |
| prob = round(softm[srtd[i]].item(), 2) | |
| prob_group = round(softm_group[srtd_group[i]].item(), 2) | |
| outs.append((lang_dict[str(srtd[i].item())], prob)) | |
| outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group)) | |
| return outs, outs_group | |
| class VADIterator: | |
| def __init__(self, | |
| model, | |
| threshold: float = 0.5, | |
| sampling_rate: int = 16000, | |
| min_silence_duration_ms: int = 100, | |
| speech_pad_ms: int = 30 | |
| ): | |
| """ | |
| Class for stream imitation | |
| Parameters | |
| ---------- | |
| model: preloaded .jit silero VAD model | |
| threshold: float (default - 0.5) | |
| Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. | |
| It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. | |
| sampling_rate: int (default - 16000) | |
| Currently silero VAD models support 8000 and 16000 sample rates | |
| min_silence_duration_ms: int (default - 100 milliseconds) | |
| In the end of each speech chunk wait for min_silence_duration_ms before separating it | |
| speech_pad_ms: int (default - 30 milliseconds) | |
| Final speech chunks are padded by speech_pad_ms each side | |
| """ | |
| self.model = model | |
| self.threshold = threshold | |
| self.sampling_rate = sampling_rate | |
| if sampling_rate not in [8000, 16000]: | |
| raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]') | |
| self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 | |
| self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 | |
| self.reset_states() | |
| def reset_states(self): | |
| self.model.reset_states() | |
| self.triggered = False | |
| self.temp_end = 0 | |
| self.current_sample = 0 | |
| def __call__(self, x, return_seconds=False): | |
| """ | |
| x: torch.Tensor | |
| audio chunk (see examples in repo) | |
| return_seconds: bool (default - False) | |
| whether return timestamps in seconds (default - samples) | |
| """ | |
| if not torch.is_tensor(x): | |
| try: | |
| x = torch.Tensor(x) | |
| except: | |
| raise TypeError("Audio cannot be casted to tensor. Cast it manually") | |
| window_size_samples = len(x[0]) if x.dim() == 2 else len(x) | |
| self.current_sample += window_size_samples | |
| speech_prob = self.model(x, self.sampling_rate).item() | |
| if (speech_prob >= self.threshold) and self.temp_end: | |
| self.temp_end = 0 | |
| if (speech_prob >= self.threshold) and not self.triggered: | |
| self.triggered = True | |
| speech_start = self.current_sample - self.speech_pad_samples | |
| return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)} | |
| if (speech_prob < self.threshold - 0.15) and self.triggered: | |
| if not self.temp_end: | |
| self.temp_end = self.current_sample | |
| if self.current_sample - self.temp_end < self.min_silence_samples: | |
| return None | |
| else: | |
| speech_end = self.temp_end + self.speech_pad_samples | |
| self.temp_end = 0 | |
| self.triggered = False | |
| return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)} | |
| return None | |
| def collect_chunks(tss: List[dict], | |
| wav: torch.Tensor): | |
| chunks = [] | |
| for i in tss: | |
| chunks.append(wav[i['start']: i['end']]) | |
| return torch.cat(chunks) | |
| def drop_chunks(tss: List[dict], | |
| wav: torch.Tensor): | |
| chunks = [] | |
| cur_start = 0 | |
| for i in tss: | |
| chunks.append((wav[cur_start: i['start']])) | |
| cur_start = i['end'] | |
| return torch.cat(chunks) | |