File size: 10,447 Bytes
0400c24
631b41e
 
 
 
 
 
 
7f16886
631b41e
 
 
 
 
 
7f16886
0400c24
7f16886
 
0400c24
631b41e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f16886
 
 
631b41e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f16886
631b41e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
affec76
631b41e
 
 
7f16886
631b41e
 
 
7f16886
631b41e
 
 
 
 
 
7f16886
 
 
 
 
631b41e
 
 
 
 
 
7f16886
 
 
 
0400c24
7f16886
 
 
 
631b41e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f16886
 
 
 
 
0400c24
7f16886
 
 
631b41e
 
 
 
 
 
 
 
 
1ae4d91
 
 
 
 
 
 
 
 
 
 
 
 
 
affec76
1ae4d91
 
 
 
affec76
 
 
 
1ae4d91
affec76
1ae4d91
 
 
 
476a469
affec76
476a469
 
 
affec76
476a469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ae4d91
 
 
476a469
 
 
 
1ae4d91
 
476a469
1ae4d91
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
# Import the GPU decorator for ZeroGPU Spaces
from spaces import GPU

import os
import cv2
import numpy as np
import torch
import tempfile
import shutil
import gradio as gr
from PIL import Image
from pdf2image import convert_from_path
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from paddleocr import PaddleOCR
import logging

# Disable PaddleOCR logging for a cleaner output
logging.disable(logging.WARNING)

# Set the GPU device if available
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")

# --- MODEL LOADING ---
# Load models globally so they are only initialized once when the app starts.

# Initialize the PaddleOCR detection model
print("Initializing PaddleOCR text detection model...")
try:
    det_model = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=torch.cuda.is_available(), show_log=False)
except Exception as e:
    print(f"Error initializing PaddleOCR: {e}")
    det_model = None

# Initialize the TrOCR recognition model and processor
print("Initializing TrOCR text recognition model...")
try:
    trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
    trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
    trocr_model.eval()
    trocr_model.to(device)
except Exception as e:
    print(f"Error initializing TrOCR: {e}")
    trocr_model = None
    trocr_processor = None

# Helper function to save a temp image
def save_temp_image(img):
    """Save an image array to a temporary file and return the path."""
    temp_fd, temp_path = tempfile.mkstemp(suffix='.png')
    cv2.imwrite(temp_path, img)
    os.close(temp_fd)
    return temp_path

def process_image_page(img):
    """
    Process a single image to detect polygons, crop regions, and recognize text.
    Returns a list of [box, text] for each cropped region and the original PIL image.
    """
    if det_model is None or trocr_model is None:
        raise RuntimeError("OCR models are not loaded. Please check logs for errors.")

    # Convert OpenCV image (BGR numpy array) to PIL Image (RGB)
    original_pil_image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    
    # PaddleOCR's predict method takes a file path, so we'll save the image to a temp file
    temp_image_path = save_temp_image(img)
    
    # Detect polygons using PaddleOCR
    ocr_result = det_model.ocr(temp_image_path)
    os.remove(temp_image_path)
    
    arr = []
    # The OCR result is a list of lists, where each inner list represents a text line.
    if ocr_result and ocr_result[0]:
        for line in ocr_result[0]:
            arr.append(line[0])

    print(f"Detected {len(arr)} lines in this page.")

    cropped_images = []
    for box in arr:
        box = np.array(box, dtype=np.float32)

        # Compute width and height of the straightened image
        width_a = np.linalg.norm(box[0] - box[1])
        width_b = np.linalg.norm(box[2] - box[3])
        height_a = np.linalg.norm(box[0] - box[3])
        height_b = np.linalg.norm(box[1] - box[2])

        width = int(max(width_a, width_b))
        height = int(max(height_a, height_b))

        dst_rect = np.array([
            [0, 0],
            [width - 1, 0],
            [width - 1, height - 1],
            [0, height - 1]
        ], dtype=np.float32)

        # Perspective transform
        M = cv2.getPerspectiveTransform(box, dst_rect)
        warped = cv2.warpPerspective(img, M, (width, height))
        cropped_images.append(warped)

    # Reverse cropped images and corresponding boxes to match reading order
    cropped_images.reverse()
    arr.reverse()

    # Text recognition with TrOCR
    results = []
    for i, crop in enumerate(cropped_images):
        image_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
        pixel_values = trocr_processor(images=image_pil, return_tensors="pt").pixel_values.to(device)

        with torch.no_grad():
            generated_ids = trocr_model.generate(pixel_values, max_new_tokens=64)
            generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        results.append([arr[i], generated_text])
        print(f"Recognized: {generated_text}")

    return results, original_pil_image

