FlexSED / src /val.py
OpenSound's picture
Upload 544 files
3b6a091 verified
import torch
import os
import pandas as pd
from tqdm import tqdm
import sed_scores_eval
from desed_task.evaluation.evaluation_measures import (compute_per_intersection_macro_f1,
compute_psds_from_operating_points,
compute_psds_from_scores)
from local.utils import (batched_decode_preds,)
from utils.sed import Encoder
import numpy as np
@torch.no_grad()
def val_psds(model, val_loader, params, epoch, split, save_path, device):
label_df = pd.read_csv(params['data'][split]['label'])
EVENTS = label_df['label'].tolist()
clap_emb = []
for event in EVENTS:
cls = torch.load(params['data']['train_data']['clap_dir'] + event + '.pt').to(device)
cls = cls.unsqueeze(1)
clap_emb.append(cls)
cls = torch.cat(clap_emb, dim=1)
encoder = Encoder(EVENTS, audio_len=10, frame_len=160, frame_hop=160, net_pooling=4, sr=16000)
model.eval()
test_csv = params['data'][split]["csv"]
test_dur = params['data'][split]["dur"]
gt = pd.read_csv(test_csv, sep='\t')
test_scores_postprocessed_buffer = {}
test_scores_postprocessed_buffer_tsed = {}
test_thresholds = [0.5]
test_psds_buffer = {k: pd.DataFrame() for k in test_thresholds}
test_psds_buffer_tsed = {k: pd.DataFrame() for k in test_thresholds}
for batch in tqdm(val_loader):
audio, filenames = batch
B = audio.shape[0]
N = cls.shape[1]
cls = cls.expand(B, -1, -1)
audio = audio.to(device)
mel = model.forward_to_spec(audio)
preds = model(mel, cls)
preds = torch.sigmoid(preds)
preds = preds.reshape(B, N, -1)
preds_tsed = preds.clone()
# tsed assumes sound exitencance is known
for idx, filename in enumerate(filenames):
weak_label = list(gt[gt['filename'] == filename]['event_label'].unique())
for j, event in enumerate(EVENTS):
if event not in weak_label:
preds_tsed[idx][j] = 0.0
# preds = preds.transpose(1, 2)
(_, scores_postprocessed_strong, _,) = \
batched_decode_preds(
preds,
filenames,
encoder,
median_filter=9,
thresholds=list(test_psds_buffer.keys()), )
test_scores_postprocessed_buffer.update(scores_postprocessed_strong)
(_, scores_postprocessed_strong_tsed, _,) = \
batched_decode_preds(
preds_tsed,
filenames,
encoder,
median_filter=9,
thresholds=list(test_psds_buffer_tsed.keys()), )
test_scores_postprocessed_buffer_tsed.update(scores_postprocessed_strong_tsed)
ground_truth = sed_scores_eval.io.read_ground_truth_events(test_csv)
audio_durations = sed_scores_eval.io.read_audio_durations(test_dur)
ground_truth = {
audio_id: ground_truth[audio_id]
for audio_id in test_scores_postprocessed_buffer
}
audio_durations = {
audio_id: audio_durations[audio_id]
for audio_id in test_scores_postprocessed_buffer
}
psds1_sed_scores_eval, psds1_cls = compute_psds_from_scores(
test_scores_postprocessed_buffer,
ground_truth,
audio_durations,
dtc_threshold=0.7,
gtc_threshold=0.7,
cttc_threshold=None,
alpha_ct=0.0,
alpha_st=0.0,
# save_dir=os.path.join(save_dir, "student", "scenario1"),
)
psds1_cls['overall'] = psds1_sed_scores_eval
psds1_cls['macro_averaged'] = np.array([v for k, v in psds1_cls.items()]).mean()
psds1_cls['name'] = 'psds1'
psds1_sed_scores_eval_tsed, psds1_cls_tsed = compute_psds_from_scores(
test_scores_postprocessed_buffer_tsed,
ground_truth,
audio_durations,
dtc_threshold=0.7,
gtc_threshold=0.7,
cttc_threshold=None,
alpha_ct=0.0,
alpha_st=0.0,
# save_dir=os.path.join(save_dir, "student", "scenario1"),
)
psds1_cls_tsed['overall'] = psds1_sed_scores_eval_tsed
psds1_cls_tsed['macro_averaged'] = np.array([v for k, v in psds1_cls_tsed.items()]).mean()
psds1_cls_tsed['name'] = 'psds1_tsed'
# psds2_sed_scores_eval, psds2_cls = compute_psds_from_scores(
# test_scores_postprocessed_buffer,
# ground_truth,
# audio_durations,
# dtc_threshold=0.1,
# gtc_threshold=0.1,
# cttc_threshold=0.3,
# alpha_ct=0.5,
# alpha_st=1,
# # save_dir=os.path.join(save_dir, "student", "scenario1"),
# )
# psds2_cls['overall'] = psds2_sed_scores_eval
# psds2_cls['macro_averaged'] = np.array([v for k, v in psds2_cls.items()]).mean()
# psds2_cls['name'] = 'psds2'
psds_cls = pd.DataFrame([psds1_cls, psds1_cls_tsed])
# psds_cls = pd.DataFrame([psds1_cls, psds2_cls])
os.makedirs(f'{save_path}/psds_cls/', exist_ok=True)
psds_cls.to_csv(f'{save_path}/psds_cls/{epoch}.csv', index=False)
return psds1_sed_scores_eval, psds1_sed_scores_eval_tsed