Spaces:
Runtime error
Runtime error
| # Copyright 2016 Google Inc. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Loads the WikiQuestions dataset. | |
| An example consists of question, table. Additionally, we store the processed | |
| columns which store the entries after performing number, date and other | |
| preprocessing as done in the baseline. | |
| columns, column names and processed columns are split into word and number | |
| columns. | |
| lookup answer (or matrix) is also split into number and word lookup matrix | |
| Author: aneelakantan (Arvind Neelakantan) | |
| """ | |
| from __future__ import print_function | |
| import math | |
| import os | |
| import re | |
| import numpy as np | |
| import unicodedata as ud | |
| import tensorflow as tf | |
| bad_number = -200000.0 #number that is added to a corrupted table entry in a number column | |
| def is_nan_or_inf(number): | |
| return math.isnan(number) or math.isinf(number) | |
| def strip_accents(s): | |
| u = unicode(s, "utf-8") | |
| u_new = ''.join(c for c in ud.normalize('NFKD', u) if ud.category(c) != 'Mn') | |
| return u_new.encode("utf-8") | |
| def correct_unicode(string): | |
| string = strip_accents(string) | |
| string = re.sub("\xc2\xa0", " ", string).strip() | |
| string = re.sub("\xe2\x80\x93", "-", string).strip() | |
| #string = re.sub(ur'[\u0300-\u036F]', "", string) | |
| string = re.sub("‚", ",", string) | |
| string = re.sub("…", "...", string) | |
| #string = re.sub("[·・]", ".", string) | |
| string = re.sub("ˆ", "^", string) | |
| string = re.sub("˜", "~", string) | |
| string = re.sub("‹", "<", string) | |
| string = re.sub("›", ">", string) | |
| #string = re.sub("[‘’´`]", "'", string) | |
| #string = re.sub("[“â€Â«Â»]", "\"", string) | |
| #string = re.sub("[•†‡]", "", string) | |
| #string = re.sub("[â€â€‘–—]", "-", string) | |
| string = re.sub(r'[\u2E00-\uFFFF]', "", string) | |
| string = re.sub("\\s+", " ", string).strip() | |
| return string | |
| def simple_normalize(string): | |
| string = correct_unicode(string) | |
| # Citations | |
| string = re.sub("\[(nb ?)?\d+\]", "", string) | |
| string = re.sub("\*+$", "", string) | |
| # Year in parenthesis | |
| string = re.sub("\(\d* ?-? ?\d*\)", "", string) | |
| string = re.sub("^\"(.*)\"$", "", string) | |
| return string | |
| def full_normalize(string): | |
| #print "an: ", string | |
| string = simple_normalize(string) | |
| # Remove trailing info in brackets | |
| string = re.sub("\[[^\]]*\]", "", string) | |
| # Remove most unicode characters in other languages | |
| string = re.sub(r'[\u007F-\uFFFF]', "", string.strip()) | |
| # Remove trailing info in parenthesis | |
| string = re.sub("\([^)]*\)$", "", string.strip()) | |
| string = final_normalize(string) | |
| # Get rid of question marks | |
| string = re.sub("\?", "", string).strip() | |
| # Get rid of trailing colons (usually occur in column titles) | |
| string = re.sub("\:$", " ", string).strip() | |
| # Get rid of slashes | |
| string = re.sub(r"/", " ", string).strip() | |
| string = re.sub(r"\\", " ", string).strip() | |
| # Replace colon, slash, and dash with space | |
| # Note: need better replacement for this when parsing time | |
| string = re.sub(r"\:", " ", string).strip() | |
| string = re.sub("/", " ", string).strip() | |
| string = re.sub("-", " ", string).strip() | |
| # Convert empty strings to UNK | |
| # Important to do this last or near last | |
| if not string: | |
| string = "UNK" | |
| return string | |
| def final_normalize(string): | |
| # Remove leading and trailing whitespace | |
| string = re.sub("\\s+", " ", string).strip() | |
| # Convert entirely to lowercase | |
| string = string.lower() | |
| # Get rid of strangely escaped newline characters | |
| string = re.sub("\\\\n", " ", string).strip() | |
| # Get rid of quotation marks | |
| string = re.sub(r"\"", "", string).strip() | |
| string = re.sub(r"\'", "", string).strip() | |
| string = re.sub(r"`", "", string).strip() | |
| # Get rid of * | |
| string = re.sub("\*", "", string).strip() | |
| return string | |
| def is_number(x): | |
| try: | |
| f = float(x) | |
| return not is_nan_or_inf(f) | |
| except ValueError: | |
| return False | |
| except TypeError: | |
| return False | |
| class WikiExample(object): | |
| def __init__(self, id, question, answer, table_key): | |
| self.question_id = id | |
| self.question = question | |
| self.answer = answer | |
| self.table_key = table_key | |
| self.lookup_matrix = [] | |
| self.is_bad_example = False | |
| self.is_word_lookup = False | |
| self.is_ambiguous_word_lookup = False | |
| self.is_number_lookup = False | |
| self.is_number_calc = False | |
| self.is_unknown_answer = False | |
| class TableInfo(object): | |
| def __init__(self, word_columns, word_column_names, word_column_indices, | |
| number_columns, number_column_names, number_column_indices, | |
| processed_word_columns, processed_number_columns, orig_columns): | |
| self.word_columns = word_columns | |
| self.word_column_names = word_column_names | |
| self.word_column_indices = word_column_indices | |
| self.number_columns = number_columns | |
| self.number_column_names = number_column_names | |
| self.number_column_indices = number_column_indices | |
| self.processed_word_columns = processed_word_columns | |
| self.processed_number_columns = processed_number_columns | |
| self.orig_columns = orig_columns | |
| class WikiQuestionLoader(object): | |
| def __init__(self, data_name, root_folder): | |
| self.root_folder = root_folder | |
| self.data_folder = os.path.join(self.root_folder, "data") | |
| self.examples = [] | |
| self.data_name = data_name | |
| def num_questions(self): | |
| return len(self.examples) | |
| def load_qa(self): | |
| data_source = os.path.join(self.data_folder, self.data_name) | |
| f = tf.gfile.GFile(data_source, "r") | |
| id_regex = re.compile("\(id ([^\)]*)\)") | |
| for line in f: | |
| id_match = id_regex.search(line) | |
| id = id_match.group(1) | |
| self.examples.append(id) | |
| def load(self): | |
| self.load_qa() | |
| def is_date(word): | |
| if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))): | |
| return False | |
| if (len(word) != 10): | |
| return False | |
| if (word[4] != "-"): | |
| return False | |
| if (word[7] != "-"): | |
| return False | |
| for i in range(len(word)): | |
| if (not (word[i] == "X" or word[i] == "x" or word[i] == "-" or re.search( | |
| "[0-9]", word[i]))): | |
| return False | |
| return True | |
| class WikiQuestionGenerator(object): | |
| def __init__(self, train_name, dev_name, test_name, root_folder): | |
| self.train_name = train_name | |
| self.dev_name = dev_name | |
| self.test_name = test_name | |
| self.train_loader = WikiQuestionLoader(train_name, root_folder) | |
| self.dev_loader = WikiQuestionLoader(dev_name, root_folder) | |
| self.test_loader = WikiQuestionLoader(test_name, root_folder) | |
| self.bad_examples = 0 | |
| self.root_folder = root_folder | |
| self.data_folder = os.path.join(self.root_folder, "annotated/data") | |
| self.annotated_examples = {} | |
| self.annotated_tables = {} | |
| self.annotated_word_reject = {} | |
| self.annotated_word_reject["-lrb-"] = 1 | |
| self.annotated_word_reject["-rrb-"] = 1 | |
| self.annotated_word_reject["UNK"] = 1 | |
| def is_money(self, word): | |
| if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))): | |
| return False | |
| for i in range(len(word)): | |
| if (not (word[i] == "E" or word[i] == "." or re.search("[0-9]", | |
| word[i]))): | |
| return False | |
| return True | |
| def remove_consecutive(self, ner_tags, ner_values): | |
| for i in range(len(ner_tags)): | |
| if ((ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or | |
| ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE") and | |
| i + 1 < len(ner_tags) and ner_tags[i] == ner_tags[i + 1] and | |
| ner_values[i] == ner_values[i + 1] and ner_values[i] != ""): | |
| word = ner_values[i] | |
| word = word.replace(">", "").replace("<", "").replace("=", "").replace( | |
| "%", "").replace("~", "").replace("$", "").replace("£", "").replace( | |
| "€", "") | |
| if (re.search("[A-Z]", word) and not (is_date(word)) and not ( | |
| self.is_money(word))): | |
| ner_values[i] = "A" | |
| else: | |
| ner_values[i] = "," | |
| return ner_tags, ner_values | |
| def pre_process_sentence(self, tokens, ner_tags, ner_values): | |
| sentence = [] | |
| tokens = tokens.split("|") | |
| ner_tags = ner_tags.split("|") | |
| ner_values = ner_values.split("|") | |
| ner_tags, ner_values = self.remove_consecutive(ner_tags, ner_values) | |
| #print "old: ", tokens | |
| for i in range(len(tokens)): | |
| word = tokens[i] | |
| if (ner_values[i] != "" and | |
| (ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or | |
| ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE")): | |
| word = ner_values[i] | |
| word = word.replace(">", "").replace("<", "").replace("=", "").replace( | |
| "%", "").replace("~", "").replace("$", "").replace("£", "").replace( | |
| "€", "") | |
| if (re.search("[A-Z]", word) and not (is_date(word)) and not ( | |
| self.is_money(word))): | |
| word = tokens[i] | |
| if (is_number(ner_values[i])): | |
| word = float(ner_values[i]) | |
| elif (is_number(word)): | |
| word = float(word) | |
| if (tokens[i] == "score"): | |
| word = "score" | |
| if (is_number(word)): | |
| word = float(word) | |
| if (not (self.annotated_word_reject.has_key(word))): | |
| if (is_number(word) or is_date(word) or self.is_money(word)): | |
| sentence.append(word) | |
| else: | |
| word = full_normalize(word) | |
| if (not (self.annotated_word_reject.has_key(word)) and | |
| bool(re.search("[a-z0-9]", word, re.IGNORECASE))): | |
| m = re.search(",", word) | |
| sentence.append(word.replace(",", "")) | |
| if (len(sentence) == 0): | |
| sentence.append("UNK") | |
| return sentence | |
| def load_annotated_data(self, in_file): | |
| self.annotated_examples = {} | |
| self.annotated_tables = {} | |
| f = tf.gfile.GFile(in_file, "r") | |
| counter = 0 | |
| for line in f: | |
| if (counter > 0): | |
| line = line.strip() | |
| (question_id, utterance, context, target_value, tokens, lemma_tokens, | |
| pos_tags, ner_tags, ner_values, target_canon) = line.split("\t") | |
| question = self.pre_process_sentence(tokens, ner_tags, ner_values) | |
| target_canon = target_canon.split("|") | |
| self.annotated_examples[question_id] = WikiExample( | |
| question_id, question, target_canon, context) | |
| self.annotated_tables[context] = [] | |
| counter += 1 | |
| print("Annotated examples loaded ", len(self.annotated_examples)) | |
| f.close() | |
| def is_number_column(self, a): | |
| for w in a: | |
| if (len(w) != 1): | |
| return False | |
| if (not (is_number(w[0]))): | |
| return False | |
| return True | |
| def convert_table(self, table): | |
| answer = [] | |
| for i in range(len(table)): | |
| temp = [] | |
| for j in range(len(table[i])): | |
| temp.append(" ".join([str(w) for w in table[i][j]])) | |
| answer.append(temp) | |
| return answer | |
| def load_annotated_tables(self): | |
| for table in self.annotated_tables.keys(): | |
| annotated_table = table.replace("csv", "annotated") | |
| orig_columns = [] | |
| processed_columns = [] | |
| f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r") | |
| counter = 0 | |
| for line in f: | |
| if (counter > 0): | |
| line = line.strip() | |
| line = line + "\t" * (13 - len(line.split("\t"))) | |
| (row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags, | |
| ner_values, number, date, num2, read_list) = line.split("\t") | |
| counter += 1 | |
| f.close() | |
| max_row = int(row) | |
| max_col = int(col) | |
| for i in range(max_col + 1): | |
| orig_columns.append([]) | |
| processed_columns.append([]) | |
| for j in range(max_row + 1): | |
| orig_columns[i].append(bad_number) | |
| processed_columns[i].append(bad_number) | |
| #print orig_columns | |
| f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r") | |
| counter = 0 | |
| column_names = [] | |
| for line in f: | |
| if (counter > 0): | |
| line = line.strip() | |
| line = line + "\t" * (13 - len(line.split("\t"))) | |
| (row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags, | |
| ner_values, number, date, num2, read_list) = line.split("\t") | |
| entry = self.pre_process_sentence(tokens, ner_tags, ner_values) | |
| if (row == "-1"): | |
| column_names.append(entry) | |
| else: | |
| orig_columns[int(col)][int(row)] = entry | |
| if (len(entry) == 1 and is_number(entry[0])): | |
| processed_columns[int(col)][int(row)] = float(entry[0]) | |
| else: | |
| for single_entry in entry: | |
| if (is_number(single_entry)): | |
| processed_columns[int(col)][int(row)] = float(single_entry) | |
| break | |
| nt = ner_tags.split("|") | |
| nv = ner_values.split("|") | |
| for i_entry in range(len(tokens.split("|"))): | |
| if (nt[i_entry] == "DATE" and | |
| is_number(nv[i_entry].replace("-", "").replace("X", ""))): | |
| processed_columns[int(col)][int(row)] = float(nv[ | |
| i_entry].replace("-", "").replace("X", "")) | |
| #processed_columns[int(col)][int(row)] = float(nv[i_entry]) | |
| if (len(entry) == 1 and (is_number(entry[0]) or is_date(entry[0]) or | |
| self.is_money(entry[0]))): | |
| if (len(entry) == 1 and not (is_number(entry[0])) and | |
| is_date(entry[0])): | |
| entry[0] = entry[0].replace("X", "x") | |
| counter += 1 | |
| word_columns = [] | |
| processed_word_columns = [] | |
| word_column_names = [] | |
| word_column_indices = [] | |
| number_columns = [] | |
| processed_number_columns = [] | |
| number_column_names = [] | |
| number_column_indices = [] | |
| for i in range(max_col + 1): | |
| if (self.is_number_column(orig_columns[i])): | |
| number_column_indices.append(i) | |
| number_column_names.append(column_names[i]) | |
| temp = [] | |
| for w in orig_columns[i]: | |
| if (is_number(w[0])): | |
| temp.append(w[0]) | |
| number_columns.append(temp) | |
| processed_number_columns.append(processed_columns[i]) | |
| else: | |
| word_column_indices.append(i) | |
| word_column_names.append(column_names[i]) | |
| word_columns.append(orig_columns[i]) | |
| processed_word_columns.append(processed_columns[i]) | |
| table_info = TableInfo( | |
| word_columns, word_column_names, word_column_indices, number_columns, | |
| number_column_names, number_column_indices, processed_word_columns, | |
| processed_number_columns, orig_columns) | |
| self.annotated_tables[table] = table_info | |
| f.close() | |
| def answer_classification(self): | |
| lookup_questions = 0 | |
| number_lookup_questions = 0 | |
| word_lookup_questions = 0 | |
| ambiguous_lookup_questions = 0 | |
| number_questions = 0 | |
| bad_questions = 0 | |
| ice_bad_questions = 0 | |
| tot = 0 | |
| got = 0 | |
| ice = {} | |
| with tf.gfile.GFile( | |
| self.root_folder + "/arvind-with-norms-2.tsv", mode="r") as f: | |
| lines = f.readlines() | |
| for line in lines: | |
| line = line.strip() | |
| if (not (self.annotated_examples.has_key(line.split("\t")[0]))): | |
| continue | |
| if (len(line.split("\t")) == 4): | |
| line = line + "\t" * (5 - len(line.split("\t"))) | |
| if (not (is_number(line.split("\t")[2]))): | |
| ice_bad_questions += 1 | |
| (example_id, ans_index, ans_raw, process_answer, | |
| matched_cells) = line.split("\t") | |
| if (ice.has_key(example_id)): | |
| ice[example_id].append(line.split("\t")) | |
| else: | |
| ice[example_id] = [line.split("\t")] | |
| for q_id in self.annotated_examples.keys(): | |
| tot += 1 | |
| example = self.annotated_examples[q_id] | |
| table_info = self.annotated_tables[example.table_key] | |
| # Figure out if the answer is numerical or lookup | |
| n_cols = len(table_info.orig_columns) | |
| n_rows = len(table_info.orig_columns[0]) | |
| example.lookup_matrix = np.zeros((n_rows, n_cols)) | |
| exact_matches = {} | |
| for (example_id, ans_index, ans_raw, process_answer, | |
| matched_cells) in ice[q_id]: | |
| for match_cell in matched_cells.split("|"): | |
| if (len(match_cell.split(",")) == 2): | |
| (row, col) = match_cell.split(",") | |
| row = int(row) | |
| col = int(col) | |
| if (row >= 0): | |
| exact_matches[ans_index] = 1 | |
| answer_is_in_table = len(exact_matches) == len(example.answer) | |
| if (answer_is_in_table): | |
| for (example_id, ans_index, ans_raw, process_answer, | |
| matched_cells) in ice[q_id]: | |
| for match_cell in matched_cells.split("|"): | |
| if (len(match_cell.split(",")) == 2): | |
| (row, col) = match_cell.split(",") | |
| row = int(row) | |
| col = int(col) | |
| example.lookup_matrix[row, col] = float(ans_index) + 1.0 | |
| example.lookup_number_answer = 0.0 | |
| if (answer_is_in_table): | |
| lookup_questions += 1 | |
| if len(example.answer) == 1 and is_number(example.answer[0]): | |
| example.number_answer = float(example.answer[0]) | |
| number_lookup_questions += 1 | |
| example.is_number_lookup = True | |
| else: | |
| #print "word lookup" | |
| example.calc_answer = example.number_answer = 0.0 | |
| word_lookup_questions += 1 | |
| example.is_word_lookup = True | |
| else: | |
| if (len(example.answer) == 1 and is_number(example.answer[0])): | |
| example.number_answer = example.answer[0] | |
| example.is_number_calc = True | |
| else: | |
| bad_questions += 1 | |
| example.is_bad_example = True | |
| example.is_unknown_answer = True | |
| example.is_lookup = example.is_word_lookup or example.is_number_lookup | |
| if not example.is_word_lookup and not example.is_bad_example: | |
| number_questions += 1 | |
| example.calc_answer = example.answer[0] | |
| example.lookup_number_answer = example.calc_answer | |
| # Split up the lookup matrix into word part and number part | |
| number_column_indices = table_info.number_column_indices | |
| word_column_indices = table_info.word_column_indices | |
| example.word_columns = table_info.word_columns | |
| example.number_columns = table_info.number_columns | |
| example.word_column_names = table_info.word_column_names | |
| example.processed_number_columns = table_info.processed_number_columns | |
| example.processed_word_columns = table_info.processed_word_columns | |
| example.number_column_names = table_info.number_column_names | |
| example.number_lookup_matrix = example.lookup_matrix[:, | |
| number_column_indices] | |
| example.word_lookup_matrix = example.lookup_matrix[:, word_column_indices] | |
| def load(self): | |
| train_data = [] | |
| dev_data = [] | |
| test_data = [] | |
| self.load_annotated_data( | |
| os.path.join(self.data_folder, "training.annotated")) | |
| self.load_annotated_tables() | |
| self.answer_classification() | |
| self.train_loader.load() | |
| self.dev_loader.load() | |
| for i in range(self.train_loader.num_questions()): | |
| example = self.train_loader.examples[i] | |
| example = self.annotated_examples[example] | |
| train_data.append(example) | |
| for i in range(self.dev_loader.num_questions()): | |
| example = self.dev_loader.examples[i] | |
| dev_data.append(self.annotated_examples[example]) | |
| self.load_annotated_data( | |
| os.path.join(self.data_folder, "pristine-unseen-tables.annotated")) | |
| self.load_annotated_tables() | |
| self.answer_classification() | |
| self.test_loader.load() | |
| for i in range(self.test_loader.num_questions()): | |
| example = self.test_loader.examples[i] | |
| test_data.append(self.annotated_examples[example]) | |
| return train_data, dev_data, test_data | |