Handwritten_OCR / app.py
imperiusrex's picture
Update app.py
476a469 verified
# Import the GPU decorator for ZeroGPU Spaces
from spaces import GPU
import os
import cv2
import numpy as np
import torch
import tempfile
import shutil
import gradio as gr
from PIL import Image
from pdf2image import convert_from_path
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from paddleocr import PaddleOCR
import logging
# Disable PaddleOCR logging for a cleaner output
logging.disable(logging.WARNING)
# Set the GPU device if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# --- MODEL LOADING ---
# Load models globally so they are only initialized once when the app starts.
# Initialize the PaddleOCR detection model
print("Initializing PaddleOCR text detection model...")
try:
det_model = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=torch.cuda.is_available(), show_log=False)
except Exception as e:
print(f"Error initializing PaddleOCR: {e}")
det_model = None
# Initialize the TrOCR recognition model and processor
print("Initializing TrOCR text recognition model...")
try:
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
trocr_model.eval()
trocr_model.to(device)
except Exception as e:
print(f"Error initializing TrOCR: {e}")
trocr_model = None
trocr_processor = None
# Helper function to save a temp image
def save_temp_image(img):
"""Save an image array to a temporary file and return the path."""
temp_fd, temp_path = tempfile.mkstemp(suffix='.png')
cv2.imwrite(temp_path, img)
os.close(temp_fd)
return temp_path
def process_image_page(img):
"""
Process a single image to detect polygons, crop regions, and recognize text.
Returns a list of [box, text] for each cropped region and the original PIL image.
"""
if det_model is None or trocr_model is None:
raise RuntimeError("OCR models are not loaded. Please check logs for errors.")
# Convert OpenCV image (BGR numpy array) to PIL Image (RGB)
original_pil_image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# PaddleOCR's predict method takes a file path, so we'll save the image to a temp file
temp_image_path = save_temp_image(img)
# Detect polygons using PaddleOCR
ocr_result = det_model.ocr(temp_image_path)
os.remove(temp_image_path)
arr = []
# The OCR result is a list of lists, where each inner list represents a text line.
if ocr_result and ocr_result[0]:
for line in ocr_result[0]:
arr.append(line[0])
print(f"Detected {len(arr)} lines in this page.")
cropped_images = []
for box in arr:
box = np.array(box, dtype=np.float32)
# Compute width and height of the straightened image
width_a = np.linalg.norm(box[0] - box[1])
width_b = np.linalg.norm(box[2] - box[3])
height_a = np.linalg.norm(box[0] - box[3])
height_b = np.linalg.norm(box[1] - box[2])
width = int(max(width_a, width_b))
height = int(max(height_a, height_b))
dst_rect = np.array([
[0, 0],
[width - 1, 0],
[width - 1, height - 1],
[0, height - 1]
], dtype=np.float32)
# Perspective transform
M = cv2.getPerspectiveTransform(box, dst_rect)
warped = cv2.warpPerspective(img, M, (width, height))
cropped_images.append(warped)
# Reverse cropped images and corresponding boxes to match reading order
cropped_images.reverse()
arr.reverse()
# Text recognition with TrOCR
results = []
for i, crop in enumerate(cropped_images):
image_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
pixel_values = trocr_processor(images=image_pil, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
generated_ids = trocr_model.generate(pixel_values, max_new_tokens=64)
generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
results.append([arr[i], generated_text])
print(f"Recognized: {generated_text}")
return results, original_pil_image
def process_file_and_create_pdf(file):
"""
Main function to process a file (image or PDF) and return a path to a new PDF.
This function will ensure the temporary output directory is cleaned up safely.
The @GPU decorator ensures this function is run on the GPU.
"""
if file is None:
return None, None
temp_output_dir = tempfile.mkdtemp()
output_pdf_path = os.path.join(temp_output_dir, "ocr_results.pdf")
input_image_for_display = None
try:
if file.name.lower().endswith('.pdf'):
# Convert PDF to images
print(f"Converting PDF {file.name} to images...")
images = convert_from_path(file.name, dpi=300)
if images:
# Set the first page as the image to display
input_image_for_display = images[0]
c = canvas.Canvas(output_pdf_path, pagesize=letter)
width, height = letter
for page_num, page in enumerate(images):
print(f"\nProcessing page {page_num + 1}")
img_cv = cv2.cvtColor(np.array(page), cv2.COLOR_RGB2BGR)
# Check if the background is dark and text is light (simple heuristic)
gray_image = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
avg_intensity = np.mean(gray_image)
if avg_intensity < 100:
print("Inverting colors for dark background.")
img_cv = cv2.bitwise_not(img_cv)
results, _ = process_image_page(img_cv)
c.setFont("Helvetica-Bold", 14)
c.drawString(50, height - 40, f"Page {page_num + 1} - OCR Results")
y = height - 60
c.setFont("Helvetica", 12)
for _, text in results:
c.drawString(50, y, text)
y -= 15
if y < 50:
c.showPage()
c.setFont("Helvetica-Bold", 14)
c.drawString(50, height - 40, f"Page {page_num + 1} (cont.) - OCR Results")
y = height - 60
c.showPage()
c.save()
else: # Handle single image file
img_cv = cv2.imread(file.name)
if img_cv is None:
raise ValueError("Failed to load image.")
input_image_for_display = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
# Check if the background is dark and text is light (simple heuristic)
gray_image = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
avg_intensity = np.mean(gray_image)
if avg_intensity < 100:
print("Inverting colors for dark background.")
img_cv = cv2.bitwise_not(img_cv)
results, original_image = process_image_page(img_cv)
c = canvas.Canvas(output_pdf_path, pagesize=letter)
width, height = letter
c.setFont("Helvetica-Bold", 14)
c.drawString(50, height - 40, "Image OCR Results")
temp_img_path = os.path.join(temp_output_dir, "original_image.png")
original_image.save(temp_img_path)
c.drawImage(temp_img_path, 50, height - 300, width=200, preserveAspectRatio=True)
os.remove(temp_img_path)
y = height - 350
c.setFont("Helvetica", 12)
for _, text in results:
c.drawString(50, y, text)
y -= 15
if y < 50:
c.showPage()
c.setFont("Helvetica", 12)
y = height - 50
c.save()
print(f"Generated PDF path: {output_pdf_path}")
return output_pdf_path, input_image_for_display
except Exception as e:
print(f"An error occurred: {e}")
# Return None, None on error
return None, None
finally:
# Ensure temporary directory is cleaned up after the function returns
if os.path.exists(temp_output_dir):
print(f"Cleaning up temporary directory: {temp_output_dir}")
shutil.rmtree(temp_output_dir)
# Gradio Interface
@GPU
def process_file_for_gradio(image_file, pdf_file):
"""
Wrapper function for Gradio interface with separate inputs.
This function checks which input was provided and calls the main
processing logic accordingly.
"""
if image_file is not None:
# The gr.Image component returns a PIL Image object
# We need to save it to a temporary file for the main function
temp_dir = tempfile.mkdtemp()
image_path = os.path.join(temp_dir, "uploaded_image.png")
image_file.save(image_path)
# Create a mock file object to be compatible with the main function
class MockFile:
def __init__(self, name):
self.name = name
mock_file = MockFile(image_path)
output_path, input_image = process_file_and_create_pdf(mock_file)
shutil.rmtree(temp_dir)
return output_path, input_image
elif pdf_file is not None:
# The gr.File component passes a temporary file object directly
output_path, input_image = process_file_and_create_pdf(pdf_file)
return output_path, input_image
else:
return None, None
demo = gr.Interface(
fn=process_file_for_gradio,
inputs=[
gr.Image(label="Upload an Image", type="pil"),
gr.File(label="Upload a PDF", file_types=['.pdf'])
],
outputs=[
gr.File(label="Download OCR Results PDF", interactive=False, visible=True),
gr.Image(label="Uploaded File Preview", interactive=False)
],
title="OCR App with PaddleOCR and TrOCR",
description="Upload an image or a multi-page PDF to get an output PDF with the recognized text from each page. The output PDF will be downloaded automatically.",
)
if __name__ == "__main__":
demo.launch()