imperiusrex commited on
Commit
631b41e
·
verified ·
1 Parent(s): 5506faf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import the GPU decorator for ZeroGPU Spaces
2
+ # This will be a no-op if the space is not configured for ZeroGPU
3
+ # but it is required for the specified hardware to work correctly.
4
+ from spaces import GPU
5
+
6
+ import os
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ import tempfile
11
+ import gradio as gr
12
+ from PIL import Image
13
+ from pdf2image import convert_from_path
14
+ from reportlab.lib.pagesizes import letter
15
+ from reportlab.pdfgen import canvas
16
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
17
+ from paddleocr import PaddleOCR, TextDetection
18
+
19
+ # Set the GPU device if available
20
+ # The `spaces.GPU` decorator handles the dynamic GPU allocation, but we still need to
21
+ # specify the device for PyTorch and other GPU-enabled libraries.
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ print(f"Using device: {device}")
25
+
26
+ # --- MODEL LOADING ---
27
+ # Load models globally so they are only initialized once when the app starts.
28
+
29
+ # Initialize the PaddleOCR detection model
30
+ # `use_angle_cls=False` is set for efficiency, as we are already using
31
+ # perspective warping to straighten the text.
32
+ print("Initializing PaddleOCR text detection model...")
33
+ try:
34
+ # Use the PaddleOCR class with a specific model for detection only
35
+ det_model = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=torch.cuda.is_available(), show_log=False)
36
+ except Exception as e:
37
+ print(f"Error initializing PaddleOCR: {e}")
38
+ det_model = None
39
+
40
+ # Initialize the TrOCR recognition model and processor
41
+ print("Initializing TrOCR text recognition model...")
42
+ try:
43
+ trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
44
+ trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
45
+ trocr_model.eval()
46
+ trocr_model.to(device)
47
+ except Exception as e:
48
+ print(f"Error initializing TrOCR: {e}")
49
+ trocr_model = None
50
+ trocr_processor = None
51
+
52
+ # Helper function to save a temp image
53
+ def save_temp_image(img):
54
+ """Save an image array to a temporary file and return the path."""
55
+ temp_fd, temp_path = tempfile.mkstemp(suffix='.png')
56
+ cv2.imwrite(temp_path, img)
57
+ os.close(temp_fd)
58
+ return temp_path
59
+
60
+ def process_image_page(img):
61
+ """
62
+ Process a single image to detect polygons, crop regions, and recognize text.
63
+ Returns a list of [box, text] for each cropped region and the original PIL image.
64
+ """
65
+ if det_model is None or trocr_model is None:
66
+ raise RuntimeError("OCR models are not loaded. Please check logs for errors.")
67
+
68
+ # Convert OpenCV image (BGR numpy array) to PIL Image (RGB)
69
+ original_pil_image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
70
+
71
+ # PaddleOCR's predict method takes a file path, so we'll save the image to a temp file
72
+ temp_image_path = save_temp_image(img)
73
+
74
+ # Detect polygons using PaddleOCR
75
+ # The `ocr` method in PaddleOCR returns both detection and recognition results.
76
+ # We will use it just for the detection polygons.
77
+ ocr_result = det_model.ocr(temp_image_path)
78
+ os.remove(temp_image_path)
79
+
80
+ arr = []
81
+ # The OCR result is a list of lists, where each inner list represents a text line.
82
+ # The first element is the bounding box coordinates.
83
+ for line in ocr_result[0]:
84
+ arr.append(line[0])
85
+
86
+ print(f"Detected {len(arr)} lines in this page.")
87
+
88
+ cropped_images = []
89
+ for box in arr:
90
+ box = np.array(box, dtype=np.float32)
91
+
92
+ # Compute width and height of the straightened image
93
+ width_a = np.linalg.norm(box[0] - box[1])
94
+ width_b = np.linalg.norm(box[2] - box[3])
95
+ height_a = np.linalg.norm(box[0] - box[3])
96
+ height_b = np.linalg.norm(box[1] - box[2])
97
+
98
+ width = int(max(width_a, width_b))
99
+ height = int(max(height_a, height_b))
100
+
101
+ dst_rect = np.array([
102
+ [0, 0],
103
+ [width - 1, 0],
104
+ [width - 1, height - 1],
105
+ [0, height - 1]
106
+ ], dtype=np.float32)
107
+
108
+ # Perspective transform
109
+ M = cv2.getPerspectiveTransform(box, dst_rect)
110
+ warped = cv2.warpPerspective(img, M, (width, height))
111
+ cropped_images.append(warped)
112
+
113
+ # Reverse cropped images and corresponding boxes
114
+ cropped_images.reverse()
115
+ arr.reverse()
116
+
117
+ # Text recognition with TrOCR
118
+ results = []
119
+ for i, crop in enumerate(cropped_images):
120
+ image_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
121
+ pixel_values = trocr_processor(images=image_pil, return_tensors="pt").pixel_values.to(device)
122
+
123
+ with torch.no_grad():
124
+ generated_ids = trocr_model.generate(pixel_values, max_new_tokens=64)
125
+ generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
126
+
127
+ results.append([arr[i], generated_text])
128
+ print(f"Recognized: {generated_text}")
129
+
130
+ return results, original_pil_image
131
+
132
+ def process_file_and_create_pdf(file):
133
+ """
134
+ Main function to process a file (image or PDF) and return a path to a new PDF.
135
+ The @GPU decorator ensures this function is run on the GPU.
136
+ """
137
+ if file is None:
138
+ return None, "Please upload a file."
139
+
140
+ temp_output_dir = tempfile.mkdtemp()
141
+ output_pdf_path = os.path.join(temp_output_dir, "ocr_results.pdf")
142
+
143
+ try:
144
+ if file.name.lower().endswith('.pdf'):
145
+ # Convert PDF to images
146
+ print(f"Converting PDF {file.name} to images...")
147
+ # Use `poppler_path` if poppler is installed on the system, otherwise
148
+ # it might be necessary to install it via a `packages.txt` file.
149
+ # Here we assume it's available.
150
+ images = convert_from_path(file.name, dpi=300)
151
+
152
+ c = canvas.Canvas(output_pdf_path, pagesize=letter)
153
+ width, height = letter
154
+
155
+ for page_num, page in enumerate(images):
156
+ print(f"\nProcessing page {page_num + 1}")
157
+ img_cv = cv2.cvtColor(np.array(page), cv2.COLOR_RGB2BGR)
158
+ results, original_image = process_image_page(img_cv)
159
+
160
+ c.setFont("Helvetica-Bold", 14)
161
+ c.drawString(50, height - 40, f"Page {page_num + 1} - OCR Results")
162
+
163
+ y = height - 60
164
+ c.setFont("Helvetica", 12)
165
+ for _, text in results:
166
+ c.drawString(50, y, text)
167
+ y -= 15
168
+ if y < 50:
169
+ c.showPage()
170
+ c.setFont("Helvetica-Bold", 14)
171
+ c.drawString(50, height - 40, f"Page {page_num + 1} (cont.) - OCR Results")
172
+ y = height - 60
173
+ c.showPage()
174
+ c.save()
175
+
176
+ else: # Handle single image file
177
+ img_cv = cv2.imread(file.name)
178
+ if img_cv is None:
179
+ raise ValueError("Failed to load image.")
180
+
181
+ results, original_image = process_image_page(img_cv)
182
+
183
+ c = canvas.Canvas(output_pdf_path, pagesize=letter)
184
+ width, height = letter
185
+ c.setFont("Helvetica-Bold", 14)
186
+ c.drawString(50, height - 40, "Image OCR Results")
187
+
188
+ # The input file from Gradio is a temp file that will be cleaned up.
189
+ # We can't display it directly in the PDF from its path.
190
+ # To draw it in the PDF, we save it to a new temporary path.
191
+ temp_img_path = os.path.join(temp_output_dir, "original_image.png")
192
+ original_image.save(temp_img_path)
193
+
194
+ # Draw the image on the PDF
195
+ c.drawImage(temp_img_path, 50, height - 300, width=200, preserveAspectRatio=True)
196
+
197
+ y = height - 350
198
+ c.setFont("Helvetica", 12)
199
+ for _, text in results:
200
+ c.drawString(50, y, text)
201
+ y -= 15
202
+ if y < 50:
203
+ c.showPage()
204
+ c.setFont("Helvetica", 12)
205
+ y = height - 50
206
+ c.save()
207
+ os.remove(temp_img_path)
208
+
209
+ return output_pdf_path
210
+
211
+ except Exception as e:
212
+ print(f"An error occurred: {e}")
213
+ # Clean up temporary directory on error
214
+ # shutil.rmtree(temp_output_dir)
215
+ return None
216
+
217
+ # Gradio Interface
218
+ # The `@GPU` decorator is used here to ensure this function runs on a GPU.
219
+ @GPU
220
+ def process_file_for_gradio(file):
221
+ # This wrapper function is needed because Gradio's `File` component passes a temp file.
222
+ # We call our main processing function and return the path to the output PDF.
223
+ output_path = process_file_and_create_pdf(file)
224
+ if output_path is None:
225
+ return None
226
+ return output_path
227
+
228
+ demo = gr.Interface(
229
+ fn=process_file_for_gradio,
230
+ inputs=gr.File(label="Upload an Image (PNG, JPG) or a PDF", file_types=['.png', '.jpg', '.jpeg', '.pdf']),
231
+ outputs=gr.File(label="Download OCR Results PDF"),
232
+ title="OCR App with PaddleOCR and TrOCR",
233
+ description="Upload an image or a multi-page PDF to get an output PDF with the recognized text from each page.",
234
+ examples=[
235
+ # Here you can provide paths to example files in your repo
236
+ # "example.png",
237
+ # "example.pdf"
238
+ ]
239
+ )
240
+
241
+ if __name__ == "__main__":
242
+ # You will need to set the hardware configuration in the `README.md` file
243
+ # of your Hugging Face Space for the GPU to be available.
244
+ demo.launch()