Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
# app.py
|
| 2 |
|
| 3 |
from fastapi import FastAPI, HTTPException
|
| 4 |
-
from fastapi.responses import HTMLResponse, JSONResponse
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 6 |
from pydantic import BaseModel
|
| 7 |
|
| 8 |
import os
|
|
@@ -14,151 +14,82 @@ from datasets import load_dataset
|
|
| 14 |
from deep_translator import GoogleTranslator
|
| 15 |
from transformers import TapexTokenizer, BartForConditionalGeneration
|
| 16 |
|
| 17 |
-
|
| 18 |
-
# --------- Configuración y defaults ----------
|
| 19 |
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
|
| 20 |
TABLE_SPLIT = os.getenv("TABLE_SPLIT", "validation")
|
| 21 |
TABLE_INDEX = int(os.getenv("TABLE_INDEX", "10"))
|
| 22 |
-
MAX_ROWS = int(os.getenv("MAX_ROWS", "100"))
|
| 23 |
|
| 24 |
-
#
|
| 25 |
os.environ["HF_HOME"] = "/app/.cache/huggingface"
|
| 26 |
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (HF Space)")
|
| 31 |
app.add_middleware(
|
| 32 |
CORSMiddleware,
|
| 33 |
-
allow_origins=["*"],
|
| 34 |
allow_credentials=True,
|
| 35 |
allow_methods=["*"],
|
| 36 |
allow_headers=["*"],
|
| 37 |
)
|
| 38 |
|
| 39 |
-
|
| 40 |
-
# --------- Carga perezosa (lazy) ---------
|
| 41 |
@lru_cache(maxsize=1)
|
| 42 |
def get_model_and_tokenizer():
|
| 43 |
tok = TapexTokenizer.from_pretrained(HF_MODEL_ID)
|
| 44 |
-
# En Spaces free es CPU; no uses device_map="auto" para evitar dependencia de accelerate
|
| 45 |
model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID)
|
| 46 |
model.eval()
|
| 47 |
return tok, model
|
| 48 |
|
| 49 |
-
|
| 50 |
@lru_cache(maxsize=32)
|
| 51 |
def get_table(split: str, index: int, max_rows: int) -> pd.DataFrame:
|
| 52 |
-
""
|
| 53 |
-
Carga una tabla de WikiSQL y devuelve un DataFrame.
|
| 54 |
-
Evitamos el revision parquet (que a veces falla en Spaces) y usamos el split normal.
|
| 55 |
-
"""
|
| 56 |
-
ds = load_dataset("Salesforce/wikisql", split=split)
|
| 57 |
if index < 0 or index >= len(ds):
|
| 58 |
raise IndexError(f"TABLE_INDEX fuera de rango (0..{len(ds)-1}).")
|
| 59 |
-
|
| 60 |
ex = ds[index]
|
| 61 |
header = [str(h) for h in ex["table"]["header"]]
|
| 62 |
rows = ex["table"]["rows"][:max_rows]
|
| 63 |
df = pd.DataFrame(rows, columns=header)
|
| 64 |
-
# Normaliza nombres de columnas a string
|
| 65 |
df.columns = [str(c) for c in df.columns]
|
| 66 |
return df
|
| 67 |
|
| 68 |
-
|
| 69 |
-
# --------- Esquema de petición ----------
|
| 70 |
class NLQuery(BaseModel):
|
| 71 |
nl_query: str
|
| 72 |
|
| 73 |
-
|
| 74 |
-
# --------- Mini UI en / ----------
|
| 75 |
-
INDEX_HTML = """
|
| 76 |
-
<!doctype html>
|
| 77 |
-
<html lang="es">
|
| 78 |
-
<meta charset="utf-8">
|
| 79 |
-
<title>NL→SQL (TAPEX + WikiSQL)</title>
|
| 80 |
-
<style>
|
| 81 |
-
body{font-family:system-ui,-apple-system,Segoe UI,Roboto,Ubuntu,sans-serif;max-width:860px;margin:30px auto;padding:0 16px;color:#eaeaea;background:#0f1115}
|
| 82 |
-
h1{font-size:1.6rem;margin:0 0 8px}
|
| 83 |
-
.card{background:#171923;border:1px solid #232736;border-radius:12px;padding:16px;margin:18px 0}
|
| 84 |
-
input,button,select{font-size:1rem}
|
| 85 |
-
input{width:100%;padding:10px 12px;border-radius:8px;border:1px solid #2a3042;background:#0f1115;color:#eaeaea}
|
| 86 |
-
button{padding:10px 16px;border-radius:8px;border:1px solid #2a3042;background:#1f2433;color:#eaeaea;cursor:pointer}
|
| 87 |
-
pre{white-space:pre-wrap;background:#0f1115;border:1px solid #2a3042;padding:12px;border-radius:8px}
|
| 88 |
-
.row{display:flex;gap:12px;align-items:center}
|
| 89 |
-
</style>
|
| 90 |
-
<h1>🧠 NL → SQL (TAPEX + WikiSQL)</h1>
|
| 91 |
-
<div class="card">
|
| 92 |
-
<p><b>Backend:</b> este Space ofrece endpoints REST; prueba una consulta:</p>
|
| 93 |
-
<div class="row">
|
| 94 |
-
<input id="q" placeholder="Ej.: Muestra los jugadores que son Guards." />
|
| 95 |
-
<button onclick="run()">Generar SQL</button>
|
| 96 |
-
<button onclick="prev()">Ver preview tabla</button>
|
| 97 |
-
</div>
|
| 98 |
-
<p style="font-size:.9rem;opacity:.8">Swagger: <a href="./docs" target="_blank">/docs</a> · Salud: <a href="./api/health" target="_blank">/api/health</a></p>
|
| 99 |
-
<pre id="out">Listo para generar...</pre>
|
| 100 |
-
</div>
|
| 101 |
-
<script>
|
| 102 |
-
async function run(){
|
| 103 |
-
const q = document.getElementById('q').value.trim();
|
| 104 |
-
const r = await fetch('./api/nl2sql', {method:'POST', headers:{'Content-Type':'application/json'}, body: JSON.stringify({nl_query: q}) });
|
| 105 |
-
document.getElementById('out').textContent = JSON.stringify(await r.json(), null, 2);
|
| 106 |
-
}
|
| 107 |
-
async function prev(){
|
| 108 |
-
const r = await fetch('./api/preview');
|
| 109 |
-
document.getElementById('out').textContent = JSON.stringify(await r.json(), null, 2);
|
| 110 |
-
}
|
| 111 |
-
</script>
|
| 112 |
-
</html>
|
| 113 |
-
"""
|
| 114 |
-
|
| 115 |
-
@app.get("/", response_class=HTMLResponse)
|
| 116 |
-
def home():
|
| 117 |
-
return INDEX_HTML
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
# --------- Rutas API ----------
|
| 121 |
@app.get("/api/health")
|
| 122 |
def health():
|
| 123 |
return {"ok": True, "model": HF_MODEL_ID, "split": TABLE_SPLIT, "index": TABLE_INDEX}
|
| 124 |
|
| 125 |
-
|
| 126 |
@app.get("/api/preview")
|
| 127 |
def preview():
|
| 128 |
try:
|
| 129 |
df = get_table(TABLE_SPLIT, TABLE_INDEX, MAX_ROWS)
|
| 130 |
-
|
| 131 |
-
data = df.head(8).to_dict(orient="records")
|
| 132 |
-
return {"columns": list(df.columns), "rows": data}
|
| 133 |
except Exception as e:
|
| 134 |
-
# Envía mensaje simple (para no saturar logs)
|
| 135 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 136 |
|
| 137 |
-
|
| 138 |
@app.post("/api/nl2sql")
|
| 139 |
def nl2sql(q: NLQuery):
|
| 140 |
nl = (q.nl_query or "").strip()
|
| 141 |
if not nl:
|
| 142 |
raise HTTPException(status_code=400, detail="Consulta vacía.")
|
| 143 |
|
| 144 |
-
# Traducción ES→EN si
|
| 145 |
try:
|
| 146 |
is_ascii = all(ord(c) < 128 for c in nl)
|
| 147 |
nl_en = nl if is_ascii else GoogleTranslator(source="auto", target="en").translate(nl)
|
| 148 |
except Exception:
|
| 149 |
-
# Si la traducción falla, seguimos con el texto original
|
| 150 |
nl_en = nl
|
| 151 |
|
| 152 |
try:
|
| 153 |
df = get_table(TABLE_SPLIT, TABLE_INDEX, MAX_ROWS)
|
| 154 |
tok, model = get_model_and_tokenizer()
|
| 155 |
-
|
| 156 |
-
# Tokenización limitada
|
| 157 |
enc = tok(table=df, query=nl_en, return_tensors="pt", truncation=True, max_length=512)
|
| 158 |
-
|
| 159 |
with torch.inference_mode():
|
| 160 |
out = model.generate(**enc, max_length=160, num_beams=1)
|
| 161 |
-
|
| 162 |
sql = tok.batch_decode(out, skip_special_tokens=True)[0]
|
| 163 |
return {"consulta_original": nl, "consulta_traducida": nl_en, "sql_generado": sql}
|
| 164 |
except Exception as e:
|
|
|
|
| 1 |
+
# app.py — NL→SQL (TAPEX + WikiSQL) backend (solo API)
|
| 2 |
|
| 3 |
from fastapi import FastAPI, HTTPException
|
|
|
|
| 4 |
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from fastapi.responses import JSONResponse
|
| 6 |
from pydantic import BaseModel
|
| 7 |
|
| 8 |
import os
|
|
|
|
| 14 |
from deep_translator import GoogleTranslator
|
| 15 |
from transformers import TapexTokenizer, BartForConditionalGeneration
|
| 16 |
|
| 17 |
+
# ------------ Config ------------
|
|
|
|
| 18 |
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
|
| 19 |
TABLE_SPLIT = os.getenv("TABLE_SPLIT", "validation")
|
| 20 |
TABLE_INDEX = int(os.getenv("TABLE_INDEX", "10"))
|
| 21 |
+
MAX_ROWS = int(os.getenv("MAX_ROWS", "100"))
|
| 22 |
|
| 23 |
+
# Caché HF escribible en Spaces
|
| 24 |
os.environ["HF_HOME"] = "/app/.cache/huggingface"
|
| 25 |
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
|
| 26 |
|
| 27 |
+
# ------------ App & CORS ------------
|
| 28 |
+
app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (API)")
|
|
|
|
| 29 |
app.add_middleware(
|
| 30 |
CORSMiddleware,
|
| 31 |
+
allow_origins=["*"], # cámbialo al dominio de Vercel cuando lo tengas
|
| 32 |
allow_credentials=True,
|
| 33 |
allow_methods=["*"],
|
| 34 |
allow_headers=["*"],
|
| 35 |
)
|
| 36 |
|
| 37 |
+
# ------------ Carga perezosa ------------
|
|
|
|
| 38 |
@lru_cache(maxsize=1)
|
| 39 |
def get_model_and_tokenizer():
|
| 40 |
tok = TapexTokenizer.from_pretrained(HF_MODEL_ID)
|
|
|
|
| 41 |
model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID)
|
| 42 |
model.eval()
|
| 43 |
return tok, model
|
| 44 |
|
|
|
|
| 45 |
@lru_cache(maxsize=32)
|
| 46 |
def get_table(split: str, index: int, max_rows: int) -> pd.DataFrame:
|
| 47 |
+
ds = load_dataset("Salesforce/wikisql", split=split) # evita revision parquet
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
if index < 0 or index >= len(ds):
|
| 49 |
raise IndexError(f"TABLE_INDEX fuera de rango (0..{len(ds)-1}).")
|
|
|
|
| 50 |
ex = ds[index]
|
| 51 |
header = [str(h) for h in ex["table"]["header"]]
|
| 52 |
rows = ex["table"]["rows"][:max_rows]
|
| 53 |
df = pd.DataFrame(rows, columns=header)
|
|
|
|
| 54 |
df.columns = [str(c) for c in df.columns]
|
| 55 |
return df
|
| 56 |
|
| 57 |
+
# ------------ Schemas ------------
|
|
|
|
| 58 |
class NLQuery(BaseModel):
|
| 59 |
nl_query: str
|
| 60 |
|
| 61 |
+
# ------------ Endpoints ------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
@app.get("/api/health")
|
| 63 |
def health():
|
| 64 |
return {"ok": True, "model": HF_MODEL_ID, "split": TABLE_SPLIT, "index": TABLE_INDEX}
|
| 65 |
|
|
|
|
| 66 |
@app.get("/api/preview")
|
| 67 |
def preview():
|
| 68 |
try:
|
| 69 |
df = get_table(TABLE_SPLIT, TABLE_INDEX, MAX_ROWS)
|
| 70 |
+
return {"columns": list(df.columns), "rows": df.head(8).to_dict(orient="records")}
|
|
|
|
|
|
|
| 71 |
except Exception as e:
|
|
|
|
| 72 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 73 |
|
|
|
|
| 74 |
@app.post("/api/nl2sql")
|
| 75 |
def nl2sql(q: NLQuery):
|
| 76 |
nl = (q.nl_query or "").strip()
|
| 77 |
if not nl:
|
| 78 |
raise HTTPException(status_code=400, detail="Consulta vacía.")
|
| 79 |
|
| 80 |
+
# Traducción ES→EN si no-ASCII
|
| 81 |
try:
|
| 82 |
is_ascii = all(ord(c) < 128 for c in nl)
|
| 83 |
nl_en = nl if is_ascii else GoogleTranslator(source="auto", target="en").translate(nl)
|
| 84 |
except Exception:
|
|
|
|
| 85 |
nl_en = nl
|
| 86 |
|
| 87 |
try:
|
| 88 |
df = get_table(TABLE_SPLIT, TABLE_INDEX, MAX_ROWS)
|
| 89 |
tok, model = get_model_and_tokenizer()
|
|
|
|
|
|
|
| 90 |
enc = tok(table=df, query=nl_en, return_tensors="pt", truncation=True, max_length=512)
|
|
|
|
| 91 |
with torch.inference_mode():
|
| 92 |
out = model.generate(**enc, max_length=160, num_beams=1)
|
|
|
|
| 93 |
sql = tok.batch_decode(out, skip_special_tokens=True)[0]
|
| 94 |
return {"consulta_original": nl, "consulta_traducida": nl_en, "sql_generado": sql}
|
| 95 |
except Exception as e:
|