from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from functools import lru_cache from transformers import TapexTokenizer, BartForConditionalGeneration from deep_translator import GoogleTranslator from pathlib import Path import os, json, pandas as pd, torch # ------------------------ # Config # ------------------------ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best") SPLIT = os.getenv("TABLE_SPLIT", "validation") # "validation" ~ "dev" INDEX = int(os.getenv("TABLE_INDEX", "10")) MAX_ROWS = int(os.getenv("MAX_ROWS", "12")) # ------------------------ # App # ------------------------ app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (API)") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True, ) class NLQuery(BaseModel): nl_query: str # ------------------------ # Modelo # ------------------------ tok = TapexTokenizer.from_pretrained(HF_MODEL_ID) model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID) if torch.cuda.is_available(): model = model.to("cuda") # ------------------------ # Utilidades de carga robustas # ------------------------ def _read_json_or_jsonl(p: Path) -> dict: """ Lee un JSON normal (.json) o un JSONL (.jsonl) y devuelve el primer objeto. """ txt = p.read_text(encoding="utf-8").strip() if p.suffix.lower() == ".jsonl": for line in txt.splitlines(): s = line.strip() if s: return json.loads(s) raise ValueError(f"{p} está vacío.") return json.loads(txt) @lru_cache(maxsize=32) def get_table(split: str, index: int, max_rows: int) -> pd.DataFrame: """ 1) Intenta cargar ./data/.json o ./data/.jsonl (mapeando 'validation'->'dev'). 2) Si no existe, hace fallback a un ejemplo de WikiSQL (conversión Parquet oficial). """ base_dir = Path(__file__).parent data_dir = base_dir / "data" # Normalizamos nombre local (para demo usamos 'dev') local_name = "dev" if split.lower() in ("validation", "dev") else split.lower() # 1) Buscar archivo local for candidate in (data_dir / f"{local_name}.json", data_dir / f"{local_name}.jsonl"): if candidate.exists(): js = _read_json_or_jsonl(candidate) header = [str(h) for h in js["header"]] rows = js["rows"][:max_rows] df = pd.DataFrame(rows, columns=header) df.columns = [str(c) for c in df.columns] return df # 2) Fallback: cargar un ejemplo del dataset WikiSQL (Parquet convertido) try: from datasets import load_dataset # import diferido para arrancar más rápido ds = load_dataset("Salesforce/wikisql", split="validation", revision="refs/convert/parquet") if not (0 <= index < len(ds)): index = 0 # seguridad ex = ds[index] header = [str(h) for h in ex["table"]["header"]] rows = ex["table"]["rows"][:max_rows] df = pd.DataFrame(rows, columns=header) df.columns = [str(c) for c in df.columns] return df except Exception as e: raise RuntimeError(f"No se pudo obtener una tabla: {e}") # ------------------------ # Endpoints # ------------------------ @app.get("/api/health") def health(): return {"ok": True, "model": HF_MODEL_ID, "split": SPLIT, "index": INDEX} @app.get("/api/preview") def preview(): try: df = get_table(SPLIT, INDEX, MAX_ROWS) return {"columns": df.columns.tolist(), "rows": df.head(8).to_dict(orient="records")} except Exception as e: return {"error": str(e)} @app.post("/api/nl2sql") def nl2sql(q: NLQuery): try: text = (q.nl_query or "").strip() if not text: raise ValueError("Consulta vacía.") # Detectar si parece SQL lower = text.lower().strip() looks_like_sql = lower.startswith(("select", "with", "insert", "update", "delete", "create", "drop", "alter")) # Traducir a inglés si no es SQL query_en = text if not looks_like_sql: try: translated = GoogleTranslator(source="auto", target="en").translate(text) if translated: query_en = translated except Exception: query_en = text # fallback seguro # Procesar con TAPEX df = get_table(SPLIT, INDEX, MAX_ROWS) enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True) if torch.cuda.is_available(): enc = {k: v.to("cuda") for k, v in enc.items()} out = model.generate(**enc, max_length=160, num_beams=1) sql = tok.batch_decode(out, skip_special_tokens=True)[0] return { "consulta_original": text, "consulta_traducida": query_en, "sql_generado": sql } except Exception as e: raise HTTPException(status_code=500, detail=str(e))