# --- Setup --- import gradio as gr import numpy as np from PIL import Image import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel import cv2 from paddleocr import TextDetection from huggingface_hub import spaces import time # Request H200 GPU spaces.GPU.require("H200") # --- Model Load --- MODEL_HUB_ID = "imperiusrex/Handwritten_model" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") processor = TrOCRProcessor.from_pretrained(MODEL_HUB_ID) model = VisionEncoderDecoderModel.from_pretrained(MODEL_HUB_ID) model.to(device) model.eval() ocr_det_model = TextDetection(model_name="PP-OCRv5_server_det") # --- Core OCR Function --- def recognize_handwritten_text_from_npimg(np_img): pil_img = Image.fromarray(np_img.astype(np.uint8)).convert("RGB") image_np = np.array(pil_img) detection_results = ocr_det_model.predict(image_np, batch_size=1) detected_polys = [] for res in detection_results: polys = res.get('dt_polys', []) if polys is not None: detected_polys.extend(polys.tolist()) cropped_images = [] if detected_polys: for box in detected_polys: box = np.array(box, dtype=np.float32) width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3]))) height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2]))) dst_rect = np.array([ [0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1] ], dtype=np.float32) M = cv2.getPerspectiveTransform(box, dst_rect) warped = cv2.warpPerspective(image_np, M, (width, height)) cropped_images.append(Image.fromarray(warped).convert("RGB")) cropped_images.reverse() recognized_texts = [] if cropped_images: for crop_img in cropped_images: pixel_values = processor(images=crop_img, return_tensors="pt").pixel_values.to(device) with torch.no_grad(): generated_ids = model.generate(pixel_values, max_new_tokens=64) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] recognized_texts.append(generated_text) else: pixel_values = processor(images=pil_img, return_tensors="pt").pixel_values.to(device) with torch.no_grad(): generated_ids = model.generate(pixel_values, max_new_tokens=64) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] recognized_texts.append("No text boxes detected. Full image OCR:\n" + generated_text) return "\n".join(recognized_texts) # --- Interface Function --- def ocr_from_canvas(img): if img is None: return "Draw something to see OCR output." np_img = np.array(img) try: result = recognize_handwritten_text_from_npimg(np_img) except Exception as e: result = f"[OCR error: {e}]" return result # --- UI Layout --- with gr.Blocks(css=".gr-textbox textarea { font-family: monospace; font-size: 16px; }") as demo: gr.Markdown("

📝 Real-Time Handwriting OCR Canvas

") with gr.Row(): with gr.Column(): canvas = gr.ImageEditor( label="Draw here (freehand, line, shapes)", type="numpy", tool="freedraw", width=600, height=400, brush=gr.Brush(color="#000000", size=3), background="#FFFFFF" ) gr.Markdown( """ - Use the canvas tools to draw freely, lines, rectangles, etc. - You can adjust stroke width, brush color, and background color. - The OCR will trigger every 4 seconds or when you draw. """ ) with gr.Column(): output_text = gr.Textbox( label="🧠 OCR Output", lines=12, max_lines=20, interactive=False, ) # Trigger OCR on change canvas.change(fn=ocr_from_canvas, inputs=canvas, outputs=output_text) demo.launch()