Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| from huggingface_hub import hf_hub_download | |
| from transformers import pipeline | |
| from ultralytics import YOLO | |
| from PIL import Image | |
| def process(path, progress = gr.Progress(), device = 'cpu'): | |
| progress(0, desc="Starting") | |
| LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection" | |
| OCR_MODEL_PATH = "microsoft/trocr-large-handwritten" | |
| CORRECTOR_PATH = "oliverguhr/spelling-correction-english-base" | |
| # Load the model and processor | |
| processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH) | |
| model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH) | |
| model.to(device) | |
| # Open an image of handwritten text | |
| image = Image.open(path).convert("RGB") | |
| progress(0, desc="Extracting Text Lines") | |
| try: | |
| # Load the trained line detection model | |
| cached_model_path = hf_hub_download(repo_id = LINE_MODEL_PATH, filename="lines_20240827.pt") | |
| line_model = YOLO(cached_model_path) | |
| except Exception as e: | |
| print('Failed to load the line detection model: %s' % e) | |
| results = line_model.predict(source = image)[0] | |
| boxes = results.boxes.xyxy | |
| indices = boxes[:,1].sort().indices | |
| boxes = boxes[indices] | |
| batch = [] | |
| for box in progress.tqdm(boxes, desc="Preprocessing"): | |
| #box = box + torch.tensor([-10,0, 10, 0]) | |
| box = [tensor.item() for tensor in box] | |
| lineImg = image.crop(tuple(list(box))) | |
| # Preprocess | |
| pixel_values = processor(lineImg, return_tensors="pt").pixel_values | |
| batch.append(pixel_values) | |
| #Predict and decode the entire batch | |
| progress(0, desc="Recognizing..") | |
| batch = torch.cat(batch).to(device) | |
| print("batch.shape", batch.shape) | |
| generated_ids = model.generate(batch) | |
| progress(0, desc="Decoding (token -> str)") | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| print(generated_text) | |
| full_text = " ".join(generated_text) | |
| print(full_text) | |
| progress(0, desc="Correction..") | |
| fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH) | |
| fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100) | |
| fixed_text = fixed_text[0]['generated_text'] | |
| return fixed_text | |
| if __name__ == "__main__": | |
| demo = gr.Interface(fn=process, inputs=gr.Image(type="filepath"), outputs="text") | |
| demo.launch() |