AItool commited on
Commit
c5de144
·
verified ·
1 Parent(s): fa414ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -26
app.py CHANGED
@@ -5,14 +5,14 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianMTModel, Ma
5
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
 
7
  MODEL_OPTIONS = [
8
- "Helsinki-NLP (Acceptable)", # Round-trip OPUS-MT en→es→en
9
- "FLAN-T5-base (Google poor results)"
10
  ]
11
 
12
  # Cache
13
  CACHE = {}
14
 
15
- # --- FLAN loader ---
16
  def load_flan():
17
  if "flan" not in CACHE:
18
  tok = AutoTokenizer.from_pretrained("google/flan-t5-base")
@@ -26,33 +26,33 @@ def load_flan():
26
 
27
  def run_flan(sentence: str) -> str:
28
  model, tok = load_flan()
29
- prompt = f"Correct grammar and rewrite in fluent British English: {sentence}"
30
  inputs = tok(prompt, return_tensors="pt").to(DEVICE)
31
  with torch.no_grad():
32
  out = model.generate(**inputs, max_new_tokens=96, num_beams=4)
33
  return tok.decode(out[0], skip_special_tokens=True).strip()
34
 
35
- # --- Marian round-trip loader ---
36
- def load_marian():
37
- if "en_es" not in CACHE:
38
- tok1 = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
39
- mdl1 = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-es").to(DEVICE)
40
- tok2 = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-en")
41
- mdl2 = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-es-en").to(DEVICE)
42
- CACHE["en_es"] = (mdl1, tok1, mdl2, tok2)
43
- return CACHE["en_es"]
44
 
45
  def run_roundtrip(sentence: str) -> str:
46
- mdl1, tok1, mdl2, tok2 = load_marian()
47
- # English → Spanish
48
  inputs = tok1(sentence, return_tensors="pt").to(DEVICE)
49
  es_tokens = mdl1.generate(**inputs, max_length=128, num_beams=4)
50
  spanish = tok1.decode(es_tokens[0], skip_special_tokens=True)
51
- # Spanish → English
52
  inputs2 = tok2(spanish, return_tensors="pt").to(DEVICE)
53
- en_tokens = mdl2.generate(**inputs2, max_length=128, num_beams=4)
54
- english = tok2.decode(en_tokens[0], skip_special_tokens=True)
55
- return english.strip()
56
 
57
  # --- Dispatcher ---
58
  def polish(sentence: str, choice: str) -> str:
@@ -66,8 +66,8 @@ def polish(sentence: str, choice: str) -> str:
66
  return "Unknown option."
67
 
68
  # --- Gradio UI ---
69
- with gr.Blocks(title="HizkuntzLagun: English Fixer (CPU enabled)") as demo:
70
- gr.Markdown("### HizkuntzLagun: English Fixer\n")
71
  gr.Markdown(
72
  """
73
  > ⚡ **Note:**
@@ -76,11 +76,11 @@ with gr.Blocks(title="HizkuntzLagun: English Fixer (CPU enabled)") as demo:
76
  > Expect quick corrections, not deep grammar analysis.
77
  > Drop in anytime — a quick fix a day keeps awkward grammar away.
78
  """)
79
- inp = gr.Textbox(lines=3, label="Input (English) E.g. She go tomorrow buy two bread.", placeholder="Type an English sentence to correct.")
80
- choice = gr.Dropdown(choices=MODEL_OPTIONS, value="Helsinki-NLP", label="Method")
81
- btn = gr.Button("Oxford grammar polish")
82
- out = gr.Textbox(label="Output")
83
  btn.click(polish, inputs=[inp, choice], outputs=out)
84
 
85
  if __name__ == "__main__":
86
- demo.launch()
 
5
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
 
7
  MODEL_OPTIONS = [
8
+ "Helsinki-NLP (Tira ondo)", # Round-trip OPUS-MT en→es→en
9
+ "FLAN-T5-base (Google gaizki xamar)"
10
  ]
11
 
12
  # Cache
13
  CACHE = {}
14
 
15
+ # --- FLAN loader (Google-style Euskera correction) ---
16
  def load_flan():
17
  if "flan" not in CACHE:
18
  tok = AutoTokenizer.from_pretrained("google/flan-t5-base")
 
26
 
27
  def run_flan(sentence: str) -> str:
28
  model, tok = load_flan()
29
+ prompt = f"Euskara zuzen gramatikalki eta idatzi modu naturalean: {sentence}"
30
  inputs = tok(prompt, return_tensors="pt").to(DEVICE)
31
  with torch.no_grad():
32
  out = model.generate(**inputs, max_new_tokens=96, num_beams=4)
33
  return tok.decode(out[0], skip_special_tokens=True).strip()
34
 
35
+ # --- Euskera round-trip loader ---
36
+ def load_euskera():
37
+ if "eus" not in CACHE:
38
+ tok1 = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-eu-es")
39
+ mdl1 = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-eu-es").to(DEVICE)
40
+ tok2 = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-eu")
41
+ mdl2 = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-es-eu").to(DEVICE)
42
+ CACHE["eus"] = (mdl1, tok1, mdl2, tok2)
43
+ return CACHE["eus"]
44
 
45
  def run_roundtrip(sentence: str) -> str:
46
+ mdl1, tok1, mdl2, tok2 = load_euskera()
47
+ # Euskera → Spanish
48
  inputs = tok1(sentence, return_tensors="pt").to(DEVICE)
49
  es_tokens = mdl1.generate(**inputs, max_length=128, num_beams=4)
50
  spanish = tok1.decode(es_tokens[0], skip_special_tokens=True)
51
+ # Spanish → Euskera
52
  inputs2 = tok2(spanish, return_tensors="pt").to(DEVICE)
53
+ eu_tokens = mdl2.generate(**inputs2, max_length=128, num_beams=4)
54
+ euskera = tok2.decode(eu_tokens[0], skip_special_tokens=True)
55
+ return euskera.strip()
56
 
57
  # --- Dispatcher ---
58
  def polish(sentence: str, choice: str) -> str:
 
66
  return "Unknown option."
67
 
68
  # --- Gradio UI ---
69
+ with gr.Blocks(title="HizkuntzLagun: Euskera Fixer (CPU enabled)") as demo:
70
+ gr.Markdown("### HizkuntzLagun: Euskera Fixer\n")
71
  gr.Markdown(
72
  """
73
  > ⚡ **Note:**
 
76
  > Expect quick corrections, not deep grammar analysis.
77
  > Drop in anytime — a quick fix a day keeps awkward grammar away.
78
  """)
79
+ inp = gr.Textbox(lines=3, label="Sarrera (Euskara)", placeholder="Idatzi zuzentzeko esaldi bat...")
80
+ choice = gr.Dropdown(choices=MODEL_OPTIONS, value="eu_sp_eu", label="Metodoa")
81
+ btn = gr.Button("Euskara zuzen")
82
+ out = gr.Textbox(label="Irteera")
83
  btn.click(polish, inputs=[inp, choice], outputs=out)
84
 
85
  if __name__ == "__main__":
86
+ demo.launch()