Spaces:
Runtime error
Runtime error
| import uvicorn | |
| from fastapi import File | |
| from fastapi import FastAPI | |
| from fastapi import UploadFile | |
| import torch | |
| import os | |
| import sys | |
| import glob | |
| import transformers | |
| from transformers import AutoTokenizer | |
| from transformers import AutoModelForSeq2SeqLM | |
| print("Loading models...") | |
| app = FastAPI() | |
| device = "cpu" | |
| correction_model_tag = "prithivida/grammar_error_correcter_v1" | |
| correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag) | |
| correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag) | |
| def set_seed(seed): | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| print("Models loaded !") | |
| def read_root(): | |
| return {"Gramformer !"} | |
| def get_correction(input_sentence): | |
| set_seed(1212) | |
| scored_corrected_sentence = correct(input_sentence) | |
| return {"scored_corrected_sentence": scored_corrected_sentence} | |
| def correct(input_sentence, max_candidates=1): | |
| correction_prefix = "gec: " | |
| input_sentence = correction_prefix + input_sentence | |
| input_ids = correction_tokenizer.encode(input_sentence, return_tensors='pt') | |
| input_ids = input_ids.to(device) | |
| preds = correction_model.generate( | |
| input_ids, | |
| do_sample=True, | |
| max_length=128, | |
| top_k=50, | |
| top_p=0.95, | |
| early_stopping=True, | |
| num_return_sequences=max_candidates) | |
| corrected = set() | |
| for pred in preds: | |
| corrected.add(correction_tokenizer.decode(pred, skip_special_tokens=True).strip()) | |
| corrected = list(corrected) | |
| return (corrected[0], 0) #Corrected Sentence, Dummy score | |