File size: 2,095 Bytes
7c83af9 f52a8f9 7c83af9 f52a8f9 7c83af9 c3918a0 7c83af9 c3918a0 7c83af9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
# app.py
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# Only the official Google FLAN-T5 models
MODEL_OPTIONS = {
"FLAN-T5-small (Google)": "google/flan-t5-small",
"FLAN-T5-base (Google)": "google/flan-t5-base"
}
# Cache loaded pipelines
loaded_pipelines = {}
def get_pipeline(model_id: str):
if model_id not in loaded_pipelines:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
low_cpu_mem_usage=True, # CPU optimization
torch_dtype="auto"
)
pipe = pipeline("text2text-generation",
model=model,
tokenizer=tokenizer,
device=-1)
# Warm-up to avoid first-call lag
_ = pipe("Correct the grammar: test", max_new_tokens=8, do_sample=False)
loaded_pipelines[model_id] = pipe
return loaded_pipelines[model_id]
def oxford_polish(sentence: str, model_choice: str) -> str:
model_id = MODEL_OPTIONS[model_choice]
polisher = get_pipeline(model_id)
# Minimal prompt for FLAN-T5
prompt = f"You are an English grammar corrector and teacher. Return the corrected version: {sentence}"
out = polisher(prompt,
max_new_tokens=60,
do_sample=False,
num_beams=2)
text = out[0]["generated_text"].strip()
# Strip accidental echo
if text.startswith(prompt):
text = text[len(prompt):].strip()
return text
# Gradio interface
demo = gr.Interface(
fn=oxford_polish,
inputs=[
gr.Textbox(lines=2, placeholder="Enter a sentence to correct..."),
gr.Dropdown(choices=list(MODEL_OPTIONS.keys()),
value="FLAN-T5-base (Google)",
label="Choose Model")
],
outputs=gr.Textbox(label="Oxford-grammar Correction"),
title="Oxford Grammar Polisher",
description="Compare Google’s official FLAN-T5 small and base models for grammar correction."
)
demo.launch()
|