|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchaudio |
|
|
from torch.utils.data import Sampler |
|
|
import os |
|
|
import math |
|
|
import scipy |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
|
|
|
from utils.evaluation_measures import compute_sed_eval_metrics |
|
|
|
|
|
|
|
|
class Encoder: |
|
|
def __init__(self, labels, audio_len, frame_len, frame_hop, net_pooling=1, sr=16000): |
|
|
if type(labels) in [np.ndarray, np.array]: |
|
|
labels = labels.tolist() |
|
|
self.labels = labels |
|
|
self.audio_len = audio_len |
|
|
self.frame_len = frame_len |
|
|
self.frame_hop = frame_hop |
|
|
self.sr = sr |
|
|
self.net_pooling = net_pooling |
|
|
n_samples = self.audio_len * self.sr |
|
|
self.n_frames = int(math.ceil(n_samples/2/self.frame_hop)*2 / self.net_pooling) |
|
|
|
|
|
def _time_to_frame(self, time): |
|
|
sample = time * self.sr |
|
|
frame = sample / self.frame_hop |
|
|
return np.clip(frame / self.net_pooling, a_min=0, a_max=self.n_frames) |
|
|
|
|
|
def _frame_to_time(self, frame): |
|
|
time = frame * self.net_pooling * self.frame_hop / self.sr |
|
|
return np.clip(time, a_min=0, a_max=self.audio_len) |
|
|
|
|
|
def encode_strong_df(self, events_df): |
|
|
|
|
|
true_labels = np.zeros((self.n_frames, len(self.labels))) |
|
|
for _, row in events_df.iterrows(): |
|
|
if not pd.isna(row['event_label']): |
|
|
label_idx = self.labels.index(row["event_label"]) |
|
|
onset = int(self._time_to_frame(row["onset"])) |
|
|
offset = int(np.ceil(self._time_to_frame(row["offset"]))) |
|
|
true_labels[onset:offset, label_idx] = 1 |
|
|
return true_labels |
|
|
|
|
|
def encode_weak(self, events): |
|
|
|
|
|
labels = np.zeros((len(self.labels))) |
|
|
if len(events) == 0: |
|
|
return labels |
|
|
else: |
|
|
for event in events: |
|
|
labels[self.labels.index(event)] = 1 |
|
|
return labels |
|
|
|
|
|
def decode_strong(self, outputs): |
|
|
|
|
|
pred = [] |
|
|
for i, label_column in enumerate(outputs.T): |
|
|
change_indices = self.find_contiguous_regions(label_column) |
|
|
for row in change_indices: |
|
|
onset = self._frame_to_time(row[0]) |
|
|
offset = self._frame_to_time(row[1]) |
|
|
onset = np.clip(onset, a_min=0, a_max=self.audio_len) |
|
|
offset = np.clip(offset, a_min=0, a_max=self.audio_len) |
|
|
pred.append([self.labels[i], onset, offset]) |
|
|
return pred |
|
|
|
|
|
def decode_weak(self, outputs): |
|
|
result_labels = [] |
|
|
for i, value in enumerate(outputs): |
|
|
if value == 1: |
|
|
result_labels.append(self.labels[i]) |
|
|
return result_labels |
|
|
|
|
|
def find_contiguous_regions(self, array): |
|
|
|
|
|
change_indices = np.logical_xor(array[1:], array[:-1]).nonzero()[0] |
|
|
|
|
|
change_indices += 1 |
|
|
if array[0]: |
|
|
|
|
|
|
|
|
change_indices = np.r_[0, change_indices] |
|
|
if array[-1]: |
|
|
|
|
|
change_indices = np.r_[change_indices, array.size] |
|
|
|
|
|
return change_indices.reshape((-1, 2)) |
|
|
|
|
|
|
|
|
def decode_pred_batch(outputs, weak_preds, filenames, encoder, thresholds, median_filter, decode_weak, pad_idx=None): |
|
|
pred_dfs = {} |
|
|
for threshold in thresholds: |
|
|
pred_dfs[threshold] = pd.DataFrame() |
|
|
for batch_idx in range(outputs.shape[0]): |
|
|
for c_th in thresholds: |
|
|
output = outputs[batch_idx] |
|
|
if pad_idx is not None: |
|
|
true_len = int(output.shape[-1] * pad_idx[batch_idx].item) |
|
|
output = output[:true_len] |
|
|
output = output.transpose(0, 1).detach().cpu().numpy() |
|
|
if decode_weak: |
|
|
for class_idx in range(weak_preds.size(1)): |
|
|
if weak_preds[batch_idx, class_idx] < c_th: |
|
|
output[:, class_idx] = 0 |
|
|
elif decode_weak > 1: |
|
|
output[:, class_idx] = 1 |
|
|
if decode_weak < 2: |
|
|
output = output > c_th |
|
|
for mf_idx in range(len(median_filter)): |
|
|
output[:, mf_idx] = scipy.ndimage.filters.median_filter(output[:, mf_idx], (median_filter[mf_idx])) |
|
|
pred = encoder.decode_strong(output) |
|
|
pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"]) |
|
|
pred["filename"] = Path(filenames[batch_idx]).stem + ".wav" |
|
|
pred_dfs[c_th] = pred_dfs[c_th]._append(pred, ignore_index=True) |
|
|
return pred_dfs |
|
|
|
|
|
|
|
|
class ConcatDatasetBatchSampler(Sampler): |
|
|
def __init__(self, samplers, batch_sizes, epoch=0): |
|
|
self.batch_sizes = batch_sizes |
|
|
self.samplers = samplers |
|
|
self.offsets = [0] + np.cumsum([len(x) for x in self.samplers]).tolist()[:-1] |
|
|
|
|
|
self.epoch = epoch |
|
|
self.set_epoch(self.epoch) |
|
|
|
|
|
def _iter_one_dataset(self, c_batch_size, c_sampler, c_offset): |
|
|
batch = [] |
|
|
for idx in c_sampler: |
|
|
batch.append(c_offset + idx) |
|
|
if len(batch) == c_batch_size: |
|
|
yield batch |
|
|
|
|
|
def set_epoch(self, epoch): |
|
|
if hasattr(self.samplers[0], "epoch"): |
|
|
for s in self.samplers: |
|
|
s.set_epoch(epoch) |
|
|
|
|
|
def __iter__(self): |
|
|
iterators = [iter(i) for i in self.samplers] |
|
|
tot_batch = [] |
|
|
for b_num in range(len(self)): |
|
|
for samp_idx in range(len(self.samplers)): |
|
|
c_batch = [] |
|
|
while len(c_batch) < self.batch_sizes[samp_idx]: |
|
|
c_batch.append(self.offsets[samp_idx] + next(iterators[samp_idx])) |
|
|
tot_batch.extend(c_batch) |
|
|
yield tot_batch |
|
|
tot_batch = [] |
|
|
|
|
|
def __len__(self): |
|
|
min_len = float("inf") |
|
|
for idx, sampler in enumerate(self.samplers): |
|
|
c_len = (len(sampler)) // self.batch_sizes[idx] |
|
|
min_len = min(c_len, min_len) |
|
|
return min_len |
|
|
|
|
|
|
|
|
class ExponentialWarmup(object): |
|
|
def __init__(self, optimizer, max_lr, rampup_length, exponent=-5.0): |
|
|
self.optimizer = optimizer |
|
|
self.rampup_length = rampup_length |
|
|
self.max_lr = max_lr |
|
|
self.step_num = 1 |
|
|
self.exponent = exponent |
|
|
|
|
|
def zero_grad(self): |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
def _get_lr(self): |
|
|
return self.max_lr * self._get_scaling_factor() |
|
|
|
|
|
def _set_lr(self, lr): |
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group["lr"] = lr |
|
|
|
|
|
def step(self): |
|
|
self.step_num += 1 |
|
|
lr = self._get_lr() |
|
|
self._set_lr(lr) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_scaling_factor(self): |
|
|
if self.rampup_length == 0: |
|
|
return 1.0 |
|
|
else: |
|
|
current = np.clip(self.step_num, 0.0, self.rampup_length) |
|
|
phase = 1.0 - current / self.rampup_length |
|
|
return float(np.exp(self.exponent * phase * phase)) |
|
|
|
|
|
|
|
|
def update_ema(net, ema_net, step, ema_factor): |
|
|
|
|
|
alpha = min(1 - 1 / step, ema_factor) |
|
|
for ema_params, params in zip(ema_net.parameters(), net.parameters()): |
|
|
ema_params.data.mul_(alpha).add_(params.data, alpha=1 - alpha) |
|
|
return ema_net |
|
|
|
|
|
|
|
|
def log_sedeval_metrics(predictions, ground_truth, save_dir=None): |
|
|
""" Return the set of metrics from sed_eval |
|
|
Args: |
|
|
predictions: pd.DataFrame, the dataframe of predictions. |
|
|
ground_truth: pd.DataFrame, the dataframe of groundtruth. |
|
|
save_dir: str, path to the folder where to save the event and segment based metrics outputs. |
|
|
|
|
|
Returns: |
|
|
tuple, event-based macro-F1 and micro-F1, segment-based macro-F1 and micro-F1 |
|
|
""" |
|
|
if predictions.empty: |
|
|
return 0.0, 0.0, 0.0, 0.0 |
|
|
|
|
|
gt = pd.read_csv(ground_truth, sep="\t") |
|
|
|
|
|
event_res, segment_res = compute_sed_eval_metrics(predictions, gt) |
|
|
|
|
|
if save_dir is not None: |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
with open(os.path.join(save_dir, "event_f1.txt"), "w") as f: |
|
|
f.write(str(event_res)) |
|
|
|
|
|
with open(os.path.join(save_dir, "segment_f1.txt"), "w") as f: |
|
|
f.write(str(segment_res)) |
|
|
|
|
|
return ( |
|
|
event_res.results()["class_wise_average"]["f_measure"]["f_measure"], |
|
|
event_res.results()["overall"]["f_measure"]["f_measure"], |
|
|
segment_res.results()["class_wise_average"]["f_measure"]["f_measure"], |
|
|
segment_res.results()["overall"]["f_measure"]["f_measure"], |
|
|
) |
|
|
|
|
|
|
|
|
class Scaler(nn.Module): |
|
|
def __init__(self, statistic="instance", normtype="minmax", dims=(0, 2), eps=1e-8): |
|
|
super(Scaler, self).__init__() |
|
|
self.statistic = statistic |
|
|
self.normtype = normtype |
|
|
self.dims = dims |
|
|
self.eps = eps |
|
|
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
|
if self.statistic == "dataset": |
|
|
super(Scaler, self).load_state_dict(state_dict, strict) |
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
|
|
missing_keys, unexpected_keys, error_msgs): |
|
|
if self.statistic == "dataset": |
|
|
super(Scaler, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, |
|
|
unexpected_keys, error_msgs) |
|
|
|
|
|
def forward(self, input): |
|
|
if self.statistic == "dataset": |
|
|
if self.normtype == "mean": |
|
|
return input - self.mean |
|
|
elif self.normtype == "standard": |
|
|
std = torch.sqrt(self.mean_squared - self.mean ** 2) |
|
|
return (input - self.mean) / (std + self.eps) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
elif self.statistic =="instance": |
|
|
if self.normtype == "mean": |
|
|
return input - torch.mean(input, self.dims, keepdim=True) |
|
|
elif self.normtype == "standard": |
|
|
return (input - torch.mean(input, self.dims, keepdim=True)) / ( |
|
|
torch.std(input, self.dims, keepdim=True) + self.eps) |
|
|
elif self.normtype == "minmax": |
|
|
return (input - torch.amin(input, dim=self.dims, keepdim=True)) / ( |
|
|
torch.amax(input, dim=self.dims, keepdim=True) |
|
|
- torch.amin(input, dim=self.dims, keepdim=True) + self.eps) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class AsymmetricalFocalLoss(nn.Module): |
|
|
def __init__(self, gamma=0, zeta=0): |
|
|
super(AsymmetricalFocalLoss, self).__init__() |
|
|
self.gamma = gamma |
|
|
self.zeta = zeta |
|
|
|
|
|
def forward(self, pred, target): |
|
|
losses = - (((1 - pred) ** self.gamma) * target * torch.clamp_min(torch.log(pred), -100) + |
|
|
(pred ** self.zeta) * (1 - target) * torch.clamp_min(torch.log(1 - pred), -100)) |
|
|
return torch.mean(losses) |
|
|
|
|
|
|
|
|
def take_log(feature): |
|
|
amp2db = torchaudio.transforms.AmplitudeToDB(stype="amplitude") |
|
|
amp2db.amin = 1e-5 |
|
|
return amp2db(feature).clamp(min=-50, max=80) |
|
|
|
|
|
|
|
|
def count_parameters(model): |
|
|
total_params = 0 |
|
|
for name, parameter in model.named_parameters(): |
|
|
if not parameter.requires_grad: |
|
|
continue |
|
|
param = parameter.numel() |
|
|
total_params += param |
|
|
return total_params |
|
|
|