Spaces:
Sleeping
Sleeping
File size: 5,644 Bytes
9c83e37 c1c647f 349e4b9 9c83e37 349e4b9 9c83e37 349e4b9 9c83e37 349e4b9 619ce97 9c83e37 349e4b9 9c83e37 349e4b9 c1c647f 349e4b9 79b7bce 349e4b9 79b7bce 349e4b9 c1c647f 349e4b9 9c83e37 349e4b9 9c83e37 c1c647f 349e4b9 c1c647f 349e4b9 c1c647f 9c83e37 c1c647f 349e4b9 c1c647f 349e4b9 9c83e37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from functools import lru_cache
from huggingface_hub import hf_hub_download
from transformers import TapexTokenizer, BartForConditionalGeneration
from deep_translator import GoogleTranslator
import os, json, pandas as pd, torch
# ------------------------
# Config
# ------------------------
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
WIKISQL_REPO = os.getenv("WIKISQL_REPO", "Salesforce/wikisql") # dataset oficial
SPLIT = os.getenv("TABLE_SPLIT", "validation") # "validation" == dev en WikiSQL
INDEX = int(os.getenv("TABLE_INDEX", "10"))
MAX_ROWS = int(os.getenv("MAX_ROWS", "12"))
# ------------------------
# App
# ------------------------
app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (API)")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True,
)
class NLQuery(BaseModel):
nl_query: str
# ------------------------
# Modelo
# ------------------------
tok = TapexTokenizer.from_pretrained(HF_MODEL_ID)
model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID)
if torch.cuda.is_available():
model = model.to("cuda")
# ------------------------
# Util: carga WikiSQL (JSONL)
# ------------------------
def _read_jsonl(path):
with open(path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
yield json.loads(line)
def _download_file(filename: str) -> str:
# descarga desde el dataset hug
return hf_hub_download(repo_id=WIKISQL_REPO, filename=filename, repo_type="dataset")
@lru_cache(maxsize=32)
def get_table_from_wikisql(split: str, index: int, max_rows: int) -> pd.DataFrame:
"""
Carga la tabla de WikiSQL sin scripts, usando directamente los JSONL del repo:
- dev.jsonl (validation = 'dev' en terminología original)
- dev.tables.jsonl
Si cambias split a 'train' o 'test', intenta los nombres equivalentes.
"""
# Mapeo simple: validation->dev, train->train, test->test
split_map = {"validation": "dev", "dev": "dev", "train": "train", "test": "test"}
base = split_map.get(split.lower(), "dev")
# Posibles nombres de archivo en el repo (algunos mirrors usan variantes)
qa_candidates = [f"data/{base}.jsonl", f"data/{base}.json", f"{base}.jsonl"]
tbl_candidates = [f"data/{base}.tables.jsonl", f"{base}.tables.jsonl"]
qa_path = None
tbl_path = None
# Descarga QA
for cand in qa_candidates:
try:
qa_path = _download_file(cand)
break
except Exception:
continue
if qa_path is None:
raise RuntimeError(f"No se encontró el archivo QA para split={split}. Intentos: {qa_candidates}")
# Descarga tablas
for cand in tbl_candidates:
try:
tbl_path = _download_file(cand)
break
except Exception:
continue
if tbl_path is None:
raise RuntimeError(f"No se encontró el archivo de tablas para split={split}. Intentos: {tbl_candidates}")
# Leemos la pregunta N (para tomar su table_id) — si no necesitas la pregunta, puedes omitir esto
qa_list = list(_read_jsonl(qa_path))
if not (0 <= index < len(qa_list)):
raise IndexError(f"index={index} fuera de rango (0..{len(qa_list)-1}) para split={split}")
table_id = qa_list[index].get("table_id") or qa_list[index].get("table", {}).get("id")
if table_id is None:
raise RuntimeError("No se pudo extraer 'table_id' del registro de QA.")
# Buscamos esa tabla en dev.tables.jsonl
header, rows = None, None
for obj in _read_jsonl(tbl_path):
if obj.get("id") == table_id:
header = [str(h) for h in obj["header"]]
rows = obj["rows"]
break
if header is None or rows is None:
raise RuntimeError(f"No se encontró la tabla con id={table_id} en {os.path.basename(tbl_path)}")
# recortamos filas
rows = rows[:max_rows]
df = pd.DataFrame(rows, columns=header)
df.columns = [str(c) for c in df.columns]
return df
# ------------------------
# Endpoints
# ------------------------
@app.get("/api/health")
def health():
return {"ok": True, "model": HF_MODEL_ID, "split": SPLIT, "index": INDEX}
@app.get("/api/preview")
def preview():
try:
df = get_table_from_wikisql(SPLIT, INDEX, MAX_ROWS)
return {"columns": df.columns.tolist(), "rows": df.head(8).to_dict(orient="records")}
except Exception as e:
return {"error": str(e)}
@app.post("/api/nl2sql")
def nl2sql(q: NLQuery):
try:
text = (q.nl_query or "").strip()
if not text:
raise ValueError("Consulta vacía.")
# Traducción ES->EN si detectamos acentos u otros
is_ascii = all(ord(c) < 128 for c in text)
query_en = text if is_ascii else GoogleTranslator(source="auto", target="en").translate(text)
df = get_table_from_wikisql(SPLIT, INDEX, MAX_ROWS)
enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
if torch.cuda.is_available():
enc = {k: v.to("cuda") for k, v in enc.items()}
out = model.generate(**enc, max_length=160, num_beams=1)
sql = tok.batch_decode(out, skip_special_tokens=True)[0]
return {
"consulta_original": text,
"consulta_traducida": query_en,
"sql_generado": sql
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) |