stvnnnnnn commited on
Commit
619ce97
·
verified ·
1 Parent(s): c1c647f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -83
app.py CHANGED
@@ -1,8 +1,8 @@
1
- # app.py — NL→SQL (TAPEX + WikiSQL) backend for HF Spaces
2
 
3
  from fastapi import FastAPI, HTTPException
4
- from fastapi.responses import HTMLResponse, JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
 
6
  from pydantic import BaseModel
7
 
8
  import os
@@ -14,151 +14,82 @@ from datasets import load_dataset
14
  from deep_translator import GoogleTranslator
15
  from transformers import TapexTokenizer, BartForConditionalGeneration
16
 
17
-
18
- # --------- Configuración y defaults ----------
19
  HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
20
  TABLE_SPLIT = os.getenv("TABLE_SPLIT", "validation")
21
  TABLE_INDEX = int(os.getenv("TABLE_INDEX", "10"))
22
- MAX_ROWS = int(os.getenv("MAX_ROWS", "100")) # límite prudente en CPU
23
 
24
- # Asegura caché escribible en Space
25
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
26
  os.makedirs(os.environ["HF_HOME"], exist_ok=True)
27
 
28
-
29
- # --------- App & CORS ----------
30
- app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (HF Space)")
31
  app.add_middleware(
32
  CORSMiddleware,
33
- allow_origins=["*"], # cambia a tu dominio de Vercel cuando lo tengas
34
  allow_credentials=True,
35
  allow_methods=["*"],
36
  allow_headers=["*"],
37
  )
38
 
39
-
40
- # --------- Carga perezosa (lazy) ---------
41
  @lru_cache(maxsize=1)
42
  def get_model_and_tokenizer():
43
  tok = TapexTokenizer.from_pretrained(HF_MODEL_ID)
44
- # En Spaces free es CPU; no uses device_map="auto" para evitar dependencia de accelerate
45
  model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID)
46
  model.eval()
47
  return tok, model
48
 
49
-
50
  @lru_cache(maxsize=32)
51
  def get_table(split: str, index: int, max_rows: int) -> pd.DataFrame:
52
- """
53
- Carga una tabla de WikiSQL y devuelve un DataFrame.
54
- Evitamos el revision parquet (que a veces falla en Spaces) y usamos el split normal.
55
- """
56
- ds = load_dataset("Salesforce/wikisql", split=split)
57
  if index < 0 or index >= len(ds):
58
  raise IndexError(f"TABLE_INDEX fuera de rango (0..{len(ds)-1}).")
59
-
60
  ex = ds[index]
61
  header = [str(h) for h in ex["table"]["header"]]
62
  rows = ex["table"]["rows"][:max_rows]
63
  df = pd.DataFrame(rows, columns=header)
64
- # Normaliza nombres de columnas a string
65
  df.columns = [str(c) for c in df.columns]
66
  return df
67
 
68
-
69
- # --------- Esquema de petición ----------
70
  class NLQuery(BaseModel):
71
  nl_query: str
72
 
