DrawOCR / app.py
imperiusrex's picture
Create app.py
d0a0585 verified
# --- Setup ---
import gradio as gr
import numpy as np
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import cv2
from paddleocr import TextDetection
from huggingface_hub import spaces
import time
# Request H200 GPU
spaces.GPU.require("H200")
# --- Model Load ---
MODEL_HUB_ID = "imperiusrex/Handwritten_model"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = TrOCRProcessor.from_pretrained(MODEL_HUB_ID)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_HUB_ID)
model.to(device)
model.eval()
ocr_det_model = TextDetection(model_name="PP-OCRv5_server_det")
# --- Core OCR Function ---
def recognize_handwritten_text_from_npimg(np_img):
pil_img = Image.fromarray(np_img.astype(np.uint8)).convert("RGB")
image_np = np.array(pil_img)
detection_results = ocr_det_model.predict(image_np, batch_size=1)
detected_polys = []
for res in detection_results:
polys = res.get('dt_polys', [])
if polys is not None:
detected_polys.extend(polys.tolist())
cropped_images = []
if detected_polys:
for box in detected_polys:
box = np.array(box, dtype=np.float32)
width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3])))
height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2])))
dst_rect = np.array([
[0, 0],
[width - 1, 0],
[width - 1, height - 1],
[0, height - 1]
], dtype=np.float32)
M = cv2.getPerspectiveTransform(box, dst_rect)
warped = cv2.warpPerspective(image_np, M, (width, height))
cropped_images.append(Image.fromarray(warped).convert("RGB"))
cropped_images.reverse()
recognized_texts = []
if cropped_images:
for crop_img in cropped_images:
pixel_values = processor(images=crop_img, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_new_tokens=64)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
recognized_texts.append(generated_text)
else:
pixel_values = processor(images=pil_img, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_new_tokens=64)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
recognized_texts.append("No text boxes detected. Full image OCR:\n" + generated_text)
return "\n".join(recognized_texts)
# --- Interface Function ---
def ocr_from_canvas(img):
if img is None:
return "Draw something to see OCR output."
np_img = np.array(img)
try:
result = recognize_handwritten_text_from_npimg(np_img)
except Exception as e:
result = f"[OCR error: {e}]"
return result
# --- UI Layout ---
with gr.Blocks(css=".gr-textbox textarea { font-family: monospace; font-size: 16px; }") as demo:
gr.Markdown("<h1>πŸ“ Real-Time Handwriting OCR Canvas</h1>")
with gr.Row():
with gr.Column():
canvas = gr.ImageEditor(
label="Draw here (freehand, line, shapes)",
type="numpy",
tool="freedraw",
width=600,
height=400,
brush=gr.Brush(color="#000000", size=3),
background="#FFFFFF"
)
gr.Markdown(
"""
- Use the canvas tools to draw freely, lines, rectangles, etc.
- You can adjust stroke width, brush color, and background color.
- The OCR will trigger every 4 seconds or when you draw.
"""
)
with gr.Column():
output_text = gr.Textbox(
label="🧠 OCR Output",
lines=12,
max_lines=20,
interactive=False,
)
# Trigger OCR on change
canvas.change(fn=ocr_from_canvas, inputs=canvas, outputs=output_text)
demo.launch()