Spaces:
Runtime error
Runtime error
| """ | |
| General utilities. | |
| """ | |
| import json | |
| import os | |
| from typing import List, Union, Dict | |
| from functools import cmp_to_key | |
| import math | |
| from collections.abc import Iterable | |
| from datasets import load_dataset | |
| ROOT_DIR = os.path.join(os.path.dirname(__file__), "../") | |
| def _load_table(table_path) -> dict: | |
| """ | |
| attention: the table_path must be the .tsv path. | |
| Load the WikiTableQuestion from csv file. Result in a dict format like: | |
| {"header": [header1, header2,...], "rows": [[row11, row12, ...], [row21,...]... [...rownm]]} | |
| """ | |
| def __extract_content(_line: str): | |
| _vals = [_.replace("\n", " ").strip() for _ in _line.strip("\n").split("\t")] | |
| return _vals | |
| with open(table_path, "r") as f: | |
| lines = f.readlines() | |
| rows = [] | |
| for i, line in enumerate(lines): | |
| line = line.strip('\n') | |
| if i == 0: | |
| header = line.split("\t") | |
| else: | |
| rows.append(__extract_content(line)) | |
| table_item = {"header": header, "rows": rows} | |
| # Defense assertion | |
| for i in range(len(rows) - 1): | |
| if not len(rows[i]) == len(rows[i - 1]): | |
| raise ValueError('some rows have diff cols.') | |
| return table_item | |
| def majority_vote( | |
| nsqls: List, | |
| pred_answer_list: List, | |
| allow_none_and_empty_answer: bool = False, | |
| allow_error_answer: bool = False, | |
| answer_placeholder: Union[str, int] = '<error|empty>', | |
| vote_method: str = 'prob', | |
| answer_biased: Union[str, int] = None, | |
| answer_biased_weight: float = None, | |
| ): | |
| """ | |
| Determine the final nsql execution answer by majority vote. | |
| """ | |
| def _compare_answer_vote_simple(a, b): | |
| """ | |
| First compare occur times. If equal, then compare max nsql logprob. | |
| """ | |
| if a[1]['count'] > b[1]['count']: | |
| return 1 | |
| elif a[1]['count'] < b[1]['count']: | |
| return -1 | |
| else: | |
| if a[1]['nsqls'][0][1] > b[1]['nsqls'][0][1]: | |
| return 1 | |
| elif a[1]['nsqls'][0][1] == b[1]['nsqls'][0][1]: | |
| return 0 | |
| else: | |
| return -1 | |
| def _compare_answer_vote_with_prob(a, b): | |
| """ | |
| Compare prob sum. | |
| """ | |
| return 1 if sum([math.exp(nsql[1]) for nsql in a[1]['nsqls']]) > sum( | |
| [math.exp(nsql[1]) for nsql in b[1]['nsqls']]) else -1 | |
| # Vote answers | |
| candi_answer_dict = dict() | |
| for (nsql, logprob), pred_answer in zip(nsqls, pred_answer_list): | |
| if allow_none_and_empty_answer: | |
| if pred_answer == [None] or pred_answer == []: | |
| pred_answer = [answer_placeholder] | |
| if allow_error_answer: | |
| if pred_answer == '<error>': | |
| pred_answer = [answer_placeholder] | |
| # Invalid execution results | |
| if pred_answer == '<error>' or pred_answer == [None] or pred_answer == []: | |
| continue | |
| if candi_answer_dict.get(tuple(pred_answer), None) is None: | |
| candi_answer_dict[tuple(pred_answer)] = { | |
| 'count': 0, | |
| 'nsqls': [] | |
| } | |
| answer_info = candi_answer_dict.get(tuple(pred_answer), None) | |
| answer_info['count'] += 1 | |
| answer_info['nsqls'].append([nsql, logprob]) | |
| # All candidates execution errors | |
| if len(candi_answer_dict) == 0: | |
| return answer_placeholder, [(nsqls[0][0], nsqls[0][-1])] | |
| # Sort | |
| if vote_method == 'simple': | |
| sorted_candi_answer_list = sorted(list(candi_answer_dict.items()), | |
| key=cmp_to_key(_compare_answer_vote_simple), reverse=True) | |
| elif vote_method == 'prob': | |
| sorted_candi_answer_list = sorted(list(candi_answer_dict.items()), | |
| key=cmp_to_key(_compare_answer_vote_with_prob), reverse=True) | |
| elif vote_method == 'answer_biased': | |
| # Specifically for Tabfact entailed answer, i.e., `1`. | |
| # If there exists nsql that produces `1`, we consider it more significant because `0` is very common. | |
| assert answer_biased_weight is not None and answer_biased_weight > 0 | |
| for answer, answer_dict in candi_answer_dict.items(): | |
| if answer == (answer_biased,): | |
| answer_dict['count'] *= answer_biased_weight | |
| sorted_candi_answer_list = sorted(list(candi_answer_dict.items()), | |
| key=cmp_to_key(_compare_answer_vote_simple), reverse=True) | |
| elif vote_method == 'lf_biased': | |
| # Assign weights to different types of logic forms (lf) to control interpretability and coverage | |
| for answer, answer_dict in candi_answer_dict.items(): | |
| count = 0 | |
| for nsql, _ in answer_dict['nsqls']: | |
| if 'map@' in nsql: | |
| count += 10 | |
| elif 'ans@' in nsql: | |
| count += 10 | |
| else: | |
| count += 1 | |
| answer_dict['count'] = count | |
| sorted_candi_answer_list = sorted(list(candi_answer_dict.items()), | |
| key=cmp_to_key(_compare_answer_vote_simple), reverse=True) | |
| else: | |
| raise ValueError(f"Vote method {vote_method} is not supported.") | |
| pred_answer_info = sorted_candi_answer_list[0] | |
| pred_answer, pred_answer_nsqls = list(pred_answer_info[0]), pred_answer_info[1]['nsqls'] | |
| return pred_answer, pred_answer_nsqls | |
| def load_data_split(dataset_to_load, split, data_dir=os.path.join(ROOT_DIR, 'datasets/')): | |
| dataset_split_loaded = load_dataset( | |
| path=os.path.join(data_dir, "{}.py".format(dataset_to_load)), | |
| cache_dir=os.path.join(data_dir, "data"))[split] | |
| # unify names of keys | |
| if dataset_to_load in ['wikitq', 'has_squall', 'missing_squall', | |
| 'wikitq', 'wikitq_sql_solvable', 'wikitq_sql_unsolvable', | |
| 'wikitq_sql_unsolvable_but_in_squall', | |
| 'wikitq_scalability_ori', | |
| 'wikitq_scalability_100rows', | |
| 'wikitq_scalability_200rows', | |
| 'wikitq_scalability_500rows', | |
| 'wikitq_robustness' | |
| ]: | |
| pass | |
| elif dataset_to_load == 'tab_fact': | |
| new_dataset_split_loaded = [] | |
| for data_item in dataset_split_loaded: | |
| data_item['question'] = data_item['statement'] | |
| data_item['answer_text'] = data_item['label'] | |
| data_item['table']['page_title'] = data_item['table']['caption'] | |
| new_dataset_split_loaded.append(data_item) | |
| dataset_split_loaded = new_dataset_split_loaded | |
| elif dataset_to_load == 'hybridqa': | |
| new_dataset_split_loaded = [] | |
| for data_item in dataset_split_loaded: | |
| data_item['table']['page_title'] = data_item['context'].split(' | ')[0] | |
| new_dataset_split_loaded.append(data_item) | |
| dataset_split_loaded = new_dataset_split_loaded | |
| elif dataset_to_load == 'mmqa': | |
| new_dataset_split_loaded = [] | |
| for data_item in dataset_split_loaded: | |
| data_item['table']['page_title'] = data_item['table']['title'] | |
| new_dataset_split_loaded.append(data_item) | |
| dataset_split_loaded = new_dataset_split_loaded | |
| else: | |
| raise ValueError(f'{dataset_to_load} dataset is not supported now.') | |
| return dataset_split_loaded | |
| def pprint_dict(dic): | |
| print(json.dumps(dic, indent=2)) | |
| def flatten(nested_list): | |
| for x in nested_list: | |
| if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): | |
| yield from flatten(x) | |
| else: | |
| yield x | |