Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from functools import lru_cache | |
| from transformers import TapexTokenizer, BartForConditionalGeneration | |
| from deep_translator import GoogleTranslator | |
| from pathlib import Path | |
| import os, json, pandas as pd, torch | |
| # ------------------------ | |
| # Config | |
| # ------------------------ | |
| HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best") | |
| SPLIT = os.getenv("TABLE_SPLIT", "validation") # "validation" ~ "dev" | |
| INDEX = int(os.getenv("TABLE_INDEX", "10")) | |
| MAX_ROWS = int(os.getenv("MAX_ROWS", "12")) | |
| # ------------------------ | |
| # App | |
| # ------------------------ | |
| app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (API)") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True, | |
| ) | |
| class NLQuery(BaseModel): | |
| nl_query: str | |
| # ------------------------ | |
| # Modelo | |
| # ------------------------ | |
| tok = TapexTokenizer.from_pretrained(HF_MODEL_ID) | |
| model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID) | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda") | |
| # ------------------------ | |
| # Utilidades de carga robustas | |
| # ------------------------ | |
| def _read_json_or_jsonl(p: Path) -> dict: | |
| """ | |
| Lee un JSON normal (.json) o un JSONL (.jsonl) y devuelve el primer objeto. | |
| """ | |
| txt = p.read_text(encoding="utf-8").strip() | |
| if p.suffix.lower() == ".jsonl": | |
| for line in txt.splitlines(): | |
| s = line.strip() | |
| if s: | |
| return json.loads(s) | |
| raise ValueError(f"{p} está vacío.") | |
| return json.loads(txt) | |
| def get_table(split: str, index: int, max_rows: int) -> pd.DataFrame: | |
| """ | |
| 1) Intenta cargar ./data/<split>.json o ./data/<split>.jsonl (mapeando 'validation'->'dev'). | |
| 2) Si no existe, hace fallback a un ejemplo de WikiSQL (conversión Parquet oficial). | |
| """ | |
| base_dir = Path(__file__).parent | |
| data_dir = base_dir / "data" | |
| # Normalizamos nombre local (para demo usamos 'dev') | |
| local_name = "dev" if split.lower() in ("validation", "dev") else split.lower() | |
| # 1) Buscar archivo local | |
| for candidate in (data_dir / f"{local_name}.json", data_dir / f"{local_name}.jsonl"): | |
| if candidate.exists(): | |
| js = _read_json_or_jsonl(candidate) | |
| header = [str(h) for h in js["header"]] | |
| rows = js["rows"][:max_rows] | |
| df = pd.DataFrame(rows, columns=header) | |
| df.columns = [str(c) for c in df.columns] | |
| return df | |
| # 2) Fallback: cargar un ejemplo del dataset WikiSQL (Parquet convertido) | |
| try: | |
| from datasets import load_dataset # import diferido para arrancar más rápido | |
| ds = load_dataset("Salesforce/wikisql", split="validation", revision="refs/convert/parquet") | |
| if not (0 <= index < len(ds)): | |
| index = 0 # seguridad | |
| ex = ds[index] | |
| header = [str(h) for h in ex["table"]["header"]] | |
| rows = ex["table"]["rows"][:max_rows] | |
| df = pd.DataFrame(rows, columns=header) | |
| df.columns = [str(c) for c in df.columns] | |
| return df | |
| except Exception as e: | |
| raise RuntimeError(f"No se pudo obtener una tabla: {e}") | |
| # ------------------------ | |
| # Endpoints | |
| # ------------------------ | |
| def health(): | |
| return {"ok": True, "model": HF_MODEL_ID, "split": SPLIT, "index": INDEX} | |
| def preview(): | |
| try: | |
| df = get_table(SPLIT, INDEX, MAX_ROWS) | |
| return {"columns": df.columns.tolist(), "rows": df.head(8).to_dict(orient="records")} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def nl2sql(q: NLQuery): | |
| try: | |
| text = (q.nl_query or "").strip() | |
| if not text: | |
| raise ValueError("Consulta vacía.") | |
| # Detectar si parece SQL | |
| lower = text.lower().strip() | |
| looks_like_sql = lower.startswith(("select", "with", "insert", "update", "delete", "create", "drop", "alter")) | |
| # Traducir a inglés si no es SQL | |
| query_en = text | |
| if not looks_like_sql: | |
| try: | |
| translated = GoogleTranslator(source="auto", target="en").translate(text) | |
| if translated: | |
| query_en = translated | |
| except Exception: | |
| query_en = text # fallback seguro | |
| # Procesar con TAPEX | |
| df = get_table(SPLIT, INDEX, MAX_ROWS) | |
| enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True) | |
| if torch.cuda.is_available(): | |
| enc = {k: v.to("cuda") for k, v in enc.items()} | |
| out = model.generate(**enc, max_length=160, num_beams=1) | |
| sql = tok.batch_decode(out, skip_special_tokens=True)[0] | |
| return { | |
| "consulta_original": text, | |
| "consulta_traducida": query_en, | |
| "sql_generado": sql | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |