Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import argparse | |
| import pandas as pd | |
| from collections import defaultdict | |
| from tinychart.eval.eval_metric import chartqa_evaluator, chartqapot_evaluator | |
| from tinychart.eval.eval_metric import chartqa_oracle_merger_evaluator, chartqa_rule_merger_evaluator | |
| def read_jsonl(jsonl_path): | |
| with open(jsonl_path, 'r') as f: | |
| data = [json.loads(line) for line in f] | |
| return data | |
| def write_jsonl(data, jsonl_path): | |
| with open(jsonl_path, 'w', encoding='utf-8') as f: | |
| for item in data: | |
| f.write(json.dumps(item) + '\n') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input', default='./output/') | |
| args = parser.parse_args() | |
| result_files = os.listdir(args.input) | |
| result_files = [f for f in result_files if f.endswith('.jsonl')] | |
| result_files.sort() | |
| direct_result, pot_result = None, None | |
| dataset2metric = defaultdict(float) | |
| for result_file in result_files: | |
| # print(result_file) | |
| dataset_name = '.'.join(result_file.split('.')[:-1]) | |
| file = os.path.join(args.input, result_file) | |
| result_data = read_jsonl(file) | |
| if 'chartqa-' in dataset_name: | |
| direct_result, direct_acc = chartqa_evaluator(result_data, key='model_answer') | |
| write_jsonl(direct_result, file) | |
| dataset2metric[dataset_name] = round(direct_acc * 100, 2) | |
| print(f'Direct Accuracy: {direct_acc}') | |
| elif 'chartqagptpot-' in dataset_name or 'chartqatemplatepot-' in dataset_name: | |
| pot_result, pot_acc, error_rate = chartqapot_evaluator(result_data) | |
| write_jsonl(pot_result, file) | |
| dataset2metric[dataset_name] = round(pot_acc * 100, 2) | |
| print(f'PoT Accuracy: {pot_acc}') | |
| print(f'PoT Error Rate: {error_rate}') | |
| if direct_result is not None and pot_result is not None: | |
| print("Calculate merging direct and pot results with simple divider") | |
| oracle_results, oracle_acc = chartqa_oracle_merger_evaluator(direct_result, pot_result) | |
| dataset2metric['merged-oracle'] = round(oracle_acc * 100, 2) | |
| print(f'Oracle Merged Accuracy: {oracle_acc}') | |
| write_jsonl(oracle_results, os.path.join(args.input, 'merged-oracle.jsonl')) | |
| rule_results, rule_acc = chartqa_rule_merger_evaluator(direct_result, pot_result) | |
| dataset2metric['merged-rule'] = round(rule_acc * 100, 2) | |
| print(f'Rule Merged Accuracy: {rule_acc}') | |
| write_jsonl(rule_results, os.path.join(args.input, 'merged-rule.jsonl')) | |
| # save metrics into tsv with key as the first row | |
| df = pd.DataFrame(dataset2metric, index=[0]) | |
| # if there is a metrics.tsv exists, add one in the name to avoid overwrite | |
| tsv_name = os.path.join(args.input, 'metrics.tsv') | |
| if os.path.exists(tsv_name): | |
| # avoid overwrite. if there is metrics.1.tsv, name it metrics.2.tsv... | |
| i = 1 | |
| tsv_name = os.path.join(args.input, f'metrics.{i}.tsv') | |
| while os.path.exists(tsv_name): | |
| i += 1 | |
| tsv_name = os.path.join(args.input, f'metrics.{i}.tsv') | |
| df.to_csv(tsv_name, sep='\t', index=False) | |
| print(f'Metrics saved at: {tsv_name}') | |
| print(df) | |