73
-
74
- # --------- Mini UI en / ----------
75
- INDEX_HTML = """
76
- <!doctype html>
77
- <html lang="es">
78
- <meta charset="utf-8">
79
- <title>NL→SQL (TAPEX + WikiSQL)</title>
80
- <style>
81
- body{font-family:system-ui,-apple-system,Segoe UI,Roboto,Ubuntu,sans-serif;max-width:860px;margin:30px auto;padding:0 16px;color:#eaeaea;background:#0f1115}
82
- h1{font-size:1.6rem;margin:0 0 8px}
83
- .card{background:#171923;border:1px solid #232736;border-radius:12px;padding:16px;margin:18px 0}
84
- input,button,select{font-size:1rem}
85
- input{width:100%;padding:10px 12px;border-radius:8px;border:1px solid #2a3042;background:#0f1115;color:#eaeaea}
86
- button{padding:10px 16px;border-radius:8px;border:1px solid #2a3042;background:#1f2433;color:#eaeaea;cursor:pointer}
87
- pre{white-space:pre-wrap;background:#0f1115;border:1px solid #2a3042;padding:12px;border-radius:8px}
88
- .row{display:flex;gap:12px;align-items:center}
89
- </style>
90
- <h1>🧠 NL → SQL (TAPEX + WikiSQL)</h1>
91
- <div class="card">
92
- <p><b>Backend:</b> este Space ofrece endpoints REST; prueba una consulta:</p>
93
- <div class="row">
94
- <input id="q" placeholder="Ej.: Muestra los jugadores que son Guards." />
95
- <button onclick="run()">Generar SQL</button>
96
- <button onclick="prev()">Ver preview tabla</button>
97
- </div>
98
- <p style="font-size:.9rem;opacity:.8">Swagger: <a href="./docs" target="_blank">/docs</a> · Salud: <a href="./api/health" target="_blank">/api/health</a></p>
99
- <pre id="out">Listo para generar...</pre>
100
- </div>
101
- <script>
102
- async function run(){
103
- const q = document.getElementById('q').value.trim();
104
- const r = await fetch('./api/nl2sql', {method:'POST', headers:{'Content-Type':'application/json'}, body: JSON.stringify({nl_query: q}) });
105
- document.getElementById('out').textContent = JSON.stringify(await r.json(), null, 2);
106
- }
107
- async function prev(){
108
- const r = await fetch('./api/preview');
109
- document.getElementById('out').textContent = JSON.stringify(await r.json(), null, 2);
110
- }
111
- </script>
112
- </html>
113
- """
114
-
115
- @app.get("/", response_class=HTMLResponse)
116
- def home():
117
- return INDEX_HTML
118
-
119
-
120
- # --------- Rutas API ----------
121
  @app.get("/api/health")
122
  def health():
123
  return {"ok": True, "model": HF_MODEL_ID, "split": TABLE_SPLIT, "index": TABLE_INDEX}
124
 
125
-
126
  @app.get("/api/preview")
127
  def preview():
128
  try:
129
  df = get_table(TABLE_SPLIT, TABLE_INDEX, MAX_ROWS)
130
- # Regresa primeras 8 filas para no saturar
131
- data = df.head(8).to_dict(orient="records")
132
- return {"columns": list(df.columns), "rows": data}
133
  except Exception as e:
134
- # Envía mensaje simple (para no saturar logs)
135
  return JSONResponse(status_code=500, content={"error": str(e)})
136
 
137
-
138
  @app.post("/api/nl2sql")
139
  def nl2sql(q: NLQuery):
140
  nl = (q.nl_query or "").strip()
141
  if not nl:
142
  raise HTTPException(status_code=400, detail="Consulta vacía.")
143
 
144
- # Traducción ES→EN si detectamos caracteres no ASCII
145
  try:
146
  is_ascii = all(ord(c) < 128 for c in nl)
147
  nl_en = nl if is_ascii else GoogleTranslator(source="auto", target="en").translate(nl)
148
  except Exception:
149
- # Si la traducción falla, seguimos con el texto original
150
  nl_en = nl
151
 
152
  try:
153
  df = get_table(TABLE_SPLIT, TABLE_INDEX, MAX_ROWS)
154
  tok, model = get_model_and_tokenizer()
155
-
156
- # Tokenización limitada
157
  enc = tok(table=df, query=nl_en, return_tensors="pt", truncation=True, max_length=512)
158
-
159
  with torch.inference_mode():
160
  out = model.generate(**enc, max_length=160, num_beams=1)
161
-
162
  sql = tok.batch_decode(out, skip_special_tokens=True)[0]
163
  return {"consulta_original": nl, "consulta_traducida": nl_en, "sql_generado": sql}
164
  except Exception as e:
 
1
+ # app.py — NL→SQL (TAPEX + WikiSQL) backend (solo API)
2
 
3
  from fastapi import FastAPI, HTTPException
 
4
  from fastapi.middleware.cors import CORSMiddleware
5
+ from fastapi.responses import JSONResponse
6
  from pydantic import BaseModel
