Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| from utils.sql.process_sql import ( | |
| tokenize, CLAUSE_KEYWORDS, WHERE_OPS, COND_OPS, UNIT_OPS, AGG_OPS, | |
| JOIN_KEYWORDS, ORDER_OPS, skip_semicolon, SQL_OPS) | |
| KEPT_WHERE_OP = ('not', 'in', 'exists') | |
| def parse_table_unit(toks, start_idx, tables_with_alias): | |
| idx = start_idx | |
| len_ = len(toks) | |
| key = toks[idx] | |
| if idx + 1 < len_ and toks[idx + 1] == "as": | |
| tables_with_alias[toks[idx + 2]] = toks[idx] | |
| idx += 3 | |
| else: | |
| idx += 1 | |
| return idx, key | |
| def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): | |
| """ | |
| :returns next idx, column id | |
| """ | |
| tok = toks[start_idx] | |
| if tok == "*": | |
| return start_idx + 1 | |
| if '.' in tok: # if token is a composite | |
| alias, col = tok.split('.') | |
| # key = tables_with_alias[alias] + "." + col | |
| table = tables_with_alias[alias] | |
| """ | |
| Add schema | |
| """ | |
| if table not in schema: | |
| schema[table] = [] | |
| schema[table].append(col) | |
| # We also want to normalize the column | |
| toks[start_idx] = "{}.{}".format(table, col) | |
| """ | |
| END | |
| """ | |
| return start_idx + 1 | |
| assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" | |
| # assert len(default_tables) == 1, "Default table should only have one time" | |
| """ | |
| Add schema | |
| """ | |
| # Find the best table here | |
| def choose_best_table(default_tables, tok): | |
| lower_tok = tok.lower() | |
| candidate = process.extractOne(lower_tok, [table.lower() for table in default_tables])[0] | |
| return candidate | |
| if len(default_tables) != 1: | |
| # print(default_tables) | |
| table = choose_best_table(default_tables, tok) | |
| # assert len(default_tables) == 1, "Default table should only have one time" | |
| else: | |
| table = default_tables[0] | |
| if table not in schema: | |
| schema[table] = [] | |
| schema[table].append(tok) | |
| toks[start_idx] = "{}.{}".format(table, tok) | |
| return start_idx + 1 | |
| # for alias in default_tables: | |
| # table = tables_with_alias[alias] | |
| # if tok in schema.schema[table]: | |
| # key = table + "." + tok | |
| # return start_idx + 1, schema.idMap[key] | |
| # assert False, "Error col: {}".format(tok) | |
| def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None, end_idx=None): | |
| """ | |
| :returns next idx, (agg_op id, col_id) | |
| """ | |
| idx = start_idx | |
| if end_idx is not None: | |
| len_ = len(toks[start_idx:end_idx]) | |
| else: | |
| len_ = len(toks) | |
| isBlock = False | |
| isDistinct = False | |
| if toks[idx] == '(': | |
| isBlock = True | |
| idx += 1 | |
| if toks[idx] in AGG_OPS: | |
| agg_id = AGG_OPS.index(toks[idx]) | |
| idx += 1 | |
| assert idx < len_ and toks[idx] == '(' | |
| idx += 1 | |
| if toks[idx] == "distinct": | |
| idx += 1 | |
| isDistinct = True | |
| idx = parse_col(toks, idx, tables_with_alias, schema, default_tables) | |
| assert idx < len_ and toks[idx] == ')' | |
| idx += 1 | |
| return idx | |
| if toks[idx] == "distinct": | |
| idx += 1 | |
| isDistinct = True | |
| agg_id = AGG_OPS.index("none") | |
| idx = parse_col(toks, idx, tables_with_alias, schema, default_tables) | |
| if isBlock: | |
| assert toks[idx] == ')' | |
| idx += 1 # skip ')' | |
| return idx | |
| def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): | |
| idx = start_idx | |
| len_ = len(toks) | |
| isBlock = False | |
| if toks[idx] == '(': | |
| isBlock = True | |
| idx += 1 | |
| col_unit1 = None | |
| col_unit2 = None | |
| unit_op = UNIT_OPS.index('none') | |
| idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) | |
| if idx < len_ and toks[idx] in UNIT_OPS: | |
| unit_op = UNIT_OPS.index(toks[idx]) | |
| idx += 1 | |
| idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) | |
| if isBlock: | |
| assert toks[idx] == ')' | |
| idx += 1 # skip ')' | |
| return idx | |
| def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): | |
| idx = start_idx | |
| len_ = len(toks) | |
| isBlock = False | |
| if toks[idx] == '(': | |
| isBlock = True | |
| idx += 1 | |
| if toks[idx] == 'select': | |
| idx = parse_sql(toks, idx, schema) | |
| elif "\"" in toks[idx]: # token is a string value | |
| val = toks[idx] | |
| # Replace with placeholder | |
| toks[idx] = "_str_value_" | |
| idx += 1 | |
| else: | |
| try: | |
| val = float(toks[idx]) | |
| toks[idx] = "_num_value_" | |
| idx += 1 | |
| except: | |
| end_idx = idx | |
| while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')' \ | |
| and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[ | |
| end_idx] not in JOIN_KEYWORDS: | |
| end_idx += 1 | |
| # idx = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) | |
| idx = parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables, end_idx=end_idx) | |
| idx = end_idx | |
| if isBlock: | |
| assert toks[idx] == ')' | |
| idx += 1 | |
| return idx | |
| def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): | |
| idx = start_idx | |
| len_ = len(toks) | |
| # conds = [] | |
| while idx < len_: | |
| idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) | |
| not_op = False | |
| if toks[idx] == 'not': | |
| not_op = True | |
| idx += 1 | |
| assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) | |
| op_id = WHERE_OPS.index(toks[idx]) | |
| idx += 1 | |
| val1 = val2 = None | |
| if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values | |
| idx = parse_value(toks, idx, tables_with_alias, schema, default_tables) | |
| assert toks[idx] == 'and' | |
| idx += 1 | |
| idx = parse_value(toks, idx, tables_with_alias, schema, default_tables) | |
| else: # normal case: single value | |
| idx = parse_value(toks, idx, tables_with_alias, schema, default_tables) | |
| val2 = None | |
| # conds.append((not_op, op_id, val_unit, val1, val2)) | |
| if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): | |
| break | |
| if idx < len_ and toks[idx] in COND_OPS: | |
| # conds.append(toks[idx]) | |
| idx += 1 # skip and/or | |
| return idx# , conds | |
| def parse_from(toks, start_idx, schema): | |
| assert 'from' in toks[start_idx:], "'from' not found" | |
| tables_with_alias = {} | |
| len_ = len(toks) | |
| idx = toks.index('from', start_idx) + 1 | |
| default_tables = [] | |
| table_units = [] | |
| conds = [] | |
| # print(idx, len_) | |
| while idx < len_: | |
| # print("idx", idx, toks[idx]) | |
| isBlock = False | |
| if toks[idx] == '(': | |
| isBlock = True | |
| idx += 1 | |
| if toks[idx] == 'select': | |
| idx = parse_sql(toks, idx, schema) | |
| # table_units.append((TABLE_TYPE['sql'], sql)) | |
| else: | |
| if idx < len_ and toks[idx] == 'join': | |
| idx += 1 # skip join | |
| idx, table_name = parse_table_unit(toks, idx, tables_with_alias) | |
| # print(table_name) | |
| # table_units.append((TABLE_TYPE['table_unit'], table_unit)) | |
| default_tables.append(table_name) | |
| """ | |
| Add schema | |
| """ | |
| if table_name not in schema: | |
| schema[table_name] = [] | |
| """ | |
| END | |
| """ | |
| if idx < len_ and toks[idx] == "on": | |
| idx += 1 # skip on | |
| idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables) | |
| # if len(conds) > 0: | |
| # conds.append('and') | |
| # conds.extend(this_conds) | |
| if isBlock: | |
| assert toks[idx] == ')' | |
| idx += 1 | |
| if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): | |
| break | |
| return idx, default_tables, tables_with_alias | |
| def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): | |
| idx = start_idx | |
| len_ = len(toks) | |
| assert toks[idx] == 'select', "'select' not found" | |
| idx += 1 | |
| isDistinct = False | |
| if idx < len_ and toks[idx] == 'distinct': | |
| idx += 1 | |
| isDistinct = True | |
| val_units = [] | |
| while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: | |
| agg_id = AGG_OPS.index("none") | |
| if toks[idx] in AGG_OPS: | |
| agg_id = AGG_OPS.index(toks[idx]) | |
| idx += 1 | |
| idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) | |
| # val_units.append((agg_id, val_unit)) | |
| if idx < len_ and toks[idx] == ',': | |
| idx += 1 # skip ',' | |
| return idx | |
| def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): | |
| idx = start_idx | |
| len_ = len(toks) | |
| if idx >= len_ or toks[idx] != 'where': | |
| return idx | |
| idx += 1 | |
| idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables) | |
| return idx | |
| def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): | |
| idx = start_idx | |
| len_ = len(toks) | |
| col_units = [] | |
| if idx >= len_ or toks[idx] != 'group': | |
| return idx | |
| idx += 1 | |
| assert toks[idx] == 'by' | |
| idx += 1 | |
| while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): | |
| idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) | |
| # col_units.append(col_unit) | |
| if idx < len_ and toks[idx] == ',': | |
| idx += 1 # skip ',' | |
| else: | |
| break | |
| return idx | |
| def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): | |
| idx = start_idx | |
| len_ = len(toks) | |
| if idx >= len_ or toks[idx] != 'having': | |
| return idx | |
| idx += 1 | |
| idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables) | |
| return idx | |
| def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): | |
| idx = start_idx | |
| len_ = len(toks) | |
| val_units = [] | |
| order_type = 'asc' # default type is 'asc' | |
| if idx >= len_ or toks[idx] != 'order': | |
| return idx | |
| idx += 1 | |
| assert toks[idx] == 'by' | |
| idx += 1 | |
| while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): | |
| idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) | |
| # val_units.append(val_unit) | |
| if idx < len_ and toks[idx] in ORDER_OPS: | |
| order_type = toks[idx] | |
| idx += 1 | |
| if idx < len_ and toks[idx] == ',': | |
| idx += 1 # skip ',' | |
| else: | |
| break | |
| return idx | |
| def parse_limit(toks, start_idx): | |
| idx = start_idx | |
| len_ = len(toks) | |
| if idx < len_ and toks[idx] == 'limit': | |
| idx += 2 | |
| toks[idx - 1] = "_limit_value_" | |
| # make limit value can work, cannot assume put 1 as a fake limit number | |
| if type(toks[idx - 1]) != int: | |
| return idx | |
| return idx | |
| return idx | |
| def parse_sql(toks, start_idx, schema): | |
| isBlock = False # indicate whether this is a block of sql/sub-sql | |
| len_ = len(toks) | |
| idx = start_idx | |
| if toks[idx] == '(': | |
| isBlock = True | |
| idx += 1 | |
| from_end_idx, default_tables, tables_with_alias = parse_from(toks, start_idx, schema) | |
| _ = parse_select(toks, idx, tables_with_alias, schema, default_tables) | |
| idx = from_end_idx | |
| idx = parse_where(toks, idx, tables_with_alias, schema, default_tables) | |
| idx = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) | |
| idx = parse_having(toks, idx, tables_with_alias, schema, default_tables) | |
| idx = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) | |
| idx = parse_limit(toks, idx) | |
| # | |
| idx = skip_semicolon(toks, idx) | |
| if isBlock: | |
| assert toks[idx] == ')' | |
| idx += 1 # skip ')' | |
| idx = skip_semicolon(toks, idx) | |
| # for op in SQL_OPS: # initialize IUE | |
| # sql[op] = None | |
| if idx < len_ and toks[idx] in SQL_OPS: | |
| sql_op = toks[idx] | |
| idx += 1 | |
| idx = parse_sql(toks, idx, schema) | |
| # sql[sql_op] = IUE_sql | |
| return idx | |
| def extract_schema_from_sql(schema, sql): | |
| toks = tokenize(sql) | |
| parse_sql(toks=toks, start_idx=0, schema=schema) | |
| return toks | |
| def extract_template_from_sql(sql, schema={}): | |
| try: | |
| toks = tokenize(sql) | |
| except: | |
| print("Tokenization error for {}".format(sql)) | |
| toks = [] | |
| # print(toks) | |
| template = [] | |
| # ignore_follow_up_and = False | |
| len_ = len(toks) | |
| idx = 0 | |
| while idx < len_: | |
| tok = toks[idx] | |
| if tok == "from": | |
| template.append(tok) | |
| if toks[idx+1] != "(": | |
| template.append("[FROM_PART]") | |
| idx += 1 | |
| while idx < len_ and (toks[idx] not in CLAUSE_KEYWORDS and toks[idx] != ")"): | |
| idx += 1 | |
| continue | |
| elif tok in CLAUSE_KEYWORDS: | |
| template.append(tok) | |
| elif tok in AGG_OPS: | |
| template.append(tok) | |
| elif tok in [",", "*", "(", ")", "having", "by", "distinct"]: | |
| template.append(tok) | |
| elif tok in ["asc", "desc"]: | |
| template.append("[ORDER_DIRECTION]") | |
| elif tok in WHERE_OPS: | |
| if tok in KEPT_WHERE_OP: | |
| template.append(tok) | |
| else: | |
| template.append("[WHERE_OP]") | |
| if tok == "between": | |
| idx += 2 | |
| elif tok in COND_OPS: | |
| template.append(tok) | |
| elif template[-1] == "[WHERE_OP]": | |
| template.append("[VALUE]") | |
| elif template[-1] == "limit": | |
| template.append("[LIMIT_VALUE]") | |
| elif template[-1] != "[MASK]": # value, schema, join on as | |
| template.append("[MASK]") | |
| idx += 1 | |
| return template | |
| def extract_partial_template_from_sql(sql, schema={}): | |
| toks = tokenize(sql) | |
| # print(toks) | |
| template = [] | |
| # ignore_follow_up_and = False | |
| len_ = len(toks) | |
| idx = 0 | |
| while idx < len_: | |
| tok = toks[idx] | |
| if tok == "from": | |
| template.append(tok) | |
| if toks[idx+1] != "(": | |
| # template.append("[FROM_PART]") | |
| idx += 1 | |
| while idx < len_ and (toks[idx] not in CLAUSE_KEYWORDS and toks[idx] != ")"): | |
| template.append(toks[idx]) | |
| idx += 1 | |
| continue | |
| elif tok in CLAUSE_KEYWORDS: | |
| template.append(tok) | |
| elif tok in AGG_OPS: | |
| template.append(tok) | |
| elif tok in [",", "*", "(", ")", "having", "by", "distinct"]: | |
| template.append(tok) | |
| elif tok in ["asc", "desc"]: | |
| template.append("[ORDER_DIRECTION]") | |
| elif tok in WHERE_OPS: | |
| if tok in KEPT_WHERE_OP: | |
| template.append(tok) | |
| else: | |
| template.append("[WHERE_OP]") | |
| if tok == "between": | |
| idx += 2 | |
| elif tok in COND_OPS: | |
| template.append(tok) | |
| elif template[-1] == "[WHERE_OP]": | |
| template.append("[VALUE]") | |
| elif template[-1] == "limit": | |
| template.append("[LIMIT_VALUE]") | |
| else: | |
| template.append(tok) | |
| idx += 1 | |
| return template | |
| def is_valid_schema(schema): | |
| # There is no "." and " " in the column name | |
| for table in schema: | |
| if "." in table: | |
| return False | |
| if any([keyword == table for keyword in CLAUSE_KEYWORDS]): | |
| return False | |
| for column in schema[table]: | |
| if "." in column or " " in column or '"' in column or "'" in column: | |
| return False | |
| return True | |
| def clean_sql(sql): | |
| while "JOIN JOIN" in sql: | |
| sql = sql.replace("JOIN JOIN", "JOIN") | |
| if "JOIN WHERE" in sql: | |
| sql = sql.replace("JOIN WHERE", "WHERE") | |
| if "JOIN GROUP BY" in sql: | |
| sql = sql.replace("JOIN GROUP BY", "GROUP BY") | |
| return sql | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--input_file", type=str) | |
| parser.add_argument("--output_file", type=str) | |
| parser.add_argument("--mode", type=str, choices=["debug", "verbose", "silent"]) | |
| parser.add_argument("--task", type=str, choices=["template_extraction", "schema_extraction"]) | |
| args = parser.parse_args() | |
| if args.task == "schema_extraction": | |
| if args.mode == "debug": | |
| sql = "SELECT count(*) FROM games" | |
| sql = sql + " INTERSECT " + "SELECT sacks, year FROM players" | |
| sql = sql + " EXCEPT " + 'SELECT T1.year, T1.sacks FROM players AS T1 JOIN tackles AS T2 ON T1.id = T2.player_id WHERE T2.manager = "A" and T2.season NOT IN (SELECT season FROM match WHERE match_name = "IVL" INTERSECT SELECT T1.year, T1.sacks FROM sack AS T1) GROUP BY T1.year, T1.sacks HAVING count(T1.coach) > 10 ORDER BY T2.score LIMIT 5' | |
| sql = "SELECT T1.pld FROM pld AS T1 JOIN games AS T2 ON T1.crs_code = T2.crs_code JOIN GROUP BY T1.pld WHERE T2.gf = '8' AND T2.gf = '9'" | |
| sql = 'select * from head where height = "6-0" or height = "6-0" order by height asc' | |
| schema = {} | |
| extract_schema_from_sql(schema, sql) | |
| print(schema, is_valid_schema(schema)) | |
| elif args.mode == "verbose": | |
| fout = open(args.output_file, "w") | |
| with open(args.input_file) as fin: | |
| for line in fin: | |
| example = json.loads(line) | |
| schema = {} | |
| try: | |
| sql = example["sql"] if "sql" in example else example["pred"] | |
| sql = clean_sql(sql) | |
| example["sql"] = sql | |
| extract_schema_from_sql(schema, sql) | |
| except: | |
| # print(sql) | |
| continue | |
| for table in schema: | |
| schema[table] = list(set(schema[table])) | |
| if is_valid_schema(schema): | |
| example["extracted_schema"] = schema | |
| fout.write(json.dumps(example) + "\n") | |
| elif args.mode == "verbose": | |
| fout = open(args.output_file, "w") | |
| with open(args.input_file) as fin: | |
| for line in fin: | |
| example = json.loads(line) | |
| schema = {} | |
| sql = example["sql"] if "sql" in example else example["pred"] | |
| sql = clean_sql(sql) | |
| example["sql"] = sql | |
| extract_schema_from_sql(schema, sql) | |
| for table in schema: | |
| schema[table] = list(set(schema[table])) | |
| example["extracted_schema"] = schema | |
| fout.write(json.dumps(example) + "\n") | |
| if is_valid_schema(schema): | |
| example["extracted_schema"] = schema | |
| fout.write(json.dumps(example) + "\n") | |
| elif args.task == "template_extraction": | |
| if args.mode == "debug": | |
| sql = "SELECT avg(T1.Votes) FROM seats AS T1 JOIN votes AS T2 ON T1.Seat_ID = T2.Seat_ID WHERE T1.seats BETWEEN 1 AND 2 and T1.Seats = 1 AND T2.Votes = 10" | |
| print(extract_template_from_sql(sql)) | |
| print(extract_partial_template_from_sql(sql)) | |
| elif args.mode == "verbose": | |
| fout_json = open(args.output_file + ".json", "w") | |
| fout_txt = open(args.output_file + ".txt", "w") | |
| low_freq_txt = open(args.output_file + ".low_freq", "w") | |
| high_freq_txt = open(args.output_file + ".high_freq", "w") | |
| all_templates = set() | |
| # for input_file in args.input_file.split(","): | |
| templates = {} | |
| with open(args.input_file) as fin: | |
| for line in fin: | |
| example = json.loads(line) | |
| sql = example["sql"] if "sql" in example else example["pred"] | |
| if isinstance(sql, list): | |
| sql = sql[-1] | |
| template = extract_template_from_sql(sql) | |
| template_str = " ".join(template) | |
| if template_str not in templates: | |
| templates[template_str] = [] | |
| templates[template_str].append(sql) | |
| print("{} has template {}".format(args.input_file, len(templates))) | |
| json.dump(templates, fout_json) | |
| for template in sorted(templates.keys()): | |
| if len(templates[template]) > 1: | |
| high_freq_txt.write(template + "\n") | |
| else: | |
| low_freq_txt.write(template + "\n") | |
| fout_txt.write(template + "\n") | |