File size: 5,138 Bytes
3b6a091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
import torchaudio
import torch.nn as nn
import pandas as pd
import random
from torch.utils.data import Dataset
import ast
class TSED_AS(Dataset):
def __init__(self, data_dir, clap_dir, meta_dir, label_dir, class_list,
seg_length=10, sr=16000, label_sr=25, label_per_audio=[10, 10],
norm=True, mono=True, label_type='strong', debug=False, sample_method='random',
neg_removed_weight=0.25,
**kwargs):
self.data_dir = data_dir
self.clap_dir = clap_dir
meta = pd.read_csv(meta_dir)
meta = meta[meta['duration'] != 0]
self.meta = meta
if label_type == 'strong':
label = pd.read_csv(label_dir)
self.label = label
else:
self.label = None
self.label_per_audio = label_per_audio
self.class_list = pd.read_csv(class_list)
self.class_dict = dict(self.class_list.set_index('id')['label']) # Convert to dict
# self.event_id = dict(self.class_list.set_index('label')['id'])
self.cls_ids = sorted(self.class_list['id'].unique().tolist())
self.sample_method = sample_method
self.seg_len = seg_length
self.sr = sr
self.label_sr = label_sr
self.label_type = label_type
self.norm = norm
self.mono = mono
self.neg_removed_weight = neg_removed_weight
def load_audio(self, audio_path):
y, sr = torchaudio.load(audio_path)
assert sr == self.sr
# Handle stereo or mono based on self.mono
if self.mono:
# Convert to mono by averaging all channels
y = torch.mean(y, dim=0, keepdim=True)
else:
if y.shape[0] == 1:
pass
elif y.shape[0] == 2:
# Randomly pick one of the two stereo channels or take the mean
if random.choice([True, False]):
y = torch.mean(y, dim=0, keepdim=True)
else:
channel = random.choice([0, 1])
y = y[channel, :].unsqueeze(0)
else:
raise ValueError("Unsupported number of channels: {}".format(y.shape[0]))
total_length = y.shape[-1]
start = 0
end = min(start + self.seg_len * self.sr, total_length)
audio_clip = torch.zeros(self.seg_len * self.sr)
audio_clip[:end - start] = y[0, start:end]
if self.norm:
eps = 1e-9
max_val = torch.max(torch.abs(audio_clip))
audio_clip = audio_clip / (max_val + eps)
# audio_clip = self.augmenter(audio_clip)
return audio_clip
def load_label(self, filelabel, event_label):
target = torch.zeros(self.seg_len * self.label_sr)
if self.label_type == 'strong':
label = filelabel[filelabel['label'] == event_label]
for i in range(len(label)):
row = label.iloc[i]
onset = row['onset']
offset = row['offset']
target[round(onset*self.label_sr):round(offset*self.label_sr)] = 1
else:
pass
return target.unsqueeze(0)
def __getitem__(self, index):
row = self.meta.iloc[index]
audio = self.load_audio(self.data_dir + row['file_name'])
# TBD balance positive and negative
if self.sample_method == 'fix':
cls_list = row['ids']
if self.sample_method == 'random':
cls_queue = self.cls_ids
cls_list = random.sample(cls_queue, self.label_per_audio)
elif self.sample_method == 'balance':
pos_ids = ast.literal_eval(row['pos_ids'])
neg_ids = ast.literal_eval(row['neg_ids'])
removed_ids = ast.literal_eval(row['removed_ids'])
N_p, N_n = self.label_per_audio
if len(pos_ids) < N_p:
N_n += N_p - len(pos_ids)
assert len(neg_ids) + len(removed_ids) >= N_n
# elif len(neg_ids) < N_n:
# N_p += N_n - len(neg_ids)
sampled_pos = random.sample(pos_ids, min(N_p, len(pos_ids)))
# Combine neg_ids and removed_ids with different sampling weights
candidates = neg_ids + removed_ids
weights = [1.0] * len(neg_ids) + [self.neg_removed_weight] * len(removed_ids)
sampled_neg = random.choices(candidates, weights=weights, k=min(N_n, len(candidates)))
cls_list = sampled_pos + sampled_neg
cls_tokens = []
labels = []
filelabel = self.label[self.label['filename'] == row['file_name']]
for cls_id in cls_list:
event_label = self.class_dict[cls_id]
cls = torch.load(self.clap_dir + event_label + '.pt')
cls_tokens.append(cls)
label = self.load_label(filelabel, event_label)
labels.append(label)
cls_tokens = torch.cat(cls_tokens, dim=0)
labels = torch.cat(labels, dim=0)
return audio, cls_tokens, labels, row['file_name']
def __len__(self):
return len(self.meta) |