Spaces:
Runtime error
Runtime error
| # --- Setup --- | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| import cv2 | |
| from paddleocr import TextDetection | |
| from huggingface_hub import spaces | |
| import time | |
| # Request H200 GPU | |
| spaces.GPU.require("H200") | |
| # --- Model Load --- | |
| MODEL_HUB_ID = "imperiusrex/Handwritten_model" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| processor = TrOCRProcessor.from_pretrained(MODEL_HUB_ID) | |
| model = VisionEncoderDecoderModel.from_pretrained(MODEL_HUB_ID) | |
| model.to(device) | |
| model.eval() | |
| ocr_det_model = TextDetection(model_name="PP-OCRv5_server_det") | |
| # --- Core OCR Function --- | |
| def recognize_handwritten_text_from_npimg(np_img): | |
| pil_img = Image.fromarray(np_img.astype(np.uint8)).convert("RGB") | |
| image_np = np.array(pil_img) | |
| detection_results = ocr_det_model.predict(image_np, batch_size=1) | |
| detected_polys = [] | |
| for res in detection_results: | |
| polys = res.get('dt_polys', []) | |
| if polys is not None: | |
| detected_polys.extend(polys.tolist()) | |
| cropped_images = [] | |
| if detected_polys: | |
| for box in detected_polys: | |
| box = np.array(box, dtype=np.float32) | |
| width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3]))) | |
| height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2]))) | |
| dst_rect = np.array([ | |
| [0, 0], | |
| [width - 1, 0], | |
| [width - 1, height - 1], | |
| [0, height - 1] | |
| ], dtype=np.float32) | |
| M = cv2.getPerspectiveTransform(box, dst_rect) | |
| warped = cv2.warpPerspective(image_np, M, (width, height)) | |
| cropped_images.append(Image.fromarray(warped).convert("RGB")) | |
| cropped_images.reverse() | |
| recognized_texts = [] | |
| if cropped_images: | |
| for crop_img in cropped_images: | |
| pixel_values = processor(images=crop_img, return_tensors="pt").pixel_values.to(device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate(pixel_values, max_new_tokens=64) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| recognized_texts.append(generated_text) | |
| else: | |
| pixel_values = processor(images=pil_img, return_tensors="pt").pixel_values.to(device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate(pixel_values, max_new_tokens=64) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| recognized_texts.append("No text boxes detected. Full image OCR:\n" + generated_text) | |
| return "\n".join(recognized_texts) | |
| # --- Interface Function --- | |
| def ocr_from_canvas(img): | |
| if img is None: | |
| return "Draw something to see OCR output." | |
| np_img = np.array(img) | |
| try: | |
| result = recognize_handwritten_text_from_npimg(np_img) | |
| except Exception as e: | |
| result = f"[OCR error: {e}]" | |
| return result | |
| # --- UI Layout --- | |
| with gr.Blocks(css=".gr-textbox textarea { font-family: monospace; font-size: 16px; }") as demo: | |
| gr.Markdown("<h1>π Real-Time Handwriting OCR Canvas</h1>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| canvas = gr.ImageEditor( | |
| label="Draw here (freehand, line, shapes)", | |
| type="numpy", | |
| tool="freedraw", | |
| width=600, | |
| height=400, | |
| brush=gr.Brush(color="#000000", size=3), | |
| background="#FFFFFF" | |
| ) | |
| gr.Markdown( | |
| """ | |
| - Use the canvas tools to draw freely, lines, rectangles, etc. | |
| - You can adjust stroke width, brush color, and background color. | |
| - The OCR will trigger every 4 seconds or when you draw. | |
| """ | |
| ) | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="π§ OCR Output", | |
| lines=12, | |
| max_lines=20, | |
| interactive=False, | |
| ) | |
| # Trigger OCR on change | |
| canvas.change(fn=ocr_from_canvas, inputs=canvas, outputs=output_text) | |
| demo.launch() | |