File size: 4,254 Bytes
d0a0585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# --- Setup ---
import gradio as gr
import numpy as np
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import cv2
from paddleocr import TextDetection
from huggingface_hub import spaces
import time

# Request H200 GPU
spaces.GPU.require("H200")

# --- Model Load ---
MODEL_HUB_ID = "imperiusrex/Handwritten_model"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = TrOCRProcessor.from_pretrained(MODEL_HUB_ID)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_HUB_ID)
model.to(device)
model.eval()
ocr_det_model = TextDetection(model_name="PP-OCRv5_server_det")

# --- Core OCR Function ---
def recognize_handwritten_text_from_npimg(np_img):
    pil_img = Image.fromarray(np_img.astype(np.uint8)).convert("RGB")
    image_np = np.array(pil_img)
    detection_results = ocr_det_model.predict(image_np, batch_size=1)

    detected_polys = []
    for res in detection_results:
        polys = res.get('dt_polys', [])
        if polys is not None:
            detected_polys.extend(polys.tolist())

    cropped_images = []
    if detected_polys:
        for box in detected_polys:
            box = np.array(box, dtype=np.float32)
            width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3])))
            height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2])))
            dst_rect = np.array([
                [0, 0],
                [width - 1, 0],
                [width - 1, height - 1],
                [0, height - 1]
            ], dtype=np.float32)
            M = cv2.getPerspectiveTransform(box, dst_rect)
            warped = cv2.warpPerspective(image_np, M, (width, height))
            cropped_images.append(Image.fromarray(warped).convert("RGB"))
        cropped_images.reverse()

    recognized_texts = []
    if cropped_images:
        for crop_img in cropped_images:
            pixel_values = processor(images=crop_img, return_tensors="pt").pixel_values.to(device)
            with torch.no_grad():
                generated_ids = model.generate(pixel_values, max_new_tokens=64)
                generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
                recognized_texts.append(generated_text)
    else:
        pixel_values = processor(images=pil_img, return_tensors="pt").pixel_values.to(device)
        with torch.no_grad():
            generated_ids = model.generate(pixel_values, max_new_tokens=64)
            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            recognized_texts.append("No text boxes detected. Full image OCR:\n" + generated_text)

    return "\n".join(recognized_texts)


# --- Interface Function ---
def ocr_from_canvas(img):
    if img is None:
        return "Draw something to see OCR output."
    np_img = np.array(img)
    try:
        result = recognize_handwritten_text_from_npimg(np_img)
    except Exception as e:
        result = f"[OCR error: {e}]"
    return result


# --- UI Layout ---
with gr.Blocks(css=".gr-textbox textarea { font-family: monospace; font-size: 16px; }") as demo:
    gr.Markdown("<h1>πŸ“ Real-Time Handwriting OCR Canvas</h1>")
    
    with gr.Row():
        with gr.Column():
            canvas = gr.ImageEditor(
                label="Draw here (freehand, line, shapes)",
                type="numpy",
                tool="freedraw",
                width=600,
                height=400,
                brush=gr.Brush(color="#000000", size=3),
                background="#FFFFFF"
            )
            gr.Markdown(
                """
                - Use the canvas tools to draw freely, lines, rectangles, etc.
                - You can adjust stroke width, brush color, and background color.
                - The OCR will trigger every 4 seconds or when you draw.
                """
            )

        with gr.Column():
            output_text = gr.Textbox(
                label="🧠 OCR Output",
                lines=12,
                max_lines=20,
                interactive=False,
            )

    # Trigger OCR on change
    canvas.change(fn=ocr_from_canvas, inputs=canvas, outputs=output_text)

demo.launch()