AItool's picture
Update app.py
e73a1af verified
raw
history blame
3.44 kB
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianMTModel, MarianTokenizer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_OPTIONS = [
"Helsinki-NLP (Tira ondo)", # Round-trip OPUS-MT en→es→en
"FLAN-T5-base (Google gaizki xamar)"
]
# Cache
CACHE = {}
# --- FLAN loader (Google-style Euskera correction) ---
def load_flan():
if "flan" not in CACHE:
tok = AutoTokenizer.from_pretrained("google/flan-t5-base")
mdl = AutoModelForSeq2SeqLM.from_pretrained(
"google/flan-t5-base",
low_cpu_mem_usage=True,
torch_dtype="auto"
).to(DEVICE)
CACHE["flan"] = (mdl, tok)
return CACHE["flan"]
def run_flan(sentence: str) -> str:
model, tok = load_flan()
prompt = f"Euskara zuzen gramatikalki eta idatzi modu naturalean: {sentence}"
inputs = tok(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=96, num_beams=4)
return tok.decode(out[0], skip_special_tokens=True).strip()
# --- Euskera round-trip loader ---
def load_euskera():
if "eus" not in CACHE:
tok1 = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-eu-es")
mdl1 = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-eu-es").to(DEVICE)
tok2 = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-eu")
mdl2 = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-es-eu").to(DEVICE)
CACHE["eus"] = (mdl1, tok1, mdl2, tok2)
return CACHE["eus"]
def run_roundtrip(sentence: str) -> str:
mdl1, tok1, mdl2, tok2 = load_euskera()
# Euskera → Spanish
inputs = tok1(sentence, return_tensors="pt").to(DEVICE)
es_tokens = mdl1.generate(**inputs, max_length=128, num_beams=4)
spanish = tok1.decode(es_tokens[0], skip_special_tokens=True)
# Spanish → Euskera
inputs2 = tok2(spanish, return_tensors="pt").to(DEVICE)
eu_tokens = mdl2.generate(**inputs2, max_length=128, num_beams=4)
euskera = tok2.decode(eu_tokens[0], skip_special_tokens=True)
return euskera.strip()
# --- Dispatcher ---
def polish(sentence: str, choice: str) -> str:
if not sentence.strip():
return ""
if choice.startswith("FLAN"):
return run_flan(sentence)
elif choice.startswith("Helsinki"):
return run_roundtrip(sentence)
else:
return "Unknown option."
# --- Gradio UI ---
with gr.Blocks(title="HizkuntzLagun: AI Euskera Zuzendu (CPU enabled)") as demo:
gr.Markdown("### HizkuntzLagun: AI Euskera Zuzedu\n")
gr.Markdown(
"""
> ⚡ **Oharra:**
> Tresna honek doako, CPU‑lagunko AI ereduak erabiltzen ditu.
> Azkarra eta eskuragarria izateko diseinatuta dago — ez beti perfektua.
> Zuzenketa azkarrak bai, ez analisi gramatikal sakonak.
> Edozein unetan erabil dezakezu — eguneroko zuzenketa txiki batek saihesten du esaldi traketsen lotsa.
""")
inp = gr.Textbox(lines=3, label="Idatzi Euskeraz esaldi bat, adibidez Gaur Kondo ikusi nuen.", placeholder="Idatzi esaldi bat...")
choice = gr.Dropdown(choices=MODEL_OPTIONS, value="Helsinki", label="Metodoa")
btn = gr.Button("Euskera zuzendu")
out = gr.Textbox(label="Zuzenketa")
btn.click(polish, inputs=[inp, choice], outputs=out)
if __name__ == "__main__":
demo.launch()