Spaces:
Runtime error
Runtime error
| import json | |
| from typing import List, Dict | |
| from nsql.qa_module.openai_qa import OpenAIQAModel | |
| from nsql.qa_module.vqa import vqa_call | |
| from nsql.database import NeuralDB | |
| from nsql.parser import get_cfg_tree, get_steps, remove_duplicate, TreeNode, parse_question_paras, nsql_role_recognize, \ | |
| extract_answers | |
| class NSQLExecutor(object): | |
| def __init__(self, args, keys=None): | |
| self.new_col_name_id = 0 | |
| self.qa_model = OpenAIQAModel(args, keys) | |
| def generate_new_col_names(self, number): | |
| col_names = ["col_{}".format(i) for i in range(self.new_col_name_id, self.new_col_name_id + number)] | |
| self.new_col_name_id += number | |
| return col_names | |
| def sql_exec(self, sql: str, db: NeuralDB, verbose=True): | |
| if verbose: | |
| print("Exec SQL '{}' with additional row_id on {}".format(sql, db)) | |
| result = db.execute_query(sql) | |
| return result | |
| def nsql_exec(self, stamp, nsql: str, db: NeuralDB, verbose=True): | |
| steps = [] | |
| root_node = get_cfg_tree(nsql) # Parse execution tree from nsql. | |
| get_steps(root_node, steps) # Flatten the execution tree and get the steps. | |
| steps = remove_duplicate(steps) # Remove the duplicate steps. | |
| if verbose: | |
| print("Steps:", [s.rename for s in steps]) | |
| with open("tmp_for_vis/{}_tmp_for_vis_steps.txt".format(stamp), "w") as f: | |
| json.dump([s.rename for s in steps], f) | |
| col_idx = 0 | |
| for step in steps: | |
| # All steps should be formatted as 'QA()' except for last step which could also be normal SQL. | |
| assert isinstance(step, TreeNode), "step must be treenode" | |
| nsql = step.rename | |
| if nsql.startswith('QA('): | |
| question, sql_s = parse_question_paras(nsql, self.qa_model) | |
| sql_executed_sub_tables = [] | |
| # Execute all SQLs and get the results as parameters | |
| for sql_item in sql_s: | |
| role, sql_item = nsql_role_recognize(sql_item, | |
| db.get_header(), | |
| db.get_passages_titles(), | |
| db.get_images_titles()) | |
| if role in ['col', 'complete_sql']: | |
| sql_executed_sub_table = self.sql_exec(sql_item, db, verbose=verbose) | |
| sql_executed_sub_tables.append(sql_executed_sub_table) | |
| elif role == 'val': | |
| val = eval(sql_item) | |
| sql_executed_sub_tables.append({ | |
| "header": ["row_id", "val"], | |
| "rows": [["0", val]] | |
| }) | |
| elif role == 'passage_title_and_image_title': | |
| sql_executed_sub_tables.append({ | |
| "header": ["row_id", "{}".format(sql_item)], | |
| "rows": [["0", db.get_passage_by_title(sql_item) + | |
| db.get_image_caption_by_title(sql_item) | |
| # "{} (The answer of '{}' is {})".format( | |
| # sql_item, | |
| # # Add image qa result as backup info | |
| # question[len("***@"):], | |
| # vqa_call(question=question[len("***@"):], | |
| # image_path=db.get_image_by_title(sql_item))) | |
| ]] | |
| }) | |
| elif role == 'passage_title': | |
| sql_executed_sub_tables.append({ | |
| "header": ["row_id", "{}".format(sql_item)], | |
| "rows": [["0", db.get_passage_by_title(sql_item)]] | |
| }) | |
| elif role == 'image_title': | |
| sql_executed_sub_tables.append({ | |
| "header": ["row_id", "{}".format(sql_item)], | |
| "rows": [["0", db.get_image_caption_by_title(sql_item)]], | |
| # "rows": [["0", "{} (The answer of '{}' is {})".format( | |
| # sql_item, | |
| # # Add image qa result as backup info | |
| # question[len("***@"):], | |
| # vqa_call(question=question[len("***@"):], | |
| # image_path=db.get_image_by_title(sql_item)))]], | |
| }) | |
| # If the sub_tables to execute with link, append it to the cell. | |
| passage_linker = db.get_passage_linker() | |
| image_linker = db.get_image_linker() | |
| for _sql_executed_sub_table in sql_executed_sub_tables: | |
| for i in range(len(_sql_executed_sub_table['rows'])): | |
| for j in range(len(_sql_executed_sub_table['rows'][i])): | |
| _cell = _sql_executed_sub_table['rows'][i][j] | |
| if _cell in passage_linker.keys(): | |
| _sql_executed_sub_table['rows'][i][j] += " ({})".format( | |
| # Add passage text as backup info | |
| db.get_passage_by_title(passage_linker[_cell])) | |
| if _cell in image_linker.keys(): | |
| _sql_executed_sub_table['rows'][i][j] += " ({})".format( | |
| # Add image caption as backup info | |
| db.get_image_caption_by_title(image_linker[_cell])) | |
| # _sql_executed_sub_table['rows'][i][j] += " (The answer of '{}' is {})".format( | |
| # # Add image qa result as backup info | |
| # question[len("***@"):], | |
| # vqa_call(question=question[len("***@"):], | |
| # image_path=db.get_image_by_title(image_linker[_cell]))) | |
| pass | |
| if question.lower().startswith("map@"): | |
| # When the question is a type of mapping, we return the mapped column. | |
| question = question[len("map@"):] | |
| if step.father: | |
| step.rename_father_col(col_idx=col_idx) | |
| sub_table: Dict = self.qa_model.qa(question, | |
| sql_executed_sub_tables, | |
| table_title=db.table_title, | |
| qa_type="map", | |
| new_col_name_s=step.produced_col_name_s, | |
| verbose=verbose) | |
| with open("tmp_for_vis/{}_result_step_{}_input.txt".format(stamp, steps.index(step)), "w") as f: | |
| json.dump(sql_executed_sub_tables, f) | |
| with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f: | |
| json.dump(sub_table, f) | |
| db.add_sub_table(sub_table, verbose=verbose) | |
| col_idx += 1 | |
| else: # This step is the final step | |
| sub_table: Dict = self.qa_model.qa(question, | |
| sql_executed_sub_tables, | |
| table_title=db.table_title, | |
| qa_type="map", | |
| new_col_name_s=["col_{}".format(col_idx)], | |
| verbose=verbose) | |
| with open("tmp_for_vis/{}_result_step_{}_input.txt".format(stamp, steps.index(step)), "w") as f: | |
| json.dump(sql_executed_sub_tables, f) | |
| with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f: | |
| json.dump(sub_table, f) | |
| return extract_answers(sub_table) | |
| elif question.lower().startswith("ans@"): | |
| # When the question is a type of answering, we return an answer list. | |
| question = question[len("ans@"):] | |
| answer: List = self.qa_model.qa(question, | |
| sql_executed_sub_tables, | |
| table_title=db.table_title, | |
| qa_type="ans", | |
| verbose=verbose) | |
| with open("tmp_for_vis/{}_result_step_{}_input.txt".format(stamp, steps.index(step)), "w") as f: | |
| json.dump(sql_executed_sub_tables, f) | |
| with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f: | |
| json.dump(answer, f) | |
| if step.father: | |
| step.rename_father_val(answer) | |
| else: # This step is the final step | |
| return answer | |
| else: | |
| raise ValueError( | |
| "Except for operators or NL question must start with 'map@' or 'ans@'!, check '{}'".format( | |
| question)) | |
| else: | |
| sub_table = self.sql_exec(nsql, db, verbose=verbose) | |
| with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f: | |
| json.dump(sub_table, f) | |
| return extract_answers(sub_table) | |