FlexSED / src /utils /data_aug.py
OpenSound's picture
Upload 544 files
3b6a091 verified
# Some codes are adopted from https://github.com/DCASE-REPO/DESED_task
import torch
import numpy as np
import random
def frame_shift(features, label=None, net_pooling=None):
if label is not None:
batch_size, _, _ = features.shape
shifted_feature = []
shifted_label = []
for idx in range(batch_size):
shift = int(random.gauss(0, 90))
shifted_feature.append(torch.roll(features[idx], shift, dims=-1))
shift = -abs(shift) // net_pooling if shift < 0 else shift // net_pooling
shifted_label.append(torch.roll(label[idx], shift, dims=-1))
return torch.stack(shifted_feature), torch.stack(shifted_label)
else:
batch_size, _, _ = features.shape
shifted_feature = []
for idx in range(batch_size):
shift = int(random.gauss(0, 90))
shifted_feature.append(torch.roll(features[idx], shift, dims=-1))
return torch.stack(shifted_feature)
def mixup(features, label=None, permutation=None, c=None, alpha=0.2, beta=0.2, mixup_label_type="soft", returnc=False):
with torch.no_grad():
batch_size = features.size(0)
if permutation is None:
permutation = torch.randperm(batch_size)
if c is None:
if mixup_label_type == "soft":
c = np.random.beta(alpha, beta)
elif mixup_label_type == "hard":
c = np.random.beta(alpha, beta) * 0.4 + 0.3 # c in [0.3, 0.7]
mixed_features = c * features + (1 - c) * features[permutation, :]
if label is not None:
if mixup_label_type == "soft":
mixed_label = torch.clamp(c * label + (1 - c) * label[permutation, :], min=0, max=1)
elif mixup_label_type == "hard":
mixed_label = torch.clamp(label + label[permutation, :], min=0, max=1)
else:
raise NotImplementedError(f"mixup_label_type: {mixup_label_type} not implemented. choice in "
f"{'soft', 'hard'}")
if returnc:
return mixed_features, mixed_label, c, permutation
else:
return mixed_features, mixed_label
else:
return mixed_features
def time_mask(features, labels=None, net_pooling=None, mask_ratios=(10, 20)):
# print(labels.shape)
if labels is not None:
_, _, n_frame = labels.shape
t_width = torch.randint(low=int(n_frame/mask_ratios[1]), high=int(n_frame/mask_ratios[0]), size=(1,)) # [low, high)
t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,))
features[:, :, t_low * net_pooling:(t_low+t_width)*net_pooling] = 0
labels[:, :, t_low:t_low+t_width] = 0
return features, labels
else:
_, _, n_frame = features.shape
t_width = torch.randint(low=int(n_frame/mask_ratios[1]), high=int(n_frame/mask_ratios[0]), size=(1,)) # [low, high)
t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,))
features[:, :, t_low:(t_low + t_width)] = 0
return features
def feature_transformation(features, n_transform, choice, filter_db_range, filter_bands,
filter_minimum_bandwidth, filter_type, freq_mask_ratio, noise_snrs):
if n_transform == 2:
feature_list = []
for _ in range(n_transform):
features_temp = features
if choice[0]:
features_temp = filt_aug(features_temp, db_range=filter_db_range, n_band=filter_bands,
min_bw=filter_minimum_bandwidth, filter_type=filter_type)
if choice[1]:
features_temp = freq_mask(features_temp, mask_ratio=freq_mask_ratio)
if choice[2]:
features_temp = add_noise(features_temp, snrs=noise_snrs)
feature_list.append(features_temp)
return feature_list
elif n_transform == 1:
if choice[0]:
features = filt_aug(features, db_range=filter_db_range, n_band=filter_bands,
min_bw=filter_minimum_bandwidth, filter_type=filter_type)
if choice[1]:
features = freq_mask(features, mask_ratio=freq_mask_ratio)
if choice[2]:
features = add_noise(features, snrs=noise_snrs)
return [features, features]
else:
return [features, features]
def filt_aug(features, db_range=[-6, 6], n_band=[3, 6], min_bw=6, filter_type="linear"):
# this is updated FilterAugment algorithm used for ICASSP 2022
if not isinstance(filter_type, str):
if torch.rand(1).item() < filter_type:
filter_type = "step"
n_band = [2, 5]
min_bw = 4
else:
filter_type = "linear"
n_band = [3, 6]
min_bw = 6
batch_size, n_freq_bin, _ = features.shape
n_freq_band = torch.randint(low=n_band[0], high=n_band[1], size=(1,)).item() # [low, high)
if n_freq_band > 1:
while n_freq_bin - n_freq_band * min_bw + 1 < 0:
min_bw -= 1
band_bndry_freqs = torch.sort(torch.randint(0, n_freq_bin - n_freq_band * min_bw + 1,
(n_freq_band - 1,)))[0] + \
torch.arange(1, n_freq_band) * min_bw
band_bndry_freqs = torch.cat((torch.tensor([0]), band_bndry_freqs, torch.tensor([n_freq_bin])))
if filter_type == "step":
band_factors = torch.rand((batch_size, n_freq_band)).to(features) * (db_range[1] - db_range[0]) + db_range[0]
band_factors = 10 ** (band_factors / 20)
freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features)
for i in range(n_freq_band):
freq_filt[:, band_bndry_freqs[i]:band_bndry_freqs[i + 1], :] = band_factors[:, i].unsqueeze(-1).unsqueeze(-1)
elif filter_type == "linear":
band_factors = torch.rand((batch_size, n_freq_band + 1)).to(features) * (db_range[1] - db_range[0]) + db_range[0]
freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features)
for i in range(n_freq_band):
for j in range(batch_size):
freq_filt[j, band_bndry_freqs[i]:band_bndry_freqs[i+1], :] = \
torch.linspace(band_factors[j, i], band_factors[j, i+1],
band_bndry_freqs[i+1] - band_bndry_freqs[i]).unsqueeze(-1)
freq_filt = 10 ** (freq_filt / 20)
return features * freq_filt
else:
return features
def freq_mask(features, mask_ratio=16):
batch_size, n_freq_bin, _ = features.shape
max_mask = int(n_freq_bin/mask_ratio)
if max_mask == 1:
f_widths = torch.ones(batch_size)
else:
f_widths = torch.randint(low=1, high=max_mask, size=(batch_size,)) # [low, high)
for i in range(batch_size):
f_width = f_widths[i]
f_low = torch.randint(low=0, high=n_freq_bin-f_width, size=(1,))
features[i, f_low:f_low+f_width, :] = 0
return features
def add_noise(features, snrs=(15, 30), dims=(1, 2)):
if isinstance(snrs, (list, tuple)):
snr = (snrs[0] - snrs[1]) * torch.rand((features.shape[0],), device=features.device).reshape(-1, 1, 1) + snrs[1]
else:
snr = snrs
snr = 10 ** (snr / 20)
sigma = torch.std(features, dim=dims, keepdim=True) / snr
return features + torch.randn(features.shape, device=features.device) * sigma