|
|
import torch |
|
|
import os |
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
import sed_scores_eval |
|
|
from desed_task.evaluation.evaluation_measures import (compute_per_intersection_macro_f1, |
|
|
compute_psds_from_operating_points, |
|
|
compute_psds_from_scores) |
|
|
from local.utils import (batched_decode_preds,) |
|
|
from utils.sed import Encoder |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def val_psds(model, val_loader, params, epoch, split, save_path, device): |
|
|
label_df = pd.read_csv(params['data'][split]['label']) |
|
|
EVENTS = label_df['label'].tolist() |
|
|
|
|
|
clap_emb = [] |
|
|
for event in EVENTS: |
|
|
cls = torch.load(params['data']['train_data']['clap_dir'] + event + '.pt').to(device) |
|
|
cls = cls.unsqueeze(1) |
|
|
clap_emb.append(cls) |
|
|
cls = torch.cat(clap_emb, dim=1) |
|
|
|
|
|
encoder = Encoder(EVENTS, audio_len=10, frame_len=160, frame_hop=160, net_pooling=4, sr=16000) |
|
|
|
|
|
model.eval() |
|
|
test_csv = params['data'][split]["csv"] |
|
|
test_dur = params['data'][split]["dur"] |
|
|
|
|
|
gt = pd.read_csv(test_csv, sep='\t') |
|
|
|
|
|
test_scores_postprocessed_buffer = {} |
|
|
test_scores_postprocessed_buffer_tsed = {} |
|
|
test_thresholds = [0.5] |
|
|
test_psds_buffer = {k: pd.DataFrame() for k in test_thresholds} |
|
|
test_psds_buffer_tsed = {k: pd.DataFrame() for k in test_thresholds} |
|
|
|
|
|
for batch in tqdm(val_loader): |
|
|
audio, filenames = batch |
|
|
B = audio.shape[0] |
|
|
N = cls.shape[1] |
|
|
cls = cls.expand(B, -1, -1) |
|
|
|
|
|
audio = audio.to(device) |
|
|
mel = model.forward_to_spec(audio) |
|
|
|
|
|
preds = model(mel, cls) |
|
|
preds = torch.sigmoid(preds) |
|
|
preds = preds.reshape(B, N, -1) |
|
|
preds_tsed = preds.clone() |
|
|
|
|
|
for idx, filename in enumerate(filenames): |
|
|
weak_label = list(gt[gt['filename'] == filename]['event_label'].unique()) |
|
|
for j, event in enumerate(EVENTS): |
|
|
if event not in weak_label: |
|
|
preds_tsed[idx][j] = 0.0 |
|
|
|
|
|
|
|
|
(_, scores_postprocessed_strong, _,) = \ |
|
|
batched_decode_preds( |
|
|
preds, |
|
|
filenames, |
|
|
encoder, |
|
|
median_filter=9, |
|
|
thresholds=list(test_psds_buffer.keys()), ) |
|
|
test_scores_postprocessed_buffer.update(scores_postprocessed_strong) |
|
|
|
|
|
(_, scores_postprocessed_strong_tsed, _,) = \ |
|
|
batched_decode_preds( |
|
|
preds_tsed, |
|
|
filenames, |
|
|
encoder, |
|
|
median_filter=9, |
|
|
thresholds=list(test_psds_buffer_tsed.keys()), ) |
|
|
test_scores_postprocessed_buffer_tsed.update(scores_postprocessed_strong_tsed) |
|
|
|
|
|
ground_truth = sed_scores_eval.io.read_ground_truth_events(test_csv) |
|
|
audio_durations = sed_scores_eval.io.read_audio_durations(test_dur) |
|
|
|
|
|
ground_truth = { |
|
|
audio_id: ground_truth[audio_id] |
|
|
for audio_id in test_scores_postprocessed_buffer |
|
|
} |
|
|
audio_durations = { |
|
|
audio_id: audio_durations[audio_id] |
|
|
for audio_id in test_scores_postprocessed_buffer |
|
|
} |
|
|
|
|
|
psds1_sed_scores_eval, psds1_cls = compute_psds_from_scores( |
|
|
test_scores_postprocessed_buffer, |
|
|
ground_truth, |
|
|
audio_durations, |
|
|
dtc_threshold=0.7, |
|
|
gtc_threshold=0.7, |
|
|
cttc_threshold=None, |
|
|
alpha_ct=0.0, |
|
|
alpha_st=0.0, |
|
|
|
|
|
) |
|
|
psds1_cls['overall'] = psds1_sed_scores_eval |
|
|
psds1_cls['macro_averaged'] = np.array([v for k, v in psds1_cls.items()]).mean() |
|
|
psds1_cls['name'] = 'psds1' |
|
|
|
|
|
psds1_sed_scores_eval_tsed, psds1_cls_tsed = compute_psds_from_scores( |
|
|
test_scores_postprocessed_buffer_tsed, |
|
|
ground_truth, |
|
|
audio_durations, |
|
|
dtc_threshold=0.7, |
|
|
gtc_threshold=0.7, |
|
|
cttc_threshold=None, |
|
|
alpha_ct=0.0, |
|
|
alpha_st=0.0, |
|
|
|
|
|
) |
|
|
|
|
|
psds1_cls_tsed['overall'] = psds1_sed_scores_eval_tsed |
|
|
psds1_cls_tsed['macro_averaged'] = np.array([v for k, v in psds1_cls_tsed.items()]).mean() |
|
|
psds1_cls_tsed['name'] = 'psds1_tsed' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
psds_cls = pd.DataFrame([psds1_cls, psds1_cls_tsed]) |
|
|
|
|
|
os.makedirs(f'{save_path}/psds_cls/', exist_ok=True) |
|
|
psds_cls.to_csv(f'{save_path}/psds_cls/{epoch}.csv', index=False) |
|
|
|
|
|
return psds1_sed_scores_eval, psds1_sed_scores_eval_tsed |