Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import gradio as gr | |
| os.environ["TRANSFORMERS_NO_TF"] = "1" | |
| os.environ["TRANSFORMERS_NO_FLAX"] = "1" | |
| os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" | |
| CRUD_VERB = { | |
| "INSERT": "load", | |
| "UPDATE": "changed", | |
| "DELETE": "deleted", | |
| "SELECT": "specified", | |
| } | |
| # ---------- helpers ---------- | |
| def plural(n: int, singular: str = "record", plural_word: str = "records") -> str: | |
| return f"{n} {singular if n == 1 else plural_word}" | |
| def detect_command(sql_text: str) -> str: | |
| m = re.search(r"\b(INSERT|UPDATE|DELETE|SELECT)\b", sql_text, flags=re.IGNORECASE) | |
| return m.group(1).upper() if m else "OTHER" | |
| def parse_table_name(sql: str) -> str: | |
| for pat in [ | |
| r"INSERT\s+INTO\s+([A-Za-z0-9\.\[\]_]+)", | |
| r"UPDATE\s+([A-Za-z0-9\.\[\]_]+)", | |
| r"DELETE\s+FROM\s+([A-Za-z0-9\.\[\]_]+)", | |
| r"\bFROM\s+([A-Za-z0-9\.\[\]_]+)", | |
| ]: | |
| m = re.search(pat, sql, flags=re.IGNORECASE) | |
| if m: | |
| return m.group(1) | |
| return "" | |
| def clean_statement(text: str) -> str: | |
| text = re.sub(r"^(What|Which|How|Give|Provide)[^:]*:\s*", "", text, flags=re.IGNORECASE).strip() | |
| if text and text[-1] not in ".!?": | |
| text += "." | |
| return text[0].upper() + text[1:] if text else text | |
| def infer_in_list_count(where_clause: str) -> None: | |
| if not where_clause: | |
| return None | |
| m = re.search(r"\bIN\s*\(\s*([^)]+?)\s*\)", where_clause, flags=re.IGNORECASE | re.DOTALL) | |
| if not m: | |
| return None | |
| items = [x.strip() for x in m.group(1).split(",") if x.strip()] | |
| return len(items) if items else None | |
| # ---------- deterministic CRUD summaries ---------- | |
| def summarize_insert(sql: str) -> str: | |
| table = parse_table_name(sql) | |
| cols_match = re.search(r"\(\s*([^)]+?)\s*\)\s*VALUES", sql, flags=re.IGNORECASE | re.DOTALL) | |
| cols = [] | |
| if cols_match: | |
| cols = [c.strip().strip("[]") for c in cols_match.group(1).split(",")] | |
| try: | |
| uid_idx = cols.index("user_id") | |
| except ValueError: | |
| uid_idx = None | |
| tuples = re.findall(r"VALUES\s*\(\s*([^)]+?)\s*\)", sql, flags=re.IGNORECASE | re.DOTALL) | |
| if not tuples: | |
| tuples = re.findall(r"\(\s*([^)]+?)\s*\)", sql, flags=re.IGNORECASE) | |
| user_ids = [] | |
| for tup in tuples: | |
| parts = [p.strip() for p in tup.split(",")] | |
| if uid_idx is not None and uid_idx < len(parts): | |
| uid_raw = parts[uid_idx].strip().strip("'").strip('"') | |
| if re.fullmatch(r"-?\d+", uid_raw): | |
| user_ids.append(int(uid_raw)) | |
| count = len(tuples) | |
| verb = CRUD_VERB["INSERT"] | |
| if user_ids: | |
| groups = {} | |
| for u in user_ids: | |
| groups[u] = groups.get(u, 0) + 1 | |
| if len(groups) == 1: | |
| uid = next(iter(groups)) | |
| return f"{count} record(s) {verb} into table {table} (column user_id {uid})." | |
| else: | |
| parts = [f"{n} with column user_id {uid}" for uid, n in sorted(groups.items())] | |
| return f"{count} record(s) {verb} into table {table} ({', '.join(parts)})." | |
| else: | |
| return f"{count} record(s) {verb} into table {table}." | |
| def summarize_update(sql: str) -> str: | |
| table = parse_table_name(sql) | |
| set_match = re.search(r"\bSET\b\s+(.+?)(\bWHERE\b|;|$)", sql, flags=re.IGNORECASE | re.DOTALL) | |
| changed_cols = [] | |
| if set_match: | |
| assigns = [a.strip() for a in set_match.group(1).split(",")] | |
| for a in assigns: | |
| col = a.split("=")[0].strip().strip("[]") | |
| if col: | |
| changed_cols.append(col) | |
| where = "" | |
| w = re.search(r"\bWHERE\b\s+(.+?)(;|$)", sql, flags=re.IGNORECASE | re.DOTALL) | |
| if w: | |
| where = re.sub(r"\s+", " ", w.group(1)).strip() | |
| verb = CRUD_VERB["UPDATE"] | |
| base = f"Record(s) {verb} in table {table}" | |
| if changed_cols: | |
| cols_txt = ", ".join(changed_cols) | |
| base += f" (changed: column(s) {cols_txt})" | |
| if where: | |
| base += f" with {where}" | |
| return base + "." | |
| def summarize_delete(sql: str) -> str: | |
| table = parse_table_name(sql) | |
| where = "" | |
| w = re.search(r"\bWHERE\b\s+(.+?)(;|$)", sql, flags=re.IGNORECASE | re.DOTALL) | |
| if w: | |
| where = re.sub(r"\s+", " ", w.group(1)).strip() | |
| n_targets = infer_in_list_count(where) | |
| verb = CRUD_VERB["DELETE"] | |
| base = (f"{plural(n_targets, 'record', 'records')} {verb} from {table}" | |
| if n_targets else f"Records {verb} from {table}") | |
| if where: | |
| base += f" that match {where}" | |
| return base + "." | |
| def summarize_select(sql: str) -> str: | |
| table = parse_table_name(sql) | |
| cols = "data" | |
| cm = re.search(r"\bSELECT\b\s+(.+?)\bFROM\b", sql, flags=re.IGNORECASE | re.DOTALL) | |
| if cm: | |
| cols_raw = cm.group(1).strip() | |
| cols = "all columns" if cols_raw == "*" else re.sub(r"\s+", " ", cols_raw) | |
| where = "" | |
| w = re.search(r"\bWHERE\b\s+(.+?)(;|$)", sql, flags=re.IGNORECASE | re.DOTALL) | |
| if w: | |
| where = re.sub(r"\s+", " ", w.group(1)).strip() | |
| verb = CRUD_VERB["SELECT"] | |
| base = f"{cols} will be {verb} from {table}" | |
| if where: | |
| base += f" that match {where}" | |
| return base + "." | |
| def deterministic_summary(sql_text: str) -> str: | |
| cmd = detect_command(sql_text) | |
| if cmd == "INSERT": | |
| return summarize_insert(sql_text) | |
| if cmd == "UPDATE": | |
| return summarize_update(sql_text) | |
| if cmd == "DELETE": | |
| return summarize_delete(sql_text) | |
| if cmd == "SELECT": | |
| return summarize_select(sql_text) | |
| return "Unrecognized SQL command." | |
| # ---------- optional T5 rephrase ---------- | |
| _HAS_T5 = False | |
| try: | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| _HAS_T5 = True | |
| except Exception: | |
| _HAS_T5 = False | |
| _T5_TOKENIZER = None | |
| _T5_MODEL = None | |
| CRUD_PROMPT = { | |
| "INSERT": "Rewrite as a clear statement that new records will be added. Keep numbers the same.", | |
| "UPDATE": "Rewrite as a clear statement that existing records will be updated. Keep names and conditions.", | |
| "DELETE": "Rewrite as a clear statement that records will be deleted. Keep conditions if present.", | |
| "SELECT": "Rewrite as a clear statement that data will be retrieved. Keep table/filters.", | |
| "OTHER": "Rewrite as a short, clear statement for non-technical users.", | |
| } | |
| def load_t5(): | |
| global _T5_TOKENIZER, _T5_MODEL | |
| if _T5_TOKENIZER is None or _T5_MODEL is None: | |
| _T5_TOKENIZER = T5Tokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL-sql-to-en") | |
| _T5_MODEL = T5ForConditionalGeneration.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL-sql-to-en") | |
| def rephrase_with_t5(summary: str, cmd: str) -> str: | |
| if not _HAS_T5: | |
| return summary | |
| load_t5() | |
| instruction = CRUD_PROMPT.get(cmd, CRUD_PROMPT["OTHER"]) | |
| input_text = f"explain sql in plain english statement: {instruction} {summary}" | |
| feats = _T5_TOKENIZER([input_text], return_tensors="pt") | |
| out = _T5_MODEL.generate( | |
| input_ids=feats["input_ids"], | |
| attention_mask=feats["attention_mask"], | |
| max_new_tokens=64, | |
| do_sample=False, | |
| ) | |
| decoded = _T5_TOKENIZER.decode(out[0], skip_special_tokens=True) | |
| return clean_statement(decoded) | |
| def _bad_rephrase(text: str) -> bool: | |
| if text is None: | |
| return True | |
| t = str(text).strip() | |
| if len(t) >= 2 and ((t[0] == t[-1] == '"') or (t[0] == t[-1] == "'")): | |
| t = t[1:-1].strip() | |
| t_norm = re.sub(r"[\s\.\!\?]+$", "", t).strip().lower() | |
| # added "true" here to force fallback for True./False. answers | |
| if t_norm in ("false", "true", "none", "null", "n/a", "na", ""): | |
| return True | |
| if re.match(r"^(what|which|how)\b", t_norm, re.IGNORECASE): | |
| return True | |
| return len(t_norm) < 3 | |
| def explain(sql_text: str): | |
| sql_text = (sql_text or "").strip() | |
| cmd = detect_command(sql_text) | |
| deterministic = deterministic_summary(sql_text) | |
| final = None | |
| if _HAS_T5 and cmd != "OTHER": | |
| try: | |
| final = rephrase_with_t5(deterministic, cmd) | |
| except Exception as e: | |
| print(f"T5 rephrase failed: {e}") | |
| final = None | |
| if _bad_rephrase(final): | |
| final = deterministic | |
| return final | |
| # ---------- UI ---------- | |
| EXAMPLES = [ | |
| # INSERT | |
| """INSERT INTO demo_database..user_records (record_id, person_id, created_at) | |
| VALUES (101, 5, GETDATE()), (102, 5, GETDATE()), (103, 5, GETDATE());""", | |
| # UPDATE | |
| """UPDATE users | |
| SET status = 'active', last_login = GETDATE() | |
| WHERE user_id IN (101, 102, 103);""", | |
| # DELETE | |
| """DELETE FROM orders | |
| WHERE order_date < '2024-01-01' AND status = 'cancelled';""", | |
| # SELECT | |
| """SELECT user_id, email, created_at | |
| FROM accounts | |
| WHERE email LIKE '%@example.com' AND created_at >= '2025-01-01';""", | |
| ] | |
| with gr.Blocks(theme=gr.themes.Glass()) as demo: | |
| gr.Markdown("## π CRUD-SQL2Text") | |
| sql_in = gr.Textbox(label="Enter SQL Query", lines=8, placeholder="Paste your SQL statement here...") | |
| final_out = gr.Textbox(label="Natural Language Output", lines=3) | |
| btn = gr.Button("Explain SQL") | |
| btn.click(explain, inputs=[sql_in], outputs=[final_out]) | |
| gr.Examples(examples=EXAMPLES, inputs=[sql_in], outputs=[final_out], fn=explain, cache_examples=False) | |
| demo.launch(share=True) | |