Spaces:
Running
on
Zero
Running
on
Zero
| import librosa | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn.utils.rnn import pad_sequence | |
| import torchaudio | |
| from typing import List | |
| def _mel_filters(n_mels: int) -> torch.Tensor: | |
| """Load the mel filterbank matrix for projecting STFT into a Mel spectrogram.""" | |
| assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" | |
| if n_mels == 128: | |
| return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=128)) | |
| else: | |
| return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=80)) | |
| def load_audio(file_path, target_rate=16000, max_length=None): | |
| """ | |
| Open an audio file and read as mono waveform, resampling as necessary | |
| If max_length is provided, truncate the audio to that length | |
| """ | |
| waveform, sample_rate = torchaudio.load(file_path) | |
| if sample_rate != target_rate: | |
| waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform) | |
| audio = waveform[0] # get the first channel | |
| # Truncate audio if it exceeds max_length | |
| if max_length is not None and audio.shape[0] > max_length: | |
| audio = audio[:max_length] | |
| return audio | |
| def log_mel_spectrogram(audio, n_mels=128, padding=479, device=None): | |
| """ | |
| Compute the log-Mel spectrogram with specific padding for StepAudio | |
| """ | |
| if not torch.is_tensor(audio): | |
| if isinstance(audio, str): | |
| audio = load_audio(audio) | |
| audio = torch.from_numpy(audio) | |
| if device is not None: | |
| audio = audio.to(device) | |
| if padding > 0: | |
| audio = F.pad(audio, (0, padding)) | |
| window = torch.hann_window(400).to(audio.device) | |
| stft = torch.stft(audio, 400, 160, window=window, return_complex=True) | |
| magnitudes = stft[..., :-1].abs() ** 2 | |
| filters = _mel_filters(n_mels) | |
| mel_spec = filters @ magnitudes | |
| log_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
| log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | |
| log_spec = (log_spec + 4.0) / 4.0 | |
| return log_spec | |
| def compute_token_num(max_feature_len): | |
| # First, audio goes through encoder: | |
| # 1. conv1: kernel=3, stride=1, padding=1 -> size unchanged | |
| # 2. conv2: kernel=3, stride=2, padding=1 -> size/2 | |
| # 3. avg_pooler: kernel=2, stride=2 -> size/2 | |
| max_feature_len = max_feature_len - 2 # remove padding | |
| encoder_output_dim = (max_feature_len + 1) // 2 // 2 # after conv2 and avg_pooler | |
| # Then through adaptor (parameters from config file): | |
| padding = 1 | |
| kernel_size = 3 # from config: audio_encoder_config.kernel_size | |
| stride = 2 # from config: audio_encoder_config.adapter_stride | |
| adapter_output_dim = (encoder_output_dim + 2 * padding - kernel_size) // stride + 1 | |
| return adapter_output_dim | |
| def padding_mels(data: List[torch.Tensor]): | |
| """ Padding the data into batch data | |
| Parameters | |
| ---------- | |
| data: List[Tensor], shape of Tensor (128, T) | |
| Returns: | |
| ------- | |
| feats, feats lengths | |
| """ | |
| sample = data | |
| assert isinstance(sample, list) | |
| feats_lengths = torch.tensor([s.size(1)-2 for s in sample], | |
| dtype=torch.int32) | |
| feats = [s.t() for s in sample] | |
| padded_feats = pad_sequence(feats, | |
| batch_first=True, | |
| padding_value=0) | |
| return padded_feats.transpose(1, 2), feats_lengths |