|
|
import torch |
|
|
import torchaudio |
|
|
import torch.nn as nn |
|
|
import pandas as pd |
|
|
import random |
|
|
from torch.utils.data import Dataset |
|
|
import ast |
|
|
|
|
|
|
|
|
class TSED_AS(Dataset): |
|
|
def __init__(self, data_dir, clap_dir, meta_dir, label_dir, class_list, |
|
|
seg_length=10, sr=16000, label_sr=25, label_per_audio=[10, 10], |
|
|
norm=True, mono=True, label_type='strong', debug=False, sample_method='random', |
|
|
neg_removed_weight=0.25, |
|
|
**kwargs): |
|
|
|
|
|
self.data_dir = data_dir |
|
|
self.clap_dir = clap_dir |
|
|
meta = pd.read_csv(meta_dir) |
|
|
meta = meta[meta['duration'] != 0] |
|
|
self.meta = meta |
|
|
if label_type == 'strong': |
|
|
label = pd.read_csv(label_dir) |
|
|
self.label = label |
|
|
else: |
|
|
self.label = None |
|
|
|
|
|
self.label_per_audio = label_per_audio |
|
|
|
|
|
self.class_list = pd.read_csv(class_list) |
|
|
self.class_dict = dict(self.class_list.set_index('id')['label']) |
|
|
|
|
|
self.cls_ids = sorted(self.class_list['id'].unique().tolist()) |
|
|
self.sample_method = sample_method |
|
|
|
|
|
self.seg_len = seg_length |
|
|
self.sr = sr |
|
|
self.label_sr = label_sr |
|
|
self.label_type = label_type |
|
|
|
|
|
self.norm = norm |
|
|
self.mono = mono |
|
|
|
|
|
self.neg_removed_weight = neg_removed_weight |
|
|
|
|
|
def load_audio(self, audio_path): |
|
|
y, sr = torchaudio.load(audio_path) |
|
|
assert sr == self.sr |
|
|
|
|
|
|
|
|
if self.mono: |
|
|
|
|
|
y = torch.mean(y, dim=0, keepdim=True) |
|
|
else: |
|
|
if y.shape[0] == 1: |
|
|
pass |
|
|
elif y.shape[0] == 2: |
|
|
|
|
|
if random.choice([True, False]): |
|
|
y = torch.mean(y, dim=0, keepdim=True) |
|
|
else: |
|
|
channel = random.choice([0, 1]) |
|
|
y = y[channel, :].unsqueeze(0) |
|
|
else: |
|
|
raise ValueError("Unsupported number of channels: {}".format(y.shape[0])) |
|
|
|
|
|
total_length = y.shape[-1] |
|
|
|
|
|
start = 0 |
|
|
end = min(start + self.seg_len * self.sr, total_length) |
|
|
|
|
|
audio_clip = torch.zeros(self.seg_len * self.sr) |
|
|
audio_clip[:end - start] = y[0, start:end] |
|
|
|
|
|
if self.norm: |
|
|
eps = 1e-9 |
|
|
max_val = torch.max(torch.abs(audio_clip)) |
|
|
audio_clip = audio_clip / (max_val + eps) |
|
|
|
|
|
return audio_clip |
|
|
|
|
|
def load_label(self, filelabel, event_label): |
|
|
target = torch.zeros(self.seg_len * self.label_sr) |
|
|
if self.label_type == 'strong': |
|
|
label = filelabel[filelabel['label'] == event_label] |
|
|
for i in range(len(label)): |
|
|
row = label.iloc[i] |
|
|
onset = row['onset'] |
|
|
offset = row['offset'] |
|
|
target[round(onset*self.label_sr):round(offset*self.label_sr)] = 1 |
|
|
else: |
|
|
pass |
|
|
return target.unsqueeze(0) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
row = self.meta.iloc[index] |
|
|
audio = self.load_audio(self.data_dir + row['file_name']) |
|
|
|
|
|
|
|
|
if self.sample_method == 'fix': |
|
|
cls_list = row['ids'] |
|
|
if self.sample_method == 'random': |
|
|
cls_queue = self.cls_ids |
|
|
cls_list = random.sample(cls_queue, self.label_per_audio) |
|
|
elif self.sample_method == 'balance': |
|
|
pos_ids = ast.literal_eval(row['pos_ids']) |
|
|
neg_ids = ast.literal_eval(row['neg_ids']) |
|
|
removed_ids = ast.literal_eval(row['removed_ids']) |
|
|
N_p, N_n = self.label_per_audio |
|
|
if len(pos_ids) < N_p: |
|
|
N_n += N_p - len(pos_ids) |
|
|
assert len(neg_ids) + len(removed_ids) >= N_n |
|
|
|
|
|
|
|
|
sampled_pos = random.sample(pos_ids, min(N_p, len(pos_ids))) |
|
|
|
|
|
|
|
|
candidates = neg_ids + removed_ids |
|
|
weights = [1.0] * len(neg_ids) + [self.neg_removed_weight] * len(removed_ids) |
|
|
sampled_neg = random.choices(candidates, weights=weights, k=min(N_n, len(candidates))) |
|
|
|
|
|
cls_list = sampled_pos + sampled_neg |
|
|
|
|
|
cls_tokens = [] |
|
|
labels = [] |
|
|
|
|
|
filelabel = self.label[self.label['filename'] == row['file_name']] |
|
|
|
|
|
for cls_id in cls_list: |
|
|
event_label = self.class_dict[cls_id] |
|
|
cls = torch.load(self.clap_dir + event_label + '.pt') |
|
|
cls_tokens.append(cls) |
|
|
label = self.load_label(filelabel, event_label) |
|
|
labels.append(label) |
|
|
|
|
|
cls_tokens = torch.cat(cls_tokens, dim=0) |
|
|
labels = torch.cat(labels, dim=0) |
|
|
|
|
|
return audio, cls_tokens, labels, row['file_name'] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.meta) |