Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import sqlite3
|
|
| 5 |
|
| 6 |
# Load model
|
| 7 |
model_name = "mrm8488/t5-base-finetuned-wikiSQL"
|
| 8 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 9 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 10 |
|
| 11 |
def nl_to_sql(question, file):
|
|
@@ -26,7 +26,7 @@ def nl_to_sql(question, file):
|
|
| 26 |
outputs = model.generate(**inputs, max_length=256)
|
| 27 |
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 28 |
|
| 29 |
-
#
|
| 30 |
try:
|
| 31 |
result = pd.read_sql_query(sql_query, conn)
|
| 32 |
except Exception as e:
|
|
|
|
| 5 |
|
| 6 |
# Load model
|
| 7 |
model_name = "mrm8488/t5-base-finetuned-wikiSQL"
|
| 8 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # <-- use slow tokenizer
|
| 9 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 10 |
|
| 11 |
def nl_to_sql(question, file):
|
|
|
|
| 26 |
outputs = model.generate(**inputs, max_length=256)
|
| 27 |
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 28 |
|
| 29 |
+
# Execute SQL query
|
| 30 |
try:
|
| 31 |
result = pd.read_sql_query(sql_query, conn)
|
| 32 |
except Exception as e:
|