|
|
|
|
|
|
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
|
|
|
|
|
|
|
MODEL_OPTIONS = { |
|
|
"FLAN-T5-small (Google)": "google/flan-t5-small", |
|
|
"FLAN-T5-base (Google)": "google/flan-t5-base" |
|
|
} |
|
|
|
|
|
|
|
|
loaded_pipelines = {} |
|
|
|
|
|
def get_pipeline(model_id: str): |
|
|
if model_id not in loaded_pipelines: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
model_id, |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype="auto" |
|
|
) |
|
|
pipe = pipeline("text2text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device=-1) |
|
|
|
|
|
_ = pipe("Correct the grammar: test", max_new_tokens=8, do_sample=False) |
|
|
loaded_pipelines[model_id] = pipe |
|
|
return loaded_pipelines[model_id] |
|
|
|
|
|
def oxford_polish(sentence: str, model_choice: str) -> str: |
|
|
model_id = MODEL_OPTIONS[model_choice] |
|
|
polisher = get_pipeline(model_id) |
|
|
|
|
|
|
|
|
prompt = f"You are an English grammar corrector and teacher. Return the corrected version: {sentence}" |
|
|
out = polisher(prompt, |
|
|
max_new_tokens=60, |
|
|
do_sample=False, |
|
|
num_beams=2) |
|
|
text = out[0]["generated_text"].strip() |
|
|
|
|
|
|
|
|
if text.startswith(prompt): |
|
|
text = text[len(prompt):].strip() |
|
|
return text |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=oxford_polish, |
|
|
inputs=[ |
|
|
gr.Textbox(lines=2, placeholder="Enter a sentence to correct..."), |
|
|
gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), |
|
|
value="FLAN-T5-base (Google)", |
|
|
label="Choose Model") |
|
|
], |
|
|
outputs=gr.Textbox(label="Oxford-grammar Correction"), |
|
|
title="Oxford Grammar Polisher", |
|
|
description="Compare Google’s official FLAN-T5 small and base models for grammar correction." |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|