FlexSED / src /dataset /tsed.py
OpenSound's picture
Upload 544 files
3b6a091 verified
import torch
import torchaudio
import torch.nn as nn
import pandas as pd
import random
from torch.utils.data import Dataset
import ast
class TSED_AS(Dataset):
def __init__(self, data_dir, clap_dir, meta_dir, label_dir, class_list,
seg_length=10, sr=16000, label_sr=25, label_per_audio=[10, 10],
norm=True, mono=True, label_type='strong', debug=False, sample_method='random',
neg_removed_weight=0.25,
**kwargs):
self.data_dir = data_dir
self.clap_dir = clap_dir
meta = pd.read_csv(meta_dir)
meta = meta[meta['duration'] != 0]
self.meta = meta
if label_type == 'strong':
label = pd.read_csv(label_dir)
self.label = label
else:
self.label = None
self.label_per_audio = label_per_audio
self.class_list = pd.read_csv(class_list)
self.class_dict = dict(self.class_list.set_index('id')['label']) # Convert to dict
# self.event_id = dict(self.class_list.set_index('label')['id'])
self.cls_ids = sorted(self.class_list['id'].unique().tolist())
self.sample_method = sample_method
self.seg_len = seg_length
self.sr = sr
self.label_sr = label_sr
self.label_type = label_type
self.norm = norm
self.mono = mono
self.neg_removed_weight = neg_removed_weight
def load_audio(self, audio_path):
y, sr = torchaudio.load(audio_path)
assert sr == self.sr
# Handle stereo or mono based on self.mono
if self.mono:
# Convert to mono by averaging all channels
y = torch.mean(y, dim=0, keepdim=True)
else:
if y.shape[0] == 1:
pass
elif y.shape[0] == 2:
# Randomly pick one of the two stereo channels or take the mean
if random.choice([True, False]):
y = torch.mean(y, dim=0, keepdim=True)
else:
channel = random.choice([0, 1])
y = y[channel, :].unsqueeze(0)
else:
raise ValueError("Unsupported number of channels: {}".format(y.shape[0]))
total_length = y.shape[-1]
start = 0
end = min(start + self.seg_len * self.sr, total_length)
audio_clip = torch.zeros(self.seg_len * self.sr)
audio_clip[:end - start] = y[0, start:end]
if self.norm:
eps = 1e-9
max_val = torch.max(torch.abs(audio_clip))
audio_clip = audio_clip / (max_val + eps)
# audio_clip = self.augmenter(audio_clip)
return audio_clip
def load_label(self, filelabel, event_label):
target = torch.zeros(self.seg_len * self.label_sr)
if self.label_type == 'strong':
label = filelabel[filelabel['label'] == event_label]
for i in range(len(label)):
row = label.iloc[i]
onset = row['onset']
offset = row['offset']
target[round(onset*self.label_sr):round(offset*self.label_sr)] = 1
else:
pass
return target.unsqueeze(0)
def __getitem__(self, index):
row = self.meta.iloc[index]
audio = self.load_audio(self.data_dir + row['file_name'])
# TBD balance positive and negative
if self.sample_method == 'fix':
cls_list = row['ids']
if self.sample_method == 'random':
cls_queue = self.cls_ids
cls_list = random.sample(cls_queue, self.label_per_audio)
elif self.sample_method == 'balance':
pos_ids = ast.literal_eval(row['pos_ids'])
neg_ids = ast.literal_eval(row['neg_ids'])
removed_ids = ast.literal_eval(row['removed_ids'])
N_p, N_n = self.label_per_audio
if len(pos_ids) < N_p:
N_n += N_p - len(pos_ids)
assert len(neg_ids) + len(removed_ids) >= N_n
# elif len(neg_ids) < N_n:
# N_p += N_n - len(neg_ids)
sampled_pos = random.sample(pos_ids, min(N_p, len(pos_ids)))
# Combine neg_ids and removed_ids with different sampling weights
candidates = neg_ids + removed_ids
weights = [1.0] * len(neg_ids) + [self.neg_removed_weight] * len(removed_ids)
sampled_neg = random.choices(candidates, weights=weights, k=min(N_n, len(candidates)))
cls_list = sampled_pos + sampled_neg
cls_tokens = []
labels = []
filelabel = self.label[self.label['filename'] == row['file_name']]
for cls_id in cls_list:
event_label = self.class_dict[cls_id]
cls = torch.load(self.clap_dir + event_label + '.pt')
cls_tokens.append(cls)
label = self.load_label(filelabel, event_label)
labels.append(label)
cls_tokens = torch.cat(cls_tokens, dim=0)
labels = torch.cat(labels, dim=0)
return audio, cls_tokens, labels, row['file_name']
def __len__(self):
return len(self.meta)