Upload metrics.py with huggingface_hub
Browse files- metrics.py +41 -11
metrics.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import Any, Dict, Generator, List, Optional
|
|
| 7 |
import evaluate
|
| 8 |
import nltk
|
| 9 |
import numpy
|
|
|
|
| 10 |
|
| 11 |
from .dataclass import InternalField
|
| 12 |
from .operator import (
|
|
@@ -21,12 +22,12 @@ from .stream import MultiStream, Stream
|
|
| 21 |
nltk.download("punkt")
|
| 22 |
|
| 23 |
|
| 24 |
-
def
|
| 25 |
return {}
|
| 26 |
|
| 27 |
|
| 28 |
def abstract_field():
|
| 29 |
-
return field(default_factory=
|
| 30 |
|
| 31 |
|
| 32 |
class UpdateStream(StreamInstanceOperator):
|
|
@@ -253,7 +254,7 @@ class F1(GlobalMetric):
|
|
| 253 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
| 254 |
assert all(
|
| 255 |
len(reference) == 1 for reference in references
|
| 256 |
-
), "
|
| 257 |
self.str_to_id = {}
|
| 258 |
self.id_to_str = {}
|
| 259 |
formatted_references = [self.get_str_id(reference[0]) for reference in references]
|
|
@@ -287,7 +288,6 @@ class F1MultiLabel(GlobalMetric):
|
|
| 287 |
_metric = None
|
| 288 |
main_score = "f1_macro"
|
| 289 |
average = None # Report per class then aggregate by mean
|
| 290 |
-
seperator = ","
|
| 291 |
|
| 292 |
def prepare(self):
|
| 293 |
super(F1MultiLabel, self).prepare()
|
|
@@ -310,17 +310,15 @@ class F1MultiLabel(GlobalMetric):
|
|
| 310 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
| 311 |
self.str_to_id = {}
|
| 312 |
self.id_to_str = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
labels = list(set([label for reference in references for label in reference]))
|
| 314 |
for label in labels:
|
| 315 |
-
assert (
|
| 316 |
-
not self.seperator in label
|
| 317 |
-
), "Reference label (f{label}) can not contain multi label seperator (f{self.seperator}) "
|
| 318 |
self.add_str_to_id(label)
|
| 319 |
formatted_references = [self.get_one_hot_vector(reference) for reference in references]
|
| 320 |
-
|
| 321 |
-
[label.strip() for label in prediction.split(self.seperator)] for prediction in predictions
|
| 322 |
-
]
|
| 323 |
-
formatted_predictions = [self.get_one_hot_vector(prediction) for prediction in split_predictions]
|
| 324 |
result = self._metric.compute(
|
| 325 |
predictions=formatted_predictions, references=formatted_references, average=self.average
|
| 326 |
)
|
|
@@ -356,6 +354,38 @@ class Rouge(HuggingfaceMetric):
|
|
| 356 |
return super().compute(references, predictions)
|
| 357 |
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
class Bleu(HuggingfaceMetric):
|
| 360 |
metric_name = "bleu"
|
| 361 |
main_score = "bleu"
|
|
|
|
| 7 |
import evaluate
|
| 8 |
import nltk
|
| 9 |
import numpy
|
| 10 |
+
from editdistance import eval
|
| 11 |
|
| 12 |
from .dataclass import InternalField
|
| 13 |
from .operator import (
|
|
|
|
| 22 |
nltk.download("punkt")
|
| 23 |
|
| 24 |
|
| 25 |
+
def abstract_factory():
|
| 26 |
return {}
|
| 27 |
|
| 28 |
|
| 29 |
def abstract_field():
|
| 30 |
+
return field(default_factory=abstract_factory)
|
| 31 |
|
| 32 |
|
| 33 |
class UpdateStream(StreamInstanceOperator):
|
|
|
|
| 254 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
| 255 |
assert all(
|
| 256 |
len(reference) == 1 for reference in references
|
| 257 |
+
), "Only a single reference per prediction is allowed in F1 metric"
|
| 258 |
self.str_to_id = {}
|
| 259 |
self.id_to_str = {}
|
| 260 |
formatted_references = [self.get_str_id(reference[0]) for reference in references]
|
|
|
|
| 288 |
_metric = None
|
| 289 |
main_score = "f1_macro"
|
| 290 |
average = None # Report per class then aggregate by mean
|
|
|
|
| 291 |
|
| 292 |
def prepare(self):
|
| 293 |
super(F1MultiLabel, self).prepare()
|
|
|
|
| 310 |
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
| 311 |
self.str_to_id = {}
|
| 312 |
self.id_to_str = {}
|
| 313 |
+
assert all(
|
| 314 |
+
len(reference) == 1 for reference in references
|
| 315 |
+
), "Only a single reference per prediction is allowed in F1 metric"
|
| 316 |
+
references = [reference[0] for reference in references]
|
| 317 |
labels = list(set([label for reference in references for label in reference]))
|
| 318 |
for label in labels:
|
|
|
|
|
|
|
|
|
|
| 319 |
self.add_str_to_id(label)
|
| 320 |
formatted_references = [self.get_one_hot_vector(reference) for reference in references]
|
| 321 |
+
formatted_predictions = [self.get_one_hot_vector(prediction) for prediction in predictions]
|
|
|
|
|
|
|
|
|
|
| 322 |
result = self._metric.compute(
|
| 323 |
predictions=formatted_predictions, references=formatted_references, average=self.average
|
| 324 |
)
|
|
|
|
| 354 |
return super().compute(references, predictions)
|
| 355 |
|
| 356 |
|
| 357 |
+
# Computes chat edit distance, ignoring repeating whitespace
|
| 358 |
+
class CharEditDistanceAccuracy(SingleReferenceInstanceMetric):
|
| 359 |
+
reduction_map = {"mean": ["char_edit_dist_accuracy"]}
|
| 360 |
+
main_score = "char_edit_dist_accuracy"
|
| 361 |
+
|
| 362 |
+
def compute(self, reference, prediction: str) -> dict:
|
| 363 |
+
formatted_prediction = " ".join(prediction.split())
|
| 364 |
+
formatted_reference = " ".join(reference.split())
|
| 365 |
+
max_length = max(len(formatted_reference), len(formatted_prediction))
|
| 366 |
+
if max_length == 0:
|
| 367 |
+
return 0
|
| 368 |
+
edit_dist = eval(formatted_reference, formatted_prediction)
|
| 369 |
+
return {"char_edit_dist_accuracy": (1 - edit_dist / max_length)}
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class Wer(HuggingfaceMetric):
|
| 373 |
+
metric_name = "wer"
|
| 374 |
+
main_score = "wer"
|
| 375 |
+
|
| 376 |
+
def prepare(self):
|
| 377 |
+
super().prepare()
|
| 378 |
+
self.metric = evaluate.load(self.metric_name)
|
| 379 |
+
|
| 380 |
+
def compute(self, references: List[List[str]], predictions: List[str]) -> dict:
|
| 381 |
+
assert all(
|
| 382 |
+
len(reference) == 1 for reference in references
|
| 383 |
+
), "Only single reference per prediction is allowed in wer metric"
|
| 384 |
+
formatted_references = [reference[0] for reference in references]
|
| 385 |
+
result = self.metric.compute(predictions=predictions, references=formatted_references)
|
| 386 |
+
return {self.main_score: result}
|
| 387 |
+
|
| 388 |
+
|
| 389 |
class Bleu(HuggingfaceMetric):
|
| 390 |
metric_name = "bleu"
|
| 391 |
main_score = "bleu"
|