Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| from generation.prompt import OpenAIQAPromptBuilder | |
| from generation.generator import Generator | |
| from retrieval.retriever import OpenAIQARetriever | |
| from retrieval.retrieve_pool import OpenAIQARetrievePool, QAItem | |
| num_parallel_prompts = 10 | |
| num_qa_shots = 8 | |
| infinite_rows_len = 50 # If the table contain rows larger than this number, it will be handled rows by rows. | |
| max_tokens = 1024 | |
| ROOT_DIR = os.path.join(os.path.dirname(__file__), "../../") | |
| class OpenAIQAModel(object): | |
| def __init__(self, args, keys=None): | |
| super().__init__() | |
| # Prepare keys | |
| self.key_current_id = 0 | |
| self.keys = keys | |
| random.seed(42) | |
| random.shuffle(self.keys) | |
| retrieve_pool = OpenAIQARetrievePool( | |
| data_path=os.path.join(ROOT_DIR, args.qa_retrieve_pool_file) | |
| ) | |
| self.retriever = OpenAIQARetriever(retrieve_pool) | |
| self.generator = Generator(args=None, keys=self.keys) # Just to use its call api function | |
| self.prompting_method = 'new_db' | |
| self.answer_split_token: str = ';' | |
| self.db_mapping_token = "\t" | |
| def call_openai_api_completion(self, prompt): | |
| completion = self.generator._call_codex_api(engine="text-davinci-002", | |
| prompt=prompt, | |
| max_tokens=max_tokens, | |
| temperature=0, | |
| top_p=1, | |
| n=1, | |
| stop=["\n\n"]) | |
| return completion | |
| def call_openai_for_completion_text(self, prompt, openai_usage_type="completion"): | |
| if openai_usage_type == "completion": | |
| completion = self.call_openai_api_completion(prompt) | |
| return completion.choices[0].text | |
| else: | |
| raise ValueError("The model usage type '{}' doesn't exists!".format(openai_usage_type)) | |
| def merge_tables(tables, by='row_id'): | |
| assert len(set([len(_table['rows']) for _table in tables])) == 1, "Tables must have the same rows!" | |
| merged_header = [by] | |
| by_idx = tables[0]['header'].index(by) | |
| merged_rows = [[_row[by_idx]] for _row in tables[0]['rows']] | |
| for _table in tables: | |
| header, rows = _table['header'], _table['rows'] | |
| for col_idx, col in enumerate(header): | |
| if col == by: | |
| continue | |
| if col in merged_header: | |
| # When the column is duplicate, and postfix _0, _1 etc. | |
| col = "{}_{}".format(col, merged_header.count(col)) | |
| merged_header.append(col) | |
| for i, row in enumerate(rows): | |
| merged_rows[i].append(row[col_idx]) | |
| return {"header": merged_header, "rows": merged_rows} | |
| def wrap_with_prompt_for_table_qa(self, | |
| question, | |
| sub_table, | |
| table_title=None, | |
| answer_split_token=None, | |
| qa_type="ans", | |
| prompting_method="new_db", | |
| db_mapping_token="π ", | |
| verbose=True): | |
| prompt = "Question Answering Over Database:\n\n" | |
| if qa_type in ['map', 'ans'] and num_qa_shots > 0: | |
| query_item = QAItem(qa_question=question, table=sub_table, title=table_title) | |
| retrieved_items = self.retriever.retrieve(item=query_item, num_shots=num_qa_shots, qa_type=qa_type) | |
| few_shot_prompt_list = [] | |
| for item in retrieved_items: | |
| one_shot_prompt = OpenAIQAPromptBuilder.build_one_shot_prompt( | |
| item=item, | |
| answer_split_token=answer_split_token, | |
| verbose=verbose, | |
| prompting_method=prompting_method, | |
| db_mapping_token=db_mapping_token | |
| ) | |
| few_shot_prompt_list.append(one_shot_prompt) | |
| few_shot_prompt = '\n'.join(few_shot_prompt_list[:num_qa_shots]) | |
| prompt = few_shot_prompt | |
| prompt += "\nGive a database as shown below:\n{}\n\n".format( | |
| OpenAIQAPromptBuilder.table2codex_prompt(sub_table, table_title) | |
| ) | |
| if qa_type == "map": | |
| prompt += "Q: Answer question \"{}\" row by row.".format(question) | |
| assert answer_split_token is not None | |
| if prompting_method == "basic": | |
| prompt += " The answer should be a list split by '{}' and have {} items in total.".format( | |
| answer_split_token, len(sub_table['rows'])) | |
| elif qa_type == "ans": | |
| prompt += "Q: Answer question \"{}\" for the table.".format(question) | |
| prompt += " " | |
| else: | |
| raise ValueError("The QA type is not supported!") | |
| prompt += "\n" | |
| if qa_type == "map": | |
| if prompting_method == "basic": | |
| prompt += "A:" | |
| elif qa_type == "ans": | |
| prompt += "A:" | |
| return prompt | |
| def qa(self, question, sub_tables, qa_type: str, verbose: bool = True, **args): | |
| # If it is not a problem API can handle, answer it with a QA model. | |
| merged_table = OpenAIQAModel.merge_tables(sub_tables) | |
| if verbose: | |
| print("Make Question {} on {}".format(question, merged_table)) | |
| if qa_type == "map": | |
| # Map: col(s) -question> one col | |
| # Make model make a QA towards a sub-table | |
| # col(s) -> one col, all QA in one time | |
| def do_map(_table): | |
| _prompt = self.wrap_with_prompt_for_table_qa(question, | |
| _table, | |
| args['table_title'], | |
| self.answer_split_token, | |
| qa_type, | |
| prompting_method=self.prompting_method, | |
| db_mapping_token=self.db_mapping_token, | |
| verbose=verbose) | |
| completion_str = self.call_openai_for_completion_text(_prompt).lower().strip(' []') | |
| if verbose: | |
| print(f'QA map@ input:\n{_prompt}') | |
| print(f'QA map@ output:\n{completion_str}') | |
| if self.prompting_method == "basic": | |
| answers = [_answer.strip(" '").lower() for _answer in | |
| completion_str.split(self.answer_split_token)] | |
| elif self.prompting_method == "new_db": | |
| answers = [line.split(self.db_mapping_token)[-1] for line in completion_str.split("\n")[2:-1]] | |
| else: | |
| raise ValueError("No such prompting methods: '{}'! ".format(self.prompting_method)) | |
| return answers | |
| # Handle infinite rows, rows by rows. | |
| answers = [] | |
| rows_len = len(merged_table['rows']) | |
| run_times = int(rows_len / infinite_rows_len) if rows_len % infinite_rows_len == 0 else int( | |
| rows_len / infinite_rows_len) + 1 | |
| for run_idx in range(run_times): | |
| _table = { | |
| "header": merged_table['header'], | |
| "rows": merged_table['rows'][run_idx * infinite_rows_len:] | |
| } if run_idx == run_times - 1 else \ | |
| { | |
| "header": merged_table['header'], | |
| "rows": merged_table['rows'][run_idx * infinite_rows_len:(run_idx + 1) * infinite_rows_len] | |
| } | |
| answers.extend(do_map(_table)) | |
| if verbose: | |
| print("The map@ openai answers are {}".format(answers)) | |
| # Add row_id in addition for finding to corresponding rows. | |
| return {"header": ['row_id'] + args['new_col_name_s'], | |
| "rows": [[row[0], answer] for row, answer in zip(merged_table['rows'], answers)]} | |
| elif qa_type == "ans": | |
| # Ans: col(s) -question> answer | |
| prompt = self.wrap_with_prompt_for_table_qa(question, | |
| merged_table, | |
| args['table_title'], | |
| prompting_method=self.prompting_method, | |
| verbose=verbose) | |
| answers = [self.call_openai_for_completion_text(prompt).lower().strip(' []')] | |
| if verbose: | |
| print(f'QA ans@ input:\n{prompt}') | |
| print(f'QA ans@ output:\n{answers}') | |
| return answers | |
| else: | |
| raise ValueError("Please choose from map and ans in the qa usage!!") | |