Spaces:
Sleeping
Sleeping
Do recognition and decoding in batch to speedup
Browse files
app.py
CHANGED
|
@@ -15,9 +15,10 @@ def process(path, progress = gr.Progress()):
|
|
| 15 |
# Load the model and processor
|
| 16 |
processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
|
| 17 |
model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
|
| 18 |
-
|
| 19 |
# Open an image of handwritten text
|
| 20 |
image = Image.open(path).convert("RGB")
|
|
|
|
| 21 |
progress(0, desc="Extracting Text Lines")
|
| 22 |
try:
|
| 23 |
# Load the trained line detection model
|
|
@@ -25,23 +26,29 @@ def process(path, progress = gr.Progress()):
|
|
| 25 |
line_model = YOLO(cached_model_path)
|
| 26 |
except Exception as e:
|
| 27 |
print('Failed to load the line detection model: %s' % e)
|
| 28 |
-
|
| 29 |
results = line_model.predict(source = image)[0]
|
| 30 |
-
full_text = ""
|
| 31 |
boxes = results.boxes.xyxy
|
| 32 |
indices = boxes[:,1].sort().indices
|
| 33 |
boxes = boxes[indices]
|
| 34 |
-
|
|
|
|
| 35 |
#box = box + torch.tensor([-10,0, 10, 0])
|
| 36 |
box = [tensor.item() for tensor in box]
|
| 37 |
lineImg = image.crop(tuple(list(box)))
|
| 38 |
|
| 39 |
-
# Preprocess
|
| 40 |
pixel_values = processor(lineImg, return_tensors="pt").pixel_values
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
|
|
|
| 45 |
fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH)
|
| 46 |
fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100)
|
| 47 |
fixed_text = fixed_text[0]['generated_text']
|
|
|
|
| 15 |
# Load the model and processor
|
| 16 |
processor = TrOCRProcessor.from_pretrained(OCR_MODEL_PATH)
|
| 17 |
model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL_PATH)
|
| 18 |
+
|
| 19 |
# Open an image of handwritten text
|
| 20 |
image = Image.open(path).convert("RGB")
|
| 21 |
+
|
| 22 |
progress(0, desc="Extracting Text Lines")
|
| 23 |
try:
|
| 24 |
# Load the trained line detection model
|
|
|
|
| 26 |
line_model = YOLO(cached_model_path)
|
| 27 |
except Exception as e:
|
| 28 |
print('Failed to load the line detection model: %s' % e)
|
| 29 |
+
|
| 30 |
results = line_model.predict(source = image)[0]
|
|
|
|
| 31 |
boxes = results.boxes.xyxy
|
| 32 |
indices = boxes[:,1].sort().indices
|
| 33 |
boxes = boxes[indices]
|
| 34 |
+
batch = []
|
| 35 |
+
for box in progress.tqdm(boxes, desc="Preprocessing"):
|
| 36 |
#box = box + torch.tensor([-10,0, 10, 0])
|
| 37 |
box = [tensor.item() for tensor in box]
|
| 38 |
lineImg = image.crop(tuple(list(box)))
|
| 39 |
|
| 40 |
+
# Preprocess
|
| 41 |
pixel_values = processor(lineImg, return_tensors="pt").pixel_values
|
| 42 |
+
batch.append(pixel_values)
|
| 43 |
+
|
| 44 |
+
#Predict and decode the entire batch
|
| 45 |
+
progress(0, desc="Recognizing..")
|
| 46 |
+
generated_ids = model.generate(torch.cat(batch))
|
| 47 |
+
progress(0, desc="Decoding (token -> str)")
|
| 48 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 49 |
+
full_text = " ".join(generated_text)
|
| 50 |
|
| 51 |
+
progress(0, desc="Correction..")
|
| 52 |
fix_spelling = pipeline("text2text-generation",model=CORRECTOR_PATH)
|
| 53 |
fixed_text = fix_spelling(full_text, max_new_tokens=len(full_text)+100)
|
| 54 |
fixed_text = fixed_text[0]['generated_text']
|