Spaces:
Build error
Build error
| # -*- coding:utf-8 -*- | |
| """ | |
| @Author : Bao | |
| @Date : 2020/8/24 | |
| @Desc : | |
| @Last modified by : Bao | |
| @Last modified date : 2020/9/1 | |
| """ | |
| import os | |
| import json | |
| import numpy as np | |
| from collections import defaultdict | |
| import tensorflow as tf | |
| from sklearn.metrics import precision_recall_fscore_support | |
| try: | |
| from .scorer import fever_score | |
| except: | |
| from scorer import fever_score | |
| prefix = os.environ['PJ_HOME'] | |
| class FeverScorer: | |
| def __init__(self): | |
| self.id2label = {2: 'SUPPORTS', 0: 'REFUTES', 1: 'NOT ENOUGH INFO'} | |
| self.label2id = {value: key for key, value in self.id2label.items()} | |
| def get_scores(self, predicted_file, actual_file=f'{prefix}/data/fever/shared_task_dev.jsonl'): | |
| id2results = defaultdict(dict) | |
| with tf.io.gfile.GFile(predicted_file) as f: | |
| for line in f: | |
| js = json.loads(line) | |
| guid = js['id'] | |
| id2results[guid] = js | |
| with tf.io.gfile.GFile(actual_file) as fin: | |
| for line in fin: | |
| line = json.loads(line) | |
| guid = line['id'] | |
| evidence = line['evidence'] | |
| label = line['label'] | |
| id2results[guid]['evidence'] = evidence | |
| id2results[guid]['label'] = label | |
| results = self.label_score(list(id2results.values())) | |
| score, accuracy, precision, recall, f1 = fever_score(list(id2results.values())) | |
| results.update({ | |
| 'Evidence Precision': precision, | |
| 'Evidence Recall': recall, | |
| 'Evidence F1': f1, | |
| 'FEVER Score': score, | |
| 'Label Accuracy': accuracy | |
| }) | |
| return results | |
| def label_score(self, results): | |
| truth = np.array([v['label'] for v in results]) | |
| prediction = np.array([v['predicted_label'] for v in results]) | |
| labels = list(self.label2id.keys()) | |
| results = {} | |
| p, r, f, _ = precision_recall_fscore_support(truth, prediction, labels=labels) | |
| for i, label in enumerate(self.label2id.keys()): | |
| results['{} Precision'.format(label)] = p[i] | |
| results['{} Recall'.format(label)] = r[i] | |
| results['{} F1'.format(label)] = f[i] | |
| return results | |
| if __name__ == '__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--predicted_file", '-i', type=str) | |
| args = parser.parse_args() | |
| scorer = FeverScorer() | |
| results = scorer.get_scores(args.predicted_file) | |
| print(json.dumps(results, ensure_ascii=False, indent=4)) | |