| import sys | |
| import logging | |
| from typing import Dict, Any, Sequence | |
| from transformers import EvalPrediction | |
| from ...utils import decode_generate_ids | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| handlers=[logging.StreamHandler(sys.stdout), ], | |
| ) | |
| class BaseComputeMetrics: | |
| def __init__(self, preprocessor: Dict[str, Any]): | |
| self.preprocessor = preprocessor | |
| self.tokenizer = self.preprocessor['text'] | |
| def __call__(self, eval_preds: EvalPrediction) -> Dict[str, Any]: | |
| preds, targets = eval_preds | |
| logger.warning(f"preds shape: {preds.shape}. targets shape: {targets.shape}") | |
| preds = decode_generate_ids(self.tokenizer, preds) | |
| targets = decode_generate_ids(self.tokenizer, targets) | |
| assert len(preds) == len(targets) | |
| return self.calculate_metric(preds, targets) | |
| def calculate_metric(self, preds: Sequence[str], targets: Sequence[str]) -> Dict[str, Any]: | |
| correct = 0 | |
| failed = 0 | |
| target_failed = 0 | |
| for pred, target in zip(preds, targets): | |
| extract_pred = self.extract_ans(pred) | |
| extract_target = self.extract_ans(target) | |
| if extract_target is None: | |
| target_failed += 1 | |
| logger.warning(f"failed to extract ans from target. maybe the response string is truncated: {target}.") | |
| continue | |
| if extract_pred is None: | |
| failed += 1 | |
| if extract_pred == extract_target: | |
| correct += 1 | |
| return { | |
| 'accuracy': 1.0 * correct / len(targets), | |
| 'target_failed': target_failed, | |
| 'failed': failed, | |
| } | |
| def extract_ans(self, string: str): | |
| raise NotImplementedError | |