7
 
8
  import os
 
14
  from deep_translator import GoogleTranslator
15
  from transformers import TapexTokenizer, BartForConditionalGeneration
16
 
17
+ # ------------ Config ------------
 
18
  HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
19
  TABLE_SPLIT = os.getenv("TABLE_SPLIT", "validation")
20
  TABLE_INDEX = int(os.getenv("TABLE_INDEX", "10"))
21
+ MAX_ROWS = int(os.getenv("MAX_ROWS", "100"))
22
 
23
+ # Caché HF escribible en Spaces
24
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
25
  os.makedirs(os.environ["HF_HOME"], exist_ok=True)
26
 
27
+ # ------------ App & CORS ------------
28
+ app = FastAPI(title="NL→SQL TAPEX + WikiSQL (API)")
 
29
  app.add_middleware(
30
  CORSMiddleware,
31
+ allow_origins=["*"], # cámbialo al dominio de Vercel cuando lo tengas
32
  allow_credentials=True,
33
  allow_methods=["*"],
34
  allow_headers=["*"],
35
  )
36
 
37
+ # ------------ Carga perezosa ------------
 
38
  @lru_cache(maxsize=1)
39
  def get_model_and_tokenizer():
40
  tok = TapexTokenizer.from_pretrained(HF_MODEL_ID)
 
41
  model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID)
42
  model.eval()
43
  return tok, model
44
 
 
45
  @lru_cache(maxsize=32)
46
  def get_table(split: str, index: int, max_rows: int) -> pd.DataFrame:
47
+ ds = load_dataset("Salesforce/wikisql", split=split) # evita revision parquet
 
 
 
 
48
  if index < 0 or index >= len(ds):
49
  raise IndexError(f"TABLE_INDEX fuera de rango (0..{len(ds)-1}).")
 
50
  ex = ds[index]
51
  header = [str(h) for h in ex["table"]["header"]]
52
  rows = ex["table"]["rows"][:max_rows]
53
  df = pd.DataFrame(rows, columns=header)
 
54
  df.columns = [str(c) for c in df.columns]
55
  return df
56
 
57
+ # ------------ Schemas ------------
 
58
  class NLQuery(BaseModel):
59
  nl_query: str
60
 
61
+ # ------------ Endpoints ------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  @app.get("/api/health")
63
  def health():
64
  return {"ok": True, "model": HF_MODEL_ID, "split": TABLE_SPLIT, "index": TABLE_INDEX}
65
 
 
66
  @app.get("/api/preview")
67
  def preview():
68
  try:
69
  df = get_table(TABLE_SPLIT, TABLE_INDEX, MAX_ROWS)
70
+ return {"columns": list(df.columns), "rows": df.head(8).to_dict(orient="records")}
 
 
71
  except Exception as e:
 
72
  return JSONResponse(status_code=500, content={"error": str(e)})
73
 
 
74
  @app.post("/api/nl2sql")
75
  def nl2sql(q: NLQuery):
76
  nl = (q.nl_query or "").strip()
77
  if not nl:
78
  raise HTTPException(status_code=400, detail="Consulta vacía.")
79
 
80
+ # Traducción ES→EN si no-ASCII
81
  try:
82
  is_ascii = all(ord(c) < 128 for c in nl)
83
  nl_en = nl if is_ascii else GoogleTranslator(source="auto", target="en").translate(nl)
84
  except Exception:
 
85
  nl_en = nl
86
 
87
  try:
88
  df = get_table(TABLE_SPLIT, TABLE_INDEX, MAX_ROWS)
89
  tok, model = get_model_and_tokenizer()
 
 
90
  enc = tok(table=df, query=nl_en, return_tensors="pt", truncation=True, max_length=512)
 
91
  with torch.inference_mode():
92
  out = model.generate(**enc, max_length=160, num_beams=1)
 
93
  sql = tok.batch_decode(out, skip_special_tokens=True)[0]
94
  return {"consulta_original": nl, "consulta_traducida": nl_en, "sql_generado": sql}
95
  except Exception as e: