stvnnnnnn's picture
Update app.py
3aa054d verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from functools import lru_cache
from transformers import TapexTokenizer, BartForConditionalGeneration
from deep_translator import GoogleTranslator
from pathlib import Path
import os, json, pandas as pd, torch
# ------------------------
# Config
# ------------------------
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
SPLIT = os.getenv("TABLE_SPLIT", "validation") # "validation" ~ "dev"
INDEX = int(os.getenv("TABLE_INDEX", "10"))
MAX_ROWS = int(os.getenv("MAX_ROWS", "12"))
# ------------------------
# App
# ------------------------
app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (API)")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True,
)
class NLQuery(BaseModel):
nl_query: str
# ------------------------
# Modelo
# ------------------------
tok = TapexTokenizer.from_pretrained(HF_MODEL_ID)
model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID)
if torch.cuda.is_available():
model = model.to("cuda")
# ------------------------
# Utilidades de carga robustas
# ------------------------
def _read_json_or_jsonl(p: Path) -> dict:
"""
Lee un JSON normal (.json) o un JSONL (.jsonl) y devuelve el primer objeto.
"""
txt = p.read_text(encoding="utf-8").strip()
if p.suffix.lower() == ".jsonl":
for line in txt.splitlines():
s = line.strip()
if s:
return json.loads(s)
raise ValueError(f"{p} está vacío.")
return json.loads(txt)
@lru_cache(maxsize=32)
def get_table(split: str, index: int, max_rows: int) -> pd.DataFrame:
"""
1) Intenta cargar ./data/<split>.json o ./data/<split>.jsonl (mapeando 'validation'->'dev').
2) Si no existe, hace fallback a un ejemplo de WikiSQL (conversión Parquet oficial).
"""
base_dir = Path(__file__).parent
data_dir = base_dir / "data"
# Normalizamos nombre local (para demo usamos 'dev')
local_name = "dev" if split.lower() in ("validation", "dev") else split.lower()
# 1) Buscar archivo local
for candidate in (data_dir / f"{local_name}.json", data_dir / f"{local_name}.jsonl"):
if candidate.exists():
js = _read_json_or_jsonl(candidate)
header = [str(h) for h in js["header"]]
rows = js["rows"][:max_rows]
df = pd.DataFrame(rows, columns=header)
df.columns = [str(c) for c in df.columns]
return df
# 2) Fallback: cargar un ejemplo del dataset WikiSQL (Parquet convertido)
try:
from datasets import load_dataset # import diferido para arrancar más rápido
ds = load_dataset("Salesforce/wikisql", split="validation", revision="refs/convert/parquet")
if not (0 <= index < len(ds)):
index = 0 # seguridad
ex = ds[index]
header = [str(h) for h in ex["table"]["header"]]
rows = ex["table"]["rows"][:max_rows]
df = pd.DataFrame(rows, columns=header)
df.columns = [str(c) for c in df.columns]
return df
except Exception as e:
raise RuntimeError(f"No se pudo obtener una tabla: {e}")
# ------------------------
# Endpoints
# ------------------------
@app.get("/api/health")
def health():
return {"ok": True, "model": HF_MODEL_ID, "split": SPLIT, "index": INDEX}
@app.get("/api/preview")
def preview():
try:
df = get_table(SPLIT, INDEX, MAX_ROWS)
return {"columns": df.columns.tolist(), "rows": df.head(8).to_dict(orient="records")}
except Exception as e:
return {"error": str(e)}
@app.post("/api/nl2sql")
def nl2sql(q: NLQuery):
try:
text = (q.nl_query or "").strip()
if not text:
raise ValueError("Consulta vacía.")
# Detectar si parece SQL
lower = text.lower().strip()
looks_like_sql = lower.startswith(("select", "with", "insert", "update", "delete", "create", "drop", "alter"))
# Traducir a inglés si no es SQL
query_en = text
if not looks_like_sql:
try:
translated = GoogleTranslator(source="auto", target="en").translate(text)
if translated:
query_en = translated
except Exception:
query_en = text # fallback seguro
# Procesar con TAPEX
df = get_table(SPLIT, INDEX, MAX_ROWS)
enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
if torch.cuda.is_available():
enc = {k: v.to("cuda") for k, v in enc.items()}
out = model.generate(**enc, max_length=160, num_beams=1)
sql = tok.batch_decode(out, skip_special_tokens=True)[0]
return {
"consulta_original": text,
"consulta_traducida": query_en,
"sql_generado": sql
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))