File size: 7,533 Bytes
3b6a091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# Some codes are adopted from https://github.com/DCASE-REPO/DESED_task
import torch
import numpy as np
import random
def frame_shift(features, label=None, net_pooling=None):
if label is not None:
batch_size, _, _ = features.shape
shifted_feature = []
shifted_label = []
for idx in range(batch_size):
shift = int(random.gauss(0, 90))
shifted_feature.append(torch.roll(features[idx], shift, dims=-1))
shift = -abs(shift) // net_pooling if shift < 0 else shift // net_pooling
shifted_label.append(torch.roll(label[idx], shift, dims=-1))
return torch.stack(shifted_feature), torch.stack(shifted_label)
else:
batch_size, _, _ = features.shape
shifted_feature = []
for idx in range(batch_size):
shift = int(random.gauss(0, 90))
shifted_feature.append(torch.roll(features[idx], shift, dims=-1))
return torch.stack(shifted_feature)
def mixup(features, label=None, permutation=None, c=None, alpha=0.2, beta=0.2, mixup_label_type="soft", returnc=False):
with torch.no_grad():
batch_size = features.size(0)
if permutation is None:
permutation = torch.randperm(batch_size)
if c is None:
if mixup_label_type == "soft":
c = np.random.beta(alpha, beta)
elif mixup_label_type == "hard":
c = np.random.beta(alpha, beta) * 0.4 + 0.3 # c in [0.3, 0.7]
mixed_features = c * features + (1 - c) * features[permutation, :]
if label is not None:
if mixup_label_type == "soft":
mixed_label = torch.clamp(c * label + (1 - c) * label[permutation, :], min=0, max=1)
elif mixup_label_type == "hard":
mixed_label = torch.clamp(label + label[permutation, :], min=0, max=1)
else:
raise NotImplementedError(f"mixup_label_type: {mixup_label_type} not implemented. choice in "
f"{'soft', 'hard'}")
if returnc:
return mixed_features, mixed_label, c, permutation
else:
return mixed_features, mixed_label
else:
return mixed_features
def time_mask(features, labels=None, net_pooling=None, mask_ratios=(10, 20)):
# print(labels.shape)
if labels is not None:
_, _, n_frame = labels.shape
t_width = torch.randint(low=int(n_frame/mask_ratios[1]), high=int(n_frame/mask_ratios[0]), size=(1,)) # [low, high)
t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,))
features[:, :, t_low * net_pooling:(t_low+t_width)*net_pooling] = 0
labels[:, :, t_low:t_low+t_width] = 0
return features, labels
else:
_, _, n_frame = features.shape
t_width = torch.randint(low=int(n_frame/mask_ratios[1]), high=int(n_frame/mask_ratios[0]), size=(1,)) # [low, high)
t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,))
features[:, :, t_low:(t_low + t_width)] = 0
return features
def feature_transformation(features, n_transform, choice, filter_db_range, filter_bands,
filter_minimum_bandwidth, filter_type, freq_mask_ratio, noise_snrs):
if n_transform == 2:
feature_list = []
for _ in range(n_transform):
features_temp = features
if choice[0]:
features_temp = filt_aug(features_temp, db_range=filter_db_range, n_band=filter_bands,
min_bw=filter_minimum_bandwidth, filter_type=filter_type)
if choice[1]:
features_temp = freq_mask(features_temp, mask_ratio=freq_mask_ratio)
if choice[2]:
features_temp = add_noise(features_temp, snrs=noise_snrs)
feature_list.append(features_temp)
return feature_list
elif n_transform == 1:
if choice[0]:
features = filt_aug(features, db_range=filter_db_range, n_band=filter_bands,
min_bw=filter_minimum_bandwidth, filter_type=filter_type)
if choice[1]:
features = freq_mask(features, mask_ratio=freq_mask_ratio)
if choice[2]:
features = add_noise(features, snrs=noise_snrs)
return [features, features]
else:
return [features, features]
def filt_aug(features, db_range=[-6, 6], n_band=[3, 6], min_bw=6, filter_type="linear"):
# this is updated FilterAugment algorithm used for ICASSP 2022
if not isinstance(filter_type, str):
if torch.rand(1).item() < filter_type:
filter_type = "step"
n_band = [2, 5]
min_bw = 4
else:
filter_type = "linear"
n_band = [3, 6]
min_bw = 6
batch_size, n_freq_bin, _ = features.shape
n_freq_band = torch.randint(low=n_band[0], high=n_band[1], size=(1,)).item() # [low, high)
if n_freq_band > 1:
while n_freq_bin - n_freq_band * min_bw + 1 < 0:
min_bw -= 1
band_bndry_freqs = torch.sort(torch.randint(0, n_freq_bin - n_freq_band * min_bw + 1,
(n_freq_band - 1,)))[0] + \
torch.arange(1, n_freq_band) * min_bw
band_bndry_freqs = torch.cat((torch.tensor([0]), band_bndry_freqs, torch.tensor([n_freq_bin])))
if filter_type == "step":
band_factors = torch.rand((batch_size, n_freq_band)).to(features) * (db_range[1] - db_range[0]) + db_range[0]
band_factors = 10 ** (band_factors / 20)
freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features)
for i in range(n_freq_band):
freq_filt[:, band_bndry_freqs[i]:band_bndry_freqs[i + 1], :] = band_factors[:, i].unsqueeze(-1).unsqueeze(-1)
elif filter_type == "linear":
band_factors = torch.rand((batch_size, n_freq_band + 1)).to(features) * (db_range[1] - db_range[0]) + db_range[0]
freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features)
for i in range(n_freq_band):
for j in range(batch_size):
freq_filt[j, band_bndry_freqs[i]:band_bndry_freqs[i+1], :] = \
torch.linspace(band_factors[j, i], band_factors[j, i+1],
band_bndry_freqs[i+1] - band_bndry_freqs[i]).unsqueeze(-1)
freq_filt = 10 ** (freq_filt / 20)
return features * freq_filt
else:
return features
def freq_mask(features, mask_ratio=16):
batch_size, n_freq_bin, _ = features.shape
max_mask = int(n_freq_bin/mask_ratio)
if max_mask == 1:
f_widths = torch.ones(batch_size)
else:
f_widths = torch.randint(low=1, high=max_mask, size=(batch_size,)) # [low, high)
for i in range(batch_size):
f_width = f_widths[i]
f_low = torch.randint(low=0, high=n_freq_bin-f_width, size=(1,))
features[i, f_low:f_low+f_width, :] = 0
return features
def add_noise(features, snrs=(15, 30), dims=(1, 2)):
if isinstance(snrs, (list, tuple)):
snr = (snrs[0] - snrs[1]) * torch.rand((features.shape[0],), device=features.device).reshape(-1, 1, 1) + snrs[1]
else:
snr = snrs
snr = 10 ** (snr / 20)
sigma = torch.std(features, dim=dims, keepdim=True) / snr
return features + torch.randn(features.shape, device=features.device) * sigma
|