Spaces:
Running
Running
| from typing import List, Dict, Any | |
| from collections import defaultdict | |
| import statistics | |
| import datasets | |
| import evaluate | |
| from FLD_task import build_metrics | |
| _DESCRIPTION = "" | |
| _KWARGS_DESCRIPTION = "" | |
| _CITATION = "" | |
| class FLDMetrics(evaluate.Metric): | |
| def __init__(self, *args, log_samples=False, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self._metric_funcs = { | |
| 'strct': build_metrics('strict'), | |
| 'extr_stps': build_metrics('allow_extra_steps'), | |
| } | |
| self.log_samples = log_samples | |
| def _info(self): | |
| return evaluate.MetricInfo( | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=datasets.Features( | |
| { | |
| "predictions": datasets.Value("string"), | |
| "references": datasets.Sequence(datasets.Value("string")), | |
| "contexts": datasets.Value("string"), | |
| } | |
| ), | |
| # reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"], | |
| ) | |
| def _compute(self, predictions, references, contexts): | |
| if contexts is None: | |
| contexts = [None] * len(predictions) | |
| metrics: Dict[str, List[Any]] = defaultdict(list) | |
| for pred, golds, context in zip(predictions, references, contexts): | |
| for metric_type, calc_metrics in self._metric_funcs.items(): | |
| _metrics = calc_metrics( | |
| golds, | |
| pred, | |
| context=context, | |
| ) | |
| for metric_name, metric_val in _metrics.items(): | |
| metrics[f"{metric_type}.{metric_name}"].append(metric_val) | |
| results = {} | |
| for metric_name, metric_vals in metrics.items(): | |
| results[f"{metric_name}"] = statistics.mean(metric_vals) | |
| return results | |