|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import sqlite3 |
|
|
|
|
|
|
|
|
model_name = "mrm8488/t5-base-finetuned-wikiSQL" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
def nl_to_sql(question, file): |
|
|
try: |
|
|
df = pd.read_csv(file.name) |
|
|
except Exception as e: |
|
|
return f"Error reading CSV: {e}", pd.DataFrame() |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(":memory:") |
|
|
df.to_sql("data_table", conn, index=False, if_exists="replace") |
|
|
|
|
|
|
|
|
schema = ", ".join(df.columns) |
|
|
text = f"translate English to SQL: {question} | table columns: {schema}" |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
outputs = model.generate(**inputs, max_length=256) |
|
|
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
try: |
|
|
result = pd.read_sql_query(sql_query, conn) |
|
|
except Exception as e: |
|
|
result = pd.DataFrame({"Error": [str(e)]}) |
|
|
|
|
|
conn.close() |
|
|
return sql_query, result.head() |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=nl_to_sql, |
|
|
inputs=[ |
|
|
gr.Textbox(label="Ask your question (Natural Language)", placeholder="e.g., Show customers older than 30"), |
|
|
gr.File(label="Upload your CSV file") |
|
|
], |
|
|
outputs=[ |
|
|
gr.Textbox(label="Generated SQL Query"), |
|
|
gr.Dataframe(label="Result Preview") |
|
|
], |
|
|
title="🧠 Natural Language to SQL Generator", |
|
|
description="Upload a CSV and ask questions in plain English. Generates SQL and shows results instantly." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|