Spaces:
Runtime error
Runtime error
| """ | |
| BERT Score | |
| --------------------- | |
| BERT Score is introduced in this paper (BERTScore: Evaluating Text Generation with BERT) `arxiv link`_. | |
| .. _arxiv link: https://arxiv.org/abs/1904.09675 | |
| BERT Score measures token similarity between two text using contextual embedding. | |
| To decide which two tokens to compare, it greedily chooses the most similar token from one text and matches it to a token in the second text. | |
| """ | |
| import bert_score | |
| from textattack.constraints import Constraint | |
| from textattack.shared import utils | |
| class BERTScore(Constraint): | |
| """A constraint on BERT-Score difference. | |
| Args: | |
| min_bert_score (float), minimum threshold value for BERT-Score | |
| model_name (str), name of model to use for scoring | |
| num_layers (int), number of hidden layers in the model | |
| score_type (str), Pick one of following three choices | |
| -(1) ``precision`` : match words from candidate text to reference text | |
| -(2) ``recall`` : match words from reference text to candidate text | |
| -(3) ``f1``: harmonic mean of precision and recall (recommended) | |
| compare_against_original (bool): | |
| If ``True``, compare new ``x_adv`` against the original ``x``. | |
| Otherwise, compare it against the previous ``x_adv``. | |
| """ | |
| SCORE_TYPE2IDX = {"precision": 0, "recall": 1, "f1": 2} | |
| def __init__( | |
| self, | |
| min_bert_score, | |
| model_name="bert-base-uncased", | |
| num_layers=None, | |
| score_type="f1", | |
| compare_against_original=True, | |
| ): | |
| super().__init__(compare_against_original) | |
| if not isinstance(min_bert_score, float): | |
| raise TypeError("max_bert_score must be a float") | |
| if min_bert_score < 0.0 or min_bert_score > 1.0: | |
| raise ValueError("max_bert_score must be a value between 0.0 and 1.0") | |
| self.min_bert_score = min_bert_score | |
| self.model = model_name | |
| self.score_type = score_type | |
| # Turn off idf-weighting scheme b/c reference sentence set is small | |
| self._bert_scorer = bert_score.BERTScorer( | |
| model_type=model_name, idf=False, device=utils.device, num_layers=num_layers | |
| ) | |
| def _check_constraint(self, transformed_text, reference_text): | |
| """Return `True` if BERT Score between `transformed_text` and | |
| `reference_text` is lower than minimum BERT Score.""" | |
| cand = transformed_text.text | |
| ref = reference_text.text | |
| result = self._bert_scorer.score([cand], [ref]) | |
| score = result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item() | |
| if score >= self.min_bert_score: | |
| return True | |
| else: | |
| return False | |
| def extra_repr_keys(self): | |
| return ["min_bert_score", "model", "score_type"] + super().extra_repr_keys() | |