Update app.py
Browse files
app.py
CHANGED
|
@@ -5,14 +5,14 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianMTModel, Ma
|
|
| 5 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 6 |
|
| 7 |
MODEL_OPTIONS = [
|
| 8 |
-
"Helsinki-NLP (
|
| 9 |
-
"FLAN-T5-base (Google
|
| 10 |
]
|
| 11 |
|
| 12 |
# Cache
|
| 13 |
CACHE = {}
|
| 14 |
|
| 15 |
-
# --- FLAN loader ---
|
| 16 |
def load_flan():
|
| 17 |
if "flan" not in CACHE:
|
| 18 |
tok = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
|
@@ -26,33 +26,33 @@ def load_flan():
|
|
| 26 |
|
| 27 |
def run_flan(sentence: str) -> str:
|
| 28 |
model, tok = load_flan()
|
| 29 |
-
prompt = f"
|
| 30 |
inputs = tok(prompt, return_tensors="pt").to(DEVICE)
|
| 31 |
with torch.no_grad():
|
| 32 |
out = model.generate(**inputs, max_new_tokens=96, num_beams=4)
|
| 33 |
return tok.decode(out[0], skip_special_tokens=True).strip()
|
| 34 |
|
| 35 |
-
# ---
|
| 36 |
-
def
|
| 37 |
-
if "
|
| 38 |
-
tok1 =
|
| 39 |
-
mdl1 =
|
| 40 |
-
tok2 =
|
| 41 |
-
mdl2 =
|
| 42 |
-
CACHE["
|
| 43 |
-
return CACHE["
|
| 44 |
|
| 45 |
def run_roundtrip(sentence: str) -> str:
|
| 46 |
-
mdl1, tok1, mdl2, tok2 =
|
| 47 |
-
#
|
| 48 |
inputs = tok1(sentence, return_tensors="pt").to(DEVICE)
|
| 49 |
es_tokens = mdl1.generate(**inputs, max_length=128, num_beams=4)
|
| 50 |
spanish = tok1.decode(es_tokens[0], skip_special_tokens=True)
|
| 51 |
-
# Spanish →
|
| 52 |
inputs2 = tok2(spanish, return_tensors="pt").to(DEVICE)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
return
|
| 56 |
|
| 57 |
# --- Dispatcher ---
|
| 58 |
def polish(sentence: str, choice: str) -> str:
|
|
@@ -66,8 +66,8 @@ def polish(sentence: str, choice: str) -> str:
|
|
| 66 |
return "Unknown option."
|
| 67 |
|
| 68 |
# --- Gradio UI ---
|
| 69 |
-
with gr.Blocks(title="HizkuntzLagun:
|
| 70 |
-
gr.Markdown("### HizkuntzLagun:
|
| 71 |
gr.Markdown(
|
| 72 |
"""
|
| 73 |
> ⚡ **Note:**
|
|
@@ -76,11 +76,11 @@ with gr.Blocks(title="HizkuntzLagun: English Fixer (CPU enabled)") as demo:
|
|
| 76 |
> Expect quick corrections, not deep grammar analysis.
|
| 77 |
> Drop in anytime — a quick fix a day keeps awkward grammar away.
|
| 78 |
""")
|
| 79 |
-
inp = gr.Textbox(lines=3, label="
|
| 80 |
-
choice = gr.Dropdown(choices=MODEL_OPTIONS, value="
|
| 81 |
-
btn = gr.Button("
|
| 82 |
-
out = gr.Textbox(label="
|
| 83 |
btn.click(polish, inputs=[inp, choice], outputs=out)
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|
| 86 |
-
demo.launch()
|
|
|
|
| 5 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 6 |
|
| 7 |
MODEL_OPTIONS = [
|
| 8 |
+
"Helsinki-NLP (Tira ondo)", # Round-trip OPUS-MT en→es→en
|
| 9 |
+
"FLAN-T5-base (Google gaizki xamar)"
|
| 10 |
]
|
| 11 |
|
| 12 |
# Cache
|
| 13 |
CACHE = {}
|
| 14 |
|
| 15 |
+
# --- FLAN loader (Google-style Euskera correction) ---
|
| 16 |
def load_flan():
|
| 17 |
if "flan" not in CACHE:
|
| 18 |
tok = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
|
|
|
| 26 |
|
| 27 |
def run_flan(sentence: str) -> str:
|
| 28 |
model, tok = load_flan()
|
| 29 |
+
prompt = f"Euskara zuzen gramatikalki eta idatzi modu naturalean: {sentence}"
|
| 30 |
inputs = tok(prompt, return_tensors="pt").to(DEVICE)
|
| 31 |
with torch.no_grad():
|
| 32 |
out = model.generate(**inputs, max_new_tokens=96, num_beams=4)
|
| 33 |
return tok.decode(out[0], skip_special_tokens=True).strip()
|
| 34 |
|
| 35 |
+
# --- Euskera round-trip loader ---
|
| 36 |
+
def load_euskera():
|
| 37 |
+
if "eus" not in CACHE:
|
| 38 |
+
tok1 = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-eu-es")
|
| 39 |
+
mdl1 = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-eu-es").to(DEVICE)
|
| 40 |
+
tok2 = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-eu")
|
| 41 |
+
mdl2 = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-es-eu").to(DEVICE)
|
| 42 |
+
CACHE["eus"] = (mdl1, tok1, mdl2, tok2)
|
| 43 |
+
return CACHE["eus"]
|
| 44 |
|
| 45 |
def run_roundtrip(sentence: str) -> str:
|
| 46 |
+
mdl1, tok1, mdl2, tok2 = load_euskera()
|
| 47 |
+
# Euskera → Spanish
|
| 48 |
inputs = tok1(sentence, return_tensors="pt").to(DEVICE)
|
| 49 |
es_tokens = mdl1.generate(**inputs, max_length=128, num_beams=4)
|
| 50 |
spanish = tok1.decode(es_tokens[0], skip_special_tokens=True)
|
| 51 |
+
# Spanish → Euskera
|
| 52 |
inputs2 = tok2(spanish, return_tensors="pt").to(DEVICE)
|
| 53 |
+
eu_tokens = mdl2.generate(**inputs2, max_length=128, num_beams=4)
|
| 54 |
+
euskera = tok2.decode(eu_tokens[0], skip_special_tokens=True)
|
| 55 |
+
return euskera.strip()
|
| 56 |
|
| 57 |
# --- Dispatcher ---
|
| 58 |
def polish(sentence: str, choice: str) -> str:
|
|
|
|
| 66 |
return "Unknown option."
|
| 67 |
|
| 68 |
# --- Gradio UI ---
|
| 69 |
+
with gr.Blocks(title="HizkuntzLagun: Euskera Fixer (CPU enabled)") as demo:
|
| 70 |
+
gr.Markdown("### HizkuntzLagun: Euskera Fixer\n")
|
| 71 |
gr.Markdown(
|
| 72 |
"""
|
| 73 |
> ⚡ **Note:**
|
|
|
|
| 76 |
> Expect quick corrections, not deep grammar analysis.
|
| 77 |
> Drop in anytime — a quick fix a day keeps awkward grammar away.
|
| 78 |
""")
|
| 79 |
+
inp = gr.Textbox(lines=3, label="Sarrera (Euskara)", placeholder="Idatzi zuzentzeko esaldi bat...")
|
| 80 |
+
choice = gr.Dropdown(choices=MODEL_OPTIONS, value="eu_sp_eu", label="Metodoa")
|
| 81 |
+
btn = gr.Button("Euskara zuzen")
|
| 82 |
+
out = gr.Textbox(label="Irteera")
|
| 83 |
btn.click(polish, inputs=[inp, choice], outputs=out)
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|
| 86 |
+
demo.launch()
|