Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| sz_float = 4 # size of a float | |
| epsilon = 10e-8 # fudge factor for normalization | |
| class AugmentMelSTFT(nn.Module): | |
| def __init__( | |
| self, | |
| n_mels=128, | |
| sr=32000, | |
| win_length=None, | |
| hopsize=320, | |
| n_fft=1024, | |
| freqm=0, | |
| timem=0, | |
| htk=False, | |
| fmin=0.0, | |
| fmax=None, | |
| norm=1, | |
| fmin_aug_range=1, | |
| fmax_aug_range=1, | |
| fast_norm=False, | |
| preamp=True, | |
| padding="center", | |
| periodic_window=True, | |
| ): | |
| torch.nn.Module.__init__(self) | |
| # adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e | |
| # Similar config to the spectrograms used in AST: https://github.com/YuanGongND/ast | |
| if win_length is None: | |
| win_length = n_fft | |
| if isinstance(win_length, list) or isinstance(win_length, tuple): | |
| assert isinstance(n_fft, list) or isinstance(n_fft, tuple) | |
| assert len(win_length) == len(n_fft) | |
| else: | |
| win_length = [win_length] | |
| n_fft = [n_fft] | |
| self.win_length = win_length | |
| self.n_mels = n_mels | |
| self.n_fft = n_fft | |
| self.sr = sr | |
| self.htk = htk | |
| self.fmin = fmin | |
| if fmax is None: | |
| fmax = sr // 2 - fmax_aug_range // 2 | |
| self.fmax = fmax | |
| self.norm = norm | |
| self.hopsize = hopsize | |
| self.preamp = preamp | |
| for win_l in self.win_length: | |
| self.register_buffer( | |
| f"window_{win_l}", | |
| torch.hann_window(win_l, periodic=periodic_window), | |
| persistent=False, | |
| ) | |
| assert ( | |
| fmin_aug_range >= 1 | |
| ), f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation" | |
| assert ( | |
| fmin_aug_range >= 1 | |
| ), f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation" | |
| self.fmin_aug_range = fmin_aug_range | |
| self.fmax_aug_range = fmax_aug_range | |
| self.register_buffer( | |
| "preemphasis_coefficient", torch.as_tensor([[[-0.97, 1]]]), persistent=False | |
| ) | |
| if freqm == 0: | |
| self.freqm = torch.nn.Identity() | |
| else: | |
| self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=False) | |
| if timem == 0: | |
| self.timem = torch.nn.Identity() | |
| else: | |
| self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=False) | |
| self.fast_norm = fast_norm | |
| self.padding = padding | |
| if padding not in ["center", "same"]: | |
| raise ValueError("Padding must be 'center' or 'same'.") | |
| self.iden = nn.Identity() | |
| def forward(self, x): | |
| if self.preamp: | |
| x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient) | |
| x = x.squeeze(1) | |
| fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item() | |
| fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item() | |
| # don't augment eval data | |
| if not self.training: | |
| fmin = self.fmin | |
| fmax = self.fmax | |
| mels = [] | |
| for n_fft, win_length in zip(self.n_fft, self.win_length): | |
| x_temp = x | |
| if self.padding == "same": | |
| pad = win_length - self.hopsize | |
| self.iden(x_temp) # printing | |
| x_temp = torch.nn.functional.pad(x_temp, (pad // 2, pad // 2), mode="reflect") | |
| self.iden(x_temp) # printing | |
| x_temp = torch.stft( | |
| x_temp, | |
| n_fft, | |
| hop_length=self.hopsize, | |
| win_length=win_length, | |
| center=self.padding == "center", | |
| normalized=False, | |
| window=getattr(self, f"window_{win_length}"), | |
| return_complex=True | |
| ) | |
| x_temp = torch.view_as_real(x_temp) | |
| x_temp = (x_temp ** 2).sum(dim=-1) # power mag | |
| mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, n_fft, self.sr, | |
| fmin, fmax, vtln_low=100.0, vtln_high=-500., | |
| vtln_warp_factor=1.0) | |
| mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0), | |
| device=x.device) | |
| with torch.cuda.amp.autocast(enabled=False): | |
| x_temp = torch.matmul(mel_basis, x_temp) | |
| x_temp = torch.log(torch.clip(x_temp, min=1e-7)) | |
| mels.append(x_temp) | |
| mels = torch.stack(mels, dim=1) | |
| if self.training: | |
| mels = self.freqm(mels) | |
| mels = self.timem(mels) | |
| if self.fast_norm: | |
| mels = (mels + 4.5) / 5.0 # fast normalization | |
| return mels | |
| def extra_repr(self): | |
| return "winsize={}, hopsize={}".format(self.win_length, self.hopsize) | |