def process_file_and_create_pdf(file):
    """
    Main function to process a file (image or PDF) and return a path to a new PDF.
    This function will ensure the temporary output directory is cleaned up safely.
    The @GPU decorator ensures this function is run on the GPU.
    """
    if file is None:
        return None, None

    temp_output_dir = tempfile.mkdtemp()
    output_pdf_path = os.path.join(temp_output_dir, "ocr_results.pdf")
    input_image_for_display = None

    try:
        if file.name.lower().endswith('.pdf'):
            # Convert PDF to images
            print(f"Converting PDF {file.name} to images...")
            images = convert_from_path(file.name, dpi=300)
            
            if images:
                # Set the first page as the image to display
                input_image_for_display = images[0]
            
            c = canvas.Canvas(output_pdf_path, pagesize=letter)
            width, height = letter
            
            for page_num, page in enumerate(images):
                print(f"\nProcessing page {page_num + 1}")
                img_cv = cv2.cvtColor(np.array(page), cv2.COLOR_RGB2BGR)

                # Check if the background is dark and text is light (simple heuristic)
                gray_image = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
                avg_intensity = np.mean(gray_image)
                if avg_intensity < 100:
                    print("Inverting colors for dark background.")
                    img_cv = cv2.bitwise_not(img_cv)
                
                results, _ = process_image_page(img_cv)

                c.setFont("Helvetica-Bold", 14)
                c.drawString(50, height - 40, f"Page {page_num + 1} - OCR Results")
                
                y = height - 60
                c.setFont("Helvetica", 12)
                for _, text in results:
                    c.drawString(50, y, text)
                    y -= 15
                    if y < 50:
                        c.showPage()
                        c.setFont("Helvetica-Bold", 14)
                        c.drawString(50, height - 40, f"Page {page_num + 1} (cont.) - OCR Results")
                        y = height - 60
                c.showPage()
            c.save()

        else: # Handle single image file
            img_cv = cv2.imread(file.name)
            if img_cv is None:
                raise ValueError("Failed to load image.")
            
            input_image_for_display = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))

            # Check if the background is dark and text is light (simple heuristic)
            gray_image = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
            avg_intensity = np.mean(gray_image)
            if avg_intensity < 100:
                print("Inverting colors for dark background.")
                img_cv = cv2.bitwise_not(img_cv)

            results, original_image = process_image_page(img_cv)

            c = canvas.Canvas(output_pdf_path, pagesize=letter)
            width, height = letter
            c.setFont("Helvetica-Bold", 14)
            c.drawString(50, height - 40, "Image OCR Results")
            
            temp_img_path = os.path.join(temp_output_dir, "original_image.png")
            original_image.save(temp_img_path)
            c.drawImage(temp_img_path, 50, height - 300, width=200, preserveAspectRatio=True)
            os.remove(temp_img_path)
            
            y = height - 350
            c.setFont("Helvetica", 12)
            for _, text in results:
                c.drawString(50, y, text)
                y -= 15
                if y < 50:
                    c.showPage()
                    c.setFont("Helvetica", 12)
                    y = height - 50
            c.save()

        print(f"Generated PDF path: {output_pdf_path}")
        return output_pdf_path, input_image_for_display
    
    except Exception as e:
        print(f"An error occurred: {e}")
        # Return None, None on error
        return None, None
    finally:
        # Ensure temporary directory is cleaned up after the function returns
        if os.path.exists(temp_output_dir):
            print(f"Cleaning up temporary directory: {temp_output_dir}")
            shutil.rmtree(temp_output_dir)

# Gradio Interface
@GPU
def process_file_for_gradio(image_file, pdf_file):
    """
    Wrapper function for Gradio interface with separate inputs.
    This function checks which input was provided and calls the main
    processing logic accordingly.
    """
    if image_file is not None:
        # The gr.Image component returns a PIL Image object
        # We need to save it to a temporary file for the main function
        temp_dir = tempfile.mkdtemp()
        image_path = os.path.join(temp_dir, "uploaded_image.png")
        image_file.save(image_path)
        
        # Create a mock file object to be compatible with the main function
        class MockFile:
            def __init__(self, name):
                self.name = name
        
        mock_file = MockFile(image_path)
        output_path, input_image = process_file_and_create_pdf(mock_file)
        shutil.rmtree(temp_dir)
        return output_path, input_image
    
    elif pdf_file is not None:
        # The gr.File component passes a temporary file object directly
        output_path, input_image = process_file_and_create_pdf(pdf_file)
        return output_path, input_image
    
    else:
        return None, None


demo = gr.Interface(
    fn=process_file_for_gradio,
    inputs=[
        gr.Image(label="Upload an Image", type="pil"),
        gr.File(label="Upload a PDF", file_types=['.pdf'])
    ],
    outputs=[
        gr.File(label="Download OCR Results PDF", interactive=False, visible=True),
        gr.Image(label="Uploaded File Preview", interactive=False)
    ],
    title="OCR App with PaddleOCR and TrOCR",
    description="Upload an image or a multi-page PDF to get an output PDF with the recognized text from each page. The output PDF will be downloaded automatically.",
)

if __name__ == "__main__":
    demo.launch()