File size: 5,644 Bytes
9c83e37
 
 
c1c647f
349e4b9
9c83e37
349e4b9
 
9c83e37
349e4b9
 
 
9c83e37
349e4b9
 
 
 
 
 
 
 
619ce97
9c83e37
 
349e4b9
9c83e37
 
349e4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1c647f
 
349e4b9
79b7bce
349e4b9
 
 
 
79b7bce
349e4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1c647f
 
 
 
349e4b9
 
 
9c83e37
 
349e4b9
9c83e37
 
 
c1c647f
349e4b9
 
c1c647f
349e4b9
c1c647f
9c83e37
 
c1c647f
349e4b9
 
 
 
 
 
 
 
 
 
 
 
 
c1c647f
349e4b9
 
 
 
 
 
9c83e37
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from functools import lru_cache
from huggingface_hub import hf_hub_download
from transformers import TapexTokenizer, BartForConditionalGeneration
from deep_translator import GoogleTranslator
import os, json, pandas as pd, torch

# ------------------------
# Config
# ------------------------
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
WIKISQL_REPO = os.getenv("WIKISQL_REPO", "Salesforce/wikisql")  # dataset oficial
SPLIT = os.getenv("TABLE_SPLIT", "validation")                  # "validation" == dev en WikiSQL
INDEX = int(os.getenv("TABLE_INDEX", "10"))
MAX_ROWS = int(os.getenv("MAX_ROWS", "12"))

# ------------------------
# App
# ------------------------
app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (API)")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True,
)

class NLQuery(BaseModel):
    nl_query: str

# ------------------------
# Modelo
# ------------------------
tok = TapexTokenizer.from_pretrained(HF_MODEL_ID)
model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID)
if torch.cuda.is_available():
    model = model.to("cuda")

# ------------------------
# Util: carga WikiSQL (JSONL)
# ------------------------
def _read_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                yield json.loads(line)

def _download_file(filename: str) -> str:
    # descarga desde el dataset hug
    return hf_hub_download(repo_id=WIKISQL_REPO, filename=filename, repo_type="dataset")

@lru_cache(maxsize=32)
def get_table_from_wikisql(split: str, index: int, max_rows: int) -> pd.DataFrame:
    """
    Carga la tabla de WikiSQL sin scripts, usando directamente los JSONL del repo:
      - dev.jsonl            (validation = 'dev' en terminología original)
      - dev.tables.jsonl
    Si cambias split a 'train' o 'test', intenta los nombres equivalentes.
    """
    # Mapeo simple: validation->dev, train->train, test->test
    split_map = {"validation": "dev", "dev": "dev", "train": "train", "test": "test"}
    base = split_map.get(split.lower(), "dev")

    # Posibles nombres de archivo en el repo (algunos mirrors usan variantes)
    qa_candidates = [f"data/{base}.jsonl", f"data/{base}.json", f"{base}.jsonl"]
    tbl_candidates = [f"data/{base}.tables.jsonl", f"{base}.tables.jsonl"]

    qa_path = None
    tbl_path = None

    # Descarga QA
    for cand in qa_candidates:
        try:
            qa_path = _download_file(cand)
            break
        except Exception:
            continue
    if qa_path is None:
        raise RuntimeError(f"No se encontró el archivo QA para split={split}. Intentos: {qa_candidates}")

    # Descarga tablas
    for cand in tbl_candidates:
        try:
            tbl_path = _download_file(cand)
            break
        except Exception:
            continue
    if tbl_path is None:
        raise RuntimeError(f"No se encontró el archivo de tablas para split={split}. Intentos: {tbl_candidates}")

    # Leemos la pregunta N (para tomar su table_id) — si no necesitas la pregunta, puedes omitir esto
    qa_list = list(_read_jsonl(qa_path))
    if not (0 <= index < len(qa_list)):
        raise IndexError(f"index={index} fuera de rango (0..{len(qa_list)-1}) para split={split}")
    table_id = qa_list[index].get("table_id") or qa_list[index].get("table", {}).get("id")
    if table_id is None:
        raise RuntimeError("No se pudo extraer 'table_id' del registro de QA.")

    # Buscamos esa tabla en dev.tables.jsonl
    header, rows = None, None
    for obj in _read_jsonl(tbl_path):
        if obj.get("id") == table_id:
            header = [str(h) for h in obj["header"]]
            rows = obj["rows"]
            break
    if header is None or rows is None:
        raise RuntimeError(f"No se encontró la tabla con id={table_id} en {os.path.basename(tbl_path)}")

    # recortamos filas
    rows = rows[:max_rows]
    df = pd.DataFrame(rows, columns=header)
    df.columns = [str(c) for c in df.columns]
    return df

# ------------------------
# Endpoints
# ------------------------
@app.get("/api/health")
def health():
    return {"ok": True, "model": HF_MODEL_ID, "split": SPLIT, "index": INDEX}

@app.get("/api/preview")
def preview():
    try:
        df = get_table_from_wikisql(SPLIT, INDEX, MAX_ROWS)
        return {"columns": df.columns.tolist(), "rows": df.head(8).to_dict(orient="records")}
    except Exception as e:
        return {"error": str(e)}

@app.post("/api/nl2sql")
def nl2sql(q: NLQuery):
    try:
        text = (q.nl_query or "").strip()
        if not text:
            raise ValueError("Consulta vacía.")

        # Traducción ES->EN si detectamos acentos u otros
        is_ascii = all(ord(c) < 128 for c in text)
        query_en = text if is_ascii else GoogleTranslator(source="auto", target="en").translate(text)

        df = get_table_from_wikisql(SPLIT, INDEX, MAX_ROWS)
        enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
        if torch.cuda.is_available():
            enc = {k: v.to("cuda") for k, v in enc.items()}
        out = model.generate(**enc, max_length=160, num_beams=1)
        sql = tok.batch_decode(out, skip_special_tokens=True)[0]

        return {
            "consulta_original": text,
            "consulta_traducida": query_en,
            "sql_generado": sql
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))