Spaces:
Sleeping
Sleeping
| """ | |
| Code from: | |
| https://github.com/DCASE-REPO/DESED_task | |
| """ | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import scipy | |
| from sed_scores_eval.base_modules.scores import create_score_dataframe | |
| def batched_decode_preds( | |
| strong_preds, | |
| filenames, | |
| encoder, | |
| thresholds=[0.5], | |
| median_filter=None, | |
| pad_indx=None, | |
| ): | |
| """Decode a batch of predictions to dataframes. Each threshold gives a different dataframe and stored in a | |
| dictionary | |
| Args: | |
| strong_preds: torch.Tensor, batch of strong predictions. | |
| filenames: list, the list of filenames of the current batch. | |
| encoder: ManyHotEncoder object, object used to decode predictions. | |
| thresholds: list, the list of thresholds to be used for predictions. | |
| median_filter: int, the number of frames for which to apply median window (smoothing). | |
| pad_indx: list, the list of indexes which have been used for padding. | |
| Returns: | |
| dict of predictions, each keys is a threshold and the value is the DataFrame of predictions. | |
| """ | |
| # Init a dataframe per threshold | |
| scores_raw = {} | |
| scores_postprocessed = {} | |
| prediction_dfs = {} | |
| for threshold in thresholds: | |
| prediction_dfs[threshold] = pd.DataFrame() | |
| for j in range(strong_preds.shape[0]): # over batches | |
| audio_id = Path(filenames[j]).stem | |
| filename = audio_id + ".wav" | |
| c_scores = strong_preds[j] | |
| if pad_indx is not None: | |
| true_len = int(c_scores.shape[-1] * pad_indx[j].item()) | |
| c_scores = c_scores[:true_len] | |
| c_scores = c_scores.transpose(0, 1).detach().cpu().numpy() | |
| scores_raw[audio_id] = create_score_dataframe( | |
| scores=c_scores, | |
| timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)), | |
| event_classes=encoder.labels, | |
| ) | |
| if median_filter is not None: | |
| c_scores = scipy.ndimage.filters.median_filter(c_scores, (median_filter, 1)) | |
| scores_postprocessed[audio_id] = create_score_dataframe( | |
| scores=c_scores, | |
| timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)), | |
| event_classes=encoder.labels, | |
| ) | |
| for c_th in thresholds: | |
| pred = c_scores > c_th | |
| pred = encoder.decode_strong(pred) | |
| pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"]) | |
| pred["filename"] = filename | |
| prediction_dfs[c_th] = pd.concat( | |
| [prediction_dfs[c_th], pred], ignore_index=True | |
| ) | |
| return scores_raw, scores_postprocessed, prediction_dfs | |