|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import random |
|
|
|
|
|
|
|
|
def frame_shift(features, label=None, net_pooling=None): |
|
|
if label is not None: |
|
|
batch_size, _, _ = features.shape |
|
|
shifted_feature = [] |
|
|
shifted_label = [] |
|
|
for idx in range(batch_size): |
|
|
shift = int(random.gauss(0, 90)) |
|
|
shifted_feature.append(torch.roll(features[idx], shift, dims=-1)) |
|
|
shift = -abs(shift) // net_pooling if shift < 0 else shift // net_pooling |
|
|
shifted_label.append(torch.roll(label[idx], shift, dims=-1)) |
|
|
return torch.stack(shifted_feature), torch.stack(shifted_label) |
|
|
else: |
|
|
batch_size, _, _ = features.shape |
|
|
shifted_feature = [] |
|
|
for idx in range(batch_size): |
|
|
shift = int(random.gauss(0, 90)) |
|
|
shifted_feature.append(torch.roll(features[idx], shift, dims=-1)) |
|
|
return torch.stack(shifted_feature) |
|
|
|
|
|
|
|
|
def mixup(features, label=None, permutation=None, c=None, alpha=0.2, beta=0.2, mixup_label_type="soft", returnc=False): |
|
|
with torch.no_grad(): |
|
|
batch_size = features.size(0) |
|
|
|
|
|
if permutation is None: |
|
|
permutation = torch.randperm(batch_size) |
|
|
|
|
|
if c is None: |
|
|
if mixup_label_type == "soft": |
|
|
c = np.random.beta(alpha, beta) |
|
|
elif mixup_label_type == "hard": |
|
|
c = np.random.beta(alpha, beta) * 0.4 + 0.3 |
|
|
|
|
|
mixed_features = c * features + (1 - c) * features[permutation, :] |
|
|
if label is not None: |
|
|
if mixup_label_type == "soft": |
|
|
mixed_label = torch.clamp(c * label + (1 - c) * label[permutation, :], min=0, max=1) |
|
|
elif mixup_label_type == "hard": |
|
|
mixed_label = torch.clamp(label + label[permutation, :], min=0, max=1) |
|
|
else: |
|
|
raise NotImplementedError(f"mixup_label_type: {mixup_label_type} not implemented. choice in " |
|
|
f"{'soft', 'hard'}") |
|
|
if returnc: |
|
|
return mixed_features, mixed_label, c, permutation |
|
|
else: |
|
|
return mixed_features, mixed_label |
|
|
else: |
|
|
return mixed_features |
|
|
|
|
|
|
|
|
def time_mask(features, labels=None, net_pooling=None, mask_ratios=(10, 20)): |
|
|
|
|
|
if labels is not None: |
|
|
_, _, n_frame = labels.shape |
|
|
t_width = torch.randint(low=int(n_frame/mask_ratios[1]), high=int(n_frame/mask_ratios[0]), size=(1,)) |
|
|
t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,)) |
|
|
features[:, :, t_low * net_pooling:(t_low+t_width)*net_pooling] = 0 |
|
|
labels[:, :, t_low:t_low+t_width] = 0 |
|
|
return features, labels |
|
|
else: |
|
|
_, _, n_frame = features.shape |
|
|
t_width = torch.randint(low=int(n_frame/mask_ratios[1]), high=int(n_frame/mask_ratios[0]), size=(1,)) |
|
|
t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,)) |
|
|
features[:, :, t_low:(t_low + t_width)] = 0 |
|
|
return features |
|
|
|
|
|
|
|
|
def feature_transformation(features, n_transform, choice, filter_db_range, filter_bands, |
|
|
filter_minimum_bandwidth, filter_type, freq_mask_ratio, noise_snrs): |
|
|
if n_transform == 2: |
|
|
feature_list = [] |
|
|
for _ in range(n_transform): |
|
|
features_temp = features |
|
|
if choice[0]: |
|
|
features_temp = filt_aug(features_temp, db_range=filter_db_range, n_band=filter_bands, |
|
|
min_bw=filter_minimum_bandwidth, filter_type=filter_type) |
|
|
if choice[1]: |
|
|
features_temp = freq_mask(features_temp, mask_ratio=freq_mask_ratio) |
|
|
if choice[2]: |
|
|
features_temp = add_noise(features_temp, snrs=noise_snrs) |
|
|
feature_list.append(features_temp) |
|
|
return feature_list |
|
|
elif n_transform == 1: |
|
|
if choice[0]: |
|
|
features = filt_aug(features, db_range=filter_db_range, n_band=filter_bands, |
|
|
min_bw=filter_minimum_bandwidth, filter_type=filter_type) |
|
|
if choice[1]: |
|
|
features = freq_mask(features, mask_ratio=freq_mask_ratio) |
|
|
if choice[2]: |
|
|
features = add_noise(features, snrs=noise_snrs) |
|
|
return [features, features] |
|
|
else: |
|
|
return [features, features] |
|
|
|
|
|
|
|
|
def filt_aug(features, db_range=[-6, 6], n_band=[3, 6], min_bw=6, filter_type="linear"): |
|
|
|
|
|
if not isinstance(filter_type, str): |
|
|
if torch.rand(1).item() < filter_type: |
|
|
filter_type = "step" |
|
|
n_band = [2, 5] |
|
|
min_bw = 4 |
|
|
else: |
|
|
filter_type = "linear" |
|
|
n_band = [3, 6] |
|
|
min_bw = 6 |
|
|
|
|
|
batch_size, n_freq_bin, _ = features.shape |
|
|
n_freq_band = torch.randint(low=n_band[0], high=n_band[1], size=(1,)).item() |
|
|
if n_freq_band > 1: |
|
|
while n_freq_bin - n_freq_band * min_bw + 1 < 0: |
|
|
min_bw -= 1 |
|
|
band_bndry_freqs = torch.sort(torch.randint(0, n_freq_bin - n_freq_band * min_bw + 1, |
|
|
(n_freq_band - 1,)))[0] + \ |
|
|
torch.arange(1, n_freq_band) * min_bw |
|
|
band_bndry_freqs = torch.cat((torch.tensor([0]), band_bndry_freqs, torch.tensor([n_freq_bin]))) |
|
|
|
|
|
if filter_type == "step": |
|
|
band_factors = torch.rand((batch_size, n_freq_band)).to(features) * (db_range[1] - db_range[0]) + db_range[0] |
|
|
band_factors = 10 ** (band_factors / 20) |
|
|
|
|
|
freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features) |
|
|
for i in range(n_freq_band): |
|
|
freq_filt[:, band_bndry_freqs[i]:band_bndry_freqs[i + 1], :] = band_factors[:, i].unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
|
elif filter_type == "linear": |
|
|
band_factors = torch.rand((batch_size, n_freq_band + 1)).to(features) * (db_range[1] - db_range[0]) + db_range[0] |
|
|
freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features) |
|
|
for i in range(n_freq_band): |
|
|
for j in range(batch_size): |
|
|
freq_filt[j, band_bndry_freqs[i]:band_bndry_freqs[i+1], :] = \ |
|
|
torch.linspace(band_factors[j, i], band_factors[j, i+1], |
|
|
band_bndry_freqs[i+1] - band_bndry_freqs[i]).unsqueeze(-1) |
|
|
freq_filt = 10 ** (freq_filt / 20) |
|
|
return features * freq_filt |
|
|
|
|
|
else: |
|
|
return features |
|
|
|
|
|
|
|
|
def freq_mask(features, mask_ratio=16): |
|
|
batch_size, n_freq_bin, _ = features.shape |
|
|
max_mask = int(n_freq_bin/mask_ratio) |
|
|
if max_mask == 1: |
|
|
f_widths = torch.ones(batch_size) |
|
|
else: |
|
|
f_widths = torch.randint(low=1, high=max_mask, size=(batch_size,)) |
|
|
|
|
|
for i in range(batch_size): |
|
|
f_width = f_widths[i] |
|
|
f_low = torch.randint(low=0, high=n_freq_bin-f_width, size=(1,)) |
|
|
|
|
|
features[i, f_low:f_low+f_width, :] = 0 |
|
|
return features |
|
|
|
|
|
|
|
|
def add_noise(features, snrs=(15, 30), dims=(1, 2)): |
|
|
if isinstance(snrs, (list, tuple)): |
|
|
snr = (snrs[0] - snrs[1]) * torch.rand((features.shape[0],), device=features.device).reshape(-1, 1, 1) + snrs[1] |
|
|
else: |
|
|
snr = snrs |
|
|
|
|
|
snr = 10 ** (snr / 20) |
|
|
sigma = torch.std(features, dim=dims, keepdim=True) / snr |
|
|
return features + torch.randn(features.shape, device=features.device) * sigma |
|
|
|
|
|
|
|
|
|