AItool commited on
Commit
f6944de
·
verified ·
1 Parent(s): 7291080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -46
app.py CHANGED
@@ -1,59 +1,33 @@
1
- # app.py
2
-
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
-
6
- # Translation models for English→English correction
7
- MODEL_OPTIONS = {
8
- "Helsinki-NLP/opus-mt-en-en (light, CPU-friendly)": "Helsinki-NLP/opus-mt-en-en",
9
- "facebook/mbart-large-50-many-to-many-mmt (heavier)": "facebook/mbart-large-50-many-to-many-mmt"
10
- }
11
 
12
- # Cache loaded pipelines
13
- loaded_pipelines = {}
 
14
 
15
- def get_pipeline(model_id: str):
16
- if model_id not in loaded_pipelines:
17
- tokenizer = AutoTokenizer.from_pretrained(model_id)
18
- model = AutoModelForSeq2SeqLM.from_pretrained(
19
- model_id,
20
- low_cpu_mem_usage=True,
21
- torch_dtype="auto"
22
- )
23
- pipe = pipeline("translation", model=model, tokenizer=tokenizer, device=-1)
24
- # Warm-up
25
- _ = pipe("This is a test.", max_length=32)
26
- loaded_pipelines[model_id] = pipe
27
- return loaded_pipelines[model_id]
28
 
29
- def polish(sentence: str, model_choice: str) -> str:
30
- model_id = MODEL_OPTIONS[model_choice]
31
- translator = get_pipeline(model_id)
 
 
32
 
33
- # For mbart we need to set language codes
34
- if "mbart" in model_id:
35
- inputs = translator.tokenizer(sentence, return_tensors="pt")
36
- inputs["forced_bos_token_id"] = translator.tokenizer.lang_code_to_id["en_XX"]
37
- out = translator.model.generate(**inputs, max_length=128, num_beams=4)
38
- text = translator.tokenizer.decode(out[0], skip_special_tokens=True)
39
- else:
40
- out = translator(sentence, max_length=128)
41
- text = out[0]["translation_text"]
42
 
43
- return text.strip()
44
 
45
- # Gradio interface
46
  demo = gr.Interface(
47
  fn=polish,
48
- inputs=[
49
- gr.Textbox(lines=2, placeholder="Enter a sentence to correct..."),
50
- gr.Dropdown(choices=list(MODEL_OPTIONS.keys()),
51
- value="Helsinki-NLP/opus-mt-en-en (light, CPU-friendly)",
52
- label="Choose Model")
53
- ],
54
  outputs=gr.Textbox(label="Corrected English"),
55
- title="English→English Grammar Polisher",
56
- description="Uses translation models (Helsinki-NLP opus-mt-en-en and facebook mbart-large-50) to rewrite English sentences into fluent, corrected English."
57
  )
58
 
59
  if __name__ == "__main__":
 
 
 
1
  import gradio as gr
2
+ from transformers import MarianMTModel, MarianTokenizer
 
 
 
 
 
 
3
 
4
+ # Load English→Spanish
5
+ en_es_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-es")
6
+ en_es_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
7
 
8
+ # Load Spanish→English
9
+ es_en_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-es-en")
10
+ es_en_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-en")
 
 
 
 
 
 
 
 
 
 
11
 
12
+ def polish(sentence: str) -> str:
13
+ # Step 1: English → Spanish
14
+ inputs = en_es_tokenizer(sentence, return_tensors="pt", padding=True)
15
+ translated = en_es_model.generate(**inputs, max_length=128)
16
+ spanish = en_es_tokenizer.decode(translated[0], skip_special_tokens=True)
17
 
18
+ # Step 2: Spanish English
19
+ inputs2 = es_en_tokenizer(spanish, return_tensors="pt", padding=True)
20
+ back_translated = es_en_model.generate(**inputs2, max_length=128)
21
+ english = es_en_tokenizer.decode(back_translated[0], skip_special_tokens=True)
 
 
 
 
 
22
 
23
+ return english
24
 
 
25
  demo = gr.Interface(
26
  fn=polish,
27
+ inputs=gr.Textbox(lines=2, placeholder="Enter a sentence in English..."),
 
 
 
 
 
28
  outputs=gr.Textbox(label="Corrected English"),
29
+ title="Round-trip Grammar Polisher",
30
+ description="Uses Helsinki-NLP MarianMT models (en→es→en) to smooth and correct English sentences."
31
  )
32
 
33
  if __name__ == "__main__":