AItool commited on
Commit
e179439
·
verified ·
1 Parent(s): f6944de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -30
app.py CHANGED
@@ -1,34 +1,81 @@
 
 
 
 
1
  import gradio as gr
2
- from transformers import MarianMTModel, MarianTokenizer
3
-
4
- # Load English→Spanish
5
- en_es_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-es")
6
- en_es_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
7
-
8
- # Load SpanishEnglish
9
- es_en_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-es-en")
10
- es_en_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-en")
11
-
12
- def polish(sentence: str) -> str:
13
- # Step 1: English → Spanish
14
- inputs = en_es_tokenizer(sentence, return_tensors="pt", padding=True)
15
- translated = en_es_model.generate(**inputs, max_length=128)
16
- spanish = en_es_tokenizer.decode(translated[0], skip_special_tokens=True)
17
-
18
- # Step 2: Spanish → English
19
- inputs2 = es_en_tokenizer(spanish, return_tensors="pt", padding=True)
20
- back_translated = es_en_model.generate(**inputs2, max_length=128)
21
- english = es_en_tokenizer.decode(back_translated[0], skip_special_tokens=True)
22
-
23
- return english
24
-
25
- demo = gr.Interface(
26
- fn=polish,
27
- inputs=gr.Textbox(lines=2, placeholder="Enter a sentence in English..."),
28
- outputs=gr.Textbox(label="Corrected English"),
29
- title="Round-trip Grammar Polisher",
30
- description="Uses Helsinki-NLP MarianMT models (en→es→en) to smooth and correct English sentences."
31
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  if __name__ == "__main__":
34
  demo.launch()
 
1
+ # app.py
2
+ # Requirements: transformers, torch, sentencepiece, sacremoses, gradio
3
+
4
+ import torch
5
  import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianMTModel, MarianTokenizer
7
+
8
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ MODEL_OPTIONS = [
11
+ "FLAN-T5-base (Google en→en)",
12
+ "Round-trip OPUS-MT enes→en (Helsinki-NLP)"
13
+ ]
14
+
15
+ # Cache
16
+ CACHE = {}
17
+
18
+ # --- FLAN loader ---
19
+ def load_flan():
20
+ if "flan" not in CACHE:
21
+ tok = AutoTokenizer.from_pretrained("google/flan-t5-base")
22
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(
23
+ "google/flan-t5-base",
24
+ low_cpu_mem_usage=True,
25
+ torch_dtype="auto"
26
+ ).to(DEVICE)
27
+ CACHE["flan"] = (mdl, tok)
28
+ return CACHE["flan"]
29
+
30
+ def run_flan(sentence: str) -> str:
31
+ model, tok = load_flan()
32
+ prompt = f"Correct grammar and rewrite in fluent British English: {sentence}"
33
+ inputs = tok(prompt, return_tensors="pt").to(DEVICE)
34
+ with torch.no_grad():
35
+ out = model.generate(**inputs, max_new_tokens=96, num_beams=4)
36
+ return tok.decode(out[0], skip_special_tokens=True).strip()
37
+
38
+ # --- Marian round-trip loader ---
39
+ def load_marian():
40
+ if "en_es" not in CACHE:
41
+ tok1 = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
42
+ mdl1 = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-es").to(DEVICE)
43
+ tok2 = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-en")
44
+ mdl2 = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-es-en").to(DEVICE)
45
+ CACHE["en_es"] = (mdl1, tok1, mdl2, tok2)
46
+ return CACHE["en_es"]
47
+
48
+ def run_roundtrip(sentence: str) -> str:
49
+ mdl1, tok1, mdl2, tok2 = load_marian()
50
+ # English → Spanish
51
+ inputs = tok1(sentence, return_tensors="pt").to(DEVICE)
52
+ es_tokens = mdl1.generate(**inputs, max_length=128, num_beams=4)
53
+ spanish = tok1.decode(es_tokens[0], skip_special_tokens=True)
54
+ # Spanish → English
55
+ inputs2 = tok2(spanish, return_tensors="pt").to(DEVICE)
56
+ en_tokens = mdl2.generate(**inputs2, max_length=128, num_beams=4)
57
+ english = tok2.decode(en_tokens[0], skip_special_tokens=True)
58
+ return english.strip()
59
+
60
+ # --- Dispatcher ---
61
+ def polish(sentence: str, choice: str) -> str:
62
+ if not sentence.strip():
63
+ return ""
64
+ if choice.startswith("FLAN"):
65
+ return run_flan(sentence)
66
+ elif choice.startswith("Round-trip"):
67
+ return run_roundtrip(sentence)
68
+ else:
69
+ return "Unknown option."
70
+
71
+ # --- Gradio UI ---
72
+ with gr.Blocks(title="English Grammar Polisher") as demo:
73
+ gr.Markdown("### English Grammar Polisher\nChoose FLAN-T5 (Google) or OPUS-MT round-trip (Helsinki-NLP).")
74
+ inp = gr.Textbox(lines=3, label="Input (English)", placeholder="Type a sentence…")
75
+ choice = gr.Dropdown(choices=MODEL_OPTIONS, value="FLAN-T5-base (Google en→en)", label="Method")
76
+ btn = gr.Button("Polish")
77
+ out = gr.Textbox(label="Output")
78
+ btn.click(polish, inputs=[inp, choice], outputs=out)
79
 
80
  if __name__ == "__main__":
81
  demo.launch()