Spaces:
Paused
Paused
| import os | |
| import json | |
| import sqlite3 | |
| from func_timeout import func_set_timeout, FunctionTimedOut | |
| from utils.bridge_content_encoder import get_matched_entries | |
| from nltk.tokenize import word_tokenize | |
| from nltk import ngrams | |
| from whoosh.qparser import QueryParser | |
| def add_a_record(question, db_id): | |
| conn = sqlite3.connect('data/history/history.sqlite') | |
| cursor = conn.cursor() | |
| cursor.execute("INSERT INTO record (question, db_id) VALUES (?, ?)", (question, db_id)) | |
| conn.commit() | |
| conn.close() | |
| def obtain_n_grams(sequence, max_n): | |
| import nltk | |
| nltk.download('punkt') | |
| nltk.download('punkt_tab') | |
| tokens = word_tokenize(sequence) | |
| all_grams = [] | |
| for n in range(1, max_n + 1): | |
| all_grams.extend([" ".join(gram) for gram in ngrams(tokens, n)]) | |
| return all_grams | |
| # get the database cursor for a sqlite database path | |
| def get_cursor_from_path(sqlite_path): | |
| try: | |
| if not os.path.exists(sqlite_path): | |
| print("Openning a new connection %s" % sqlite_path) | |
| connection = sqlite3.connect(sqlite_path, check_same_thread = False) | |
| except Exception as e: | |
| print(sqlite_path) | |
| raise e | |
| connection.text_factory = lambda b: b.decode(errors="ignore") | |
| cursor = connection.cursor() | |
| return cursor | |
| # execute predicted sql with a time limitation | |
| def execute_sql(cursor, sql): | |
| cursor.execute(sql) | |
| return cursor.fetchall() | |
| # execute predicted sql with a long time limitation (for buiding content index) | |
| def execute_sql_long_time_limitation(cursor, sql): | |
| cursor.execute(sql) | |
| return cursor.fetchall() | |
| def check_sql_executability(generated_sql, db): | |
| if generated_sql.strip() == "": | |
| return "Error: empty string" | |
| try: | |
| cursor = get_cursor_from_path(db) | |
| execute_sql(cursor, generated_sql) | |
| execution_error = None | |
| except FunctionTimedOut as fto: | |
| print("SQL execution time out error: {}.".format(fto)) | |
| execution_error = "SQL execution times out." | |
| except Exception as e: | |
| print("SQL execution runtime error: {}.".format(e)) | |
| execution_error = str(e) | |
| return execution_error | |
| def is_number(s): | |
| try: | |
| float(s) | |
| return True | |
| except ValueError: | |
| return False | |
| def detect_special_char(name): | |
| for special_char in ['(', '-', ')', ' ', '/']: | |
| if special_char in name: | |
| return True | |
| return False | |
| def add_quotation_mark(s): | |
| return "`" + s + "`" | |
| def get_column_contents(column_name, table_name, cursor): | |
| select_column_sql = "SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL LIMIT 2;".format(column_name, table_name, column_name) | |
| results = execute_sql_long_time_limitation(cursor, select_column_sql) | |
| column_contents = [str(result[0]).strip() for result in results] | |
| # remove empty and extremely-long contents | |
| column_contents = [content for content in column_contents if len(content) != 0 and len(content) <= 25] | |
| return column_contents | |
| def get_matched_contents(question, searcher): | |
| # Coarse-grained matching between the input text and all contents in database | |
| grams = obtain_n_grams(question, 4) | |
| hits = [] | |
| # Parse each n-gram query into a valid Whoosh query object | |
| for query in grams: | |
| query_parser = QueryParser("content", schema=searcher.schema) # 'content' should match the field you are searching | |
| parsed_query = query_parser.parse(query) # Convert the query string into a Whoosh Query object | |
| hits.extend(searcher.search(parsed_query, limit=10)) # Perform the search with the parsed query | |
| coarse_matched_contents = dict() | |
| for hit in hits: | |
| matched_result = json.loads(hit.raw) | |
| # `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id | |
| tc_name = ".".join(matched_result["id"].split("-**-")[:2]) | |
| if tc_name in coarse_matched_contents.keys(): | |
| if matched_result["contents"] not in coarse_matched_contents[tc_name]: | |
| coarse_matched_contents[tc_name].append(matched_result["contents"]) | |
| else: | |
| coarse_matched_contents[tc_name] = [matched_result["contents"]] | |
| fine_matched_contents = dict() | |
| for tc_name, contents in coarse_matched_contents.items(): | |
| # Fine-grained matching between the question and coarse matched contents | |
| fm_contents = get_matched_entries(question, contents) | |
| if fm_contents is None: | |
| continue | |
| for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents: | |
| if match_score < 0.9: | |
| continue | |
| if tc_name in fine_matched_contents.keys(): | |
| if len(fine_matched_contents[tc_name]) < 25: | |
| fine_matched_contents[tc_name].append(field_value.strip()) | |
| else: | |
| fine_matched_contents[tc_name] = [field_value.strip()] | |
| return fine_matched_contents | |
| # def get_matched_contents(question, searcher): | |
| # # coarse-grained matching between the input text and all contents in database | |
| # grams = obtain_n_grams(question, 4) | |
| # hits = [] | |
| # for query in grams: | |
| # hits.extend(searcher.search(query, limit = 10)) | |
| # coarse_matched_contents = dict() | |
| # for i in range(len(hits)): | |
| # matched_result = json.loads(hits[i].raw) | |
| # # `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id | |
| # tc_name = ".".join(matched_result["id"].split("-**-")[:2]) | |
| # if tc_name in coarse_matched_contents.keys(): | |
| # if matched_result["contents"] not in coarse_matched_contents[tc_name]: | |
| # coarse_matched_contents[tc_name].append(matched_result["contents"]) | |
| # else: | |
| # coarse_matched_contents[tc_name] = [matched_result["contents"]] | |
| # fine_matched_contents = dict() | |
| # for tc_name, contents in coarse_matched_contents.items(): | |
| # # fine-grained matching between the question and coarse matched contents | |
| # fm_contents = get_matched_entries(question, contents) | |
| # if fm_contents is None: | |
| # continue | |
| # for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents: | |
| # if match_score < 0.9: | |
| # continue | |
| # if tc_name in fine_matched_contents.keys(): | |
| # if len(fine_matched_contents[tc_name]) < 25: | |
| # fine_matched_contents[tc_name].append(field_value.strip()) | |
| # else: | |
| # fine_matched_contents[tc_name] = [field_value.strip()] | |
| # return fine_matched_contents | |
| def get_db_schema_sequence(schema): | |
| schema_sequence = "database schema :\n" | |
| for table in schema["schema_items"]: | |
| table_name, table_comment = table["table_name"], table["table_comment"] | |
| if detect_special_char(table_name): | |
| table_name = add_quotation_mark(table_name) | |
| # if table_comment != "": | |
| # table_name += " ( comment : " + table_comment + " )" | |
| column_info_list = [] | |
| for column_name, column_type, column_comment, column_content, pk_indicator in \ | |
| zip(table["column_names"], table["column_types"], table["column_comments"], table["column_contents"], table["pk_indicators"]): | |
| if detect_special_char(column_name): | |
| column_name = add_quotation_mark(column_name) | |
| additional_column_info = [] | |
| # column type | |
| additional_column_info.append(column_type) | |
| # pk indicator | |
| if pk_indicator != 0: | |
| additional_column_info.append("primary key") | |
| # column comment | |
| if column_comment != "": | |
| additional_column_info.append("comment : " + column_comment) | |
| # representive column values | |
| if len(column_content) != 0: | |
| additional_column_info.append("values : " + " , ".join(column_content)) | |
| column_info_list.append(table_name + "." + column_name + " ( " + " | ".join(additional_column_info) + " )") | |
| schema_sequence += "table "+ table_name + " , columns = [ " + " , ".join(column_info_list) + " ]\n" | |
| if len(schema["foreign_keys"]) != 0: | |
| schema_sequence += "foreign keys :\n" | |
| for foreign_key in schema["foreign_keys"]: | |
| for i in range(len(foreign_key)): | |
| if detect_special_char(foreign_key[i]): | |
| foreign_key[i] = add_quotation_mark(foreign_key[i]) | |
| schema_sequence += "{}.{} = {}.{}\n".format(foreign_key[0], foreign_key[1], foreign_key[2], foreign_key[3]) | |
| else: | |
| schema_sequence += "foreign keys : None\n" | |
| return schema_sequence.strip() | |
| def get_matched_content_sequence(matched_contents): | |
| content_sequence = "" | |
| if len(matched_contents) != 0: | |
| content_sequence += "matched contents :\n" | |
| for tc_name, contents in matched_contents.items(): | |
| table_name = tc_name.split(".")[0] | |
| column_name = tc_name.split(".")[1] | |
| if detect_special_char(table_name): | |
| table_name = add_quotation_mark(table_name) | |
| if detect_special_char(column_name): | |
| column_name = add_quotation_mark(column_name) | |
| content_sequence += table_name + "." + column_name + " ( " + " , ".join(contents) + " )\n" | |
| else: | |
| content_sequence = "matched contents : None" | |
| return content_sequence.strip() | |
| def get_db_schema(db_path, db_comments, db_id): | |
| if db_id in db_comments: | |
| db_comment = db_comments[db_id] | |
| else: | |
| db_comment = None | |
| cursor = get_cursor_from_path(db_path) | |
| # obtain table names | |
| results = execute_sql(cursor, "SELECT name FROM sqlite_master WHERE type='table';") | |
| table_names = [result[0].lower() for result in results] | |
| schema = dict() | |
| schema["schema_items"] = [] | |
| foreign_keys = [] | |
| # for each table | |
| for table_name in table_names: | |
| # skip SQLite system table: sqlite_sequence | |
| if table_name == "sqlite_sequence": | |
| continue | |
| # obtain column names in the current table | |
| results = execute_sql(cursor, "SELECT name, type, pk FROM PRAGMA_TABLE_INFO('{}')".format(table_name)) | |
| column_names_in_one_table = [result[0].lower() for result in results] | |
| column_types_in_one_table = [result[1].lower() for result in results] | |
| pk_indicators_in_one_table = [result[2] for result in results] | |
| column_contents = [] | |
| for column_name in column_names_in_one_table: | |
| column_contents.append(get_column_contents(column_name, table_name, cursor)) | |
| # obtain foreign keys in the current table | |
| results = execute_sql(cursor, "SELECT * FROM pragma_foreign_key_list('{}');".format(table_name)) | |
| for result in results: | |
| if None not in [result[3], result[2], result[4]]: | |
| foreign_keys.append([table_name.lower(), result[3].lower(), result[2].lower(), result[4].lower()]) | |
| # obtain comments for each schema item | |
| if db_comment is not None: | |
| if table_name in db_comment: # record comments for tables and columns | |
| table_comment = db_comment[table_name]["table_comment"] | |
| column_comments = [db_comment[table_name]["column_comments"][column_name] \ | |
| if column_name in db_comment[table_name]["column_comments"] else "" \ | |
| for column_name in column_names_in_one_table] | |
| else: # current database has comment information, but the current table does not | |
| table_comment = "" | |
| column_comments = ["" for _ in column_names_in_one_table] | |
| else: # current database has no comment information | |
| table_comment = "" | |
| column_comments = ["" for _ in column_names_in_one_table] | |
| schema["schema_items"].append({ | |
| "table_name": table_name, | |
| "table_comment": table_comment, | |
| "column_names": column_names_in_one_table, | |
| "column_types": column_types_in_one_table, | |
| "column_comments": column_comments, | |
| "column_contents": column_contents, | |
| "pk_indicators": pk_indicators_in_one_table | |
| }) | |
| schema["foreign_keys"] = foreign_keys | |
| return schema | |