import numpy as np import torch import random def frame_shift(mels, labels, net_pooling=4): bsz, n_bands, frames = mels.shape shifted = [] new_labels = [] for bindx in range(bsz): shift = int(random.gauss(0, 90)) shifted.append(torch.roll(mels[bindx], shift, dims=-1)) shift = -abs(shift) // net_pooling if shift < 0 else shift // net_pooling new_labels.append(torch.roll(labels[bindx], shift, dims=-1)) return torch.stack(shifted), torch.stack(new_labels) def mixup(data, target=None, alpha=0.2, beta=0.2, mixup_label_type="soft"): """Mixup data augmentation by permuting the data Args: data: input tensor, must be a batch so data can be permuted and mixed. target: tensor of the target to be mixed, if None, do not return targets. alpha: float, the parameter to the np.random.beta distribution beta: float, the parameter to the np.random.beta distribution mixup_label_type: str, the type of mixup to be used choice between {'soft', 'hard'}. Returns: torch.Tensor of mixed data and labels if given """ with torch.no_grad(): batch_size = data.size(0) c = np.random.beta(alpha, beta) perm = torch.randperm(batch_size) mixed_data = c * data + (1 - c) * data[perm, :] if target is not None: if mixup_label_type == "soft": mixed_target = torch.clamp( c * target + (1 - c) * target[perm, :], min=0, max=1 ) elif mixup_label_type == "hard": mixed_target = torch.clamp(target + target[perm, :], min=0, max=1) else: raise NotImplementedError( f"mixup_label_type: {mixup_label_type} not implemented. choice in " f"{'soft', 'hard'}" ) return mixed_data, mixed_target else: return mixed_data def add_noise(mels, snrs=(6, 30), dims=(1, 2)): """ Add white noise to mels spectrograms Args: mels: torch.tensor, mels spectrograms to apply the white noise to. snrs: int or tuple, the range of snrs to choose from if tuple (uniform) dims: tuple, the dimensions for which to compute the standard deviation (default to (1,2) because assume an input of a batch of mel spectrograms. Returns: torch.Tensor of mels with noise applied """ if isinstance(snrs, (list, tuple)): snr = (snrs[0] - snrs[1]) * torch.rand( (mels.shape[0],), device=mels.device ).reshape(-1, 1, 1) + snrs[1] else: snr = snrs snr = 10 ** (snr / 20) # linear domain sigma = torch.std(mels, dim=dims, keepdim=True) / snr mels = mels + torch.randn(mels.shape, device=mels.device) * sigma return mels