VanguardAI commited on
Commit
0961ac0
·
verified ·
1 Parent(s): 4a32a6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -1398
app.py CHANGED
@@ -1,425 +1,126 @@
1
  import spaces
2
- import json
3
- import math
4
- import os
5
- import traceback
6
- from io import BytesIO
7
- from typing import Any, Dict, List, Optional, Tuple
8
- import re
9
-
10
- import fitz # PyMuPDF
11
  import gradio as gr
12
- import requests
13
  import torch
14
- from huggingface_hub import snapshot_download
15
- from PIL import Image, ImageDraw, ImageFont
16
  from qwen_vl_utils import process_vision_info
17
- from transformers import AutoModelForCausalLM, AutoProcessor
18
- import numpy as np
19
-
20
- # Import Arabic text correction module
21
- from arabic_corrector import get_corrector
22
 
23
  # ========================================
24
- # DETERMINISTIC SETTINGS FOR CONSISTENCY
25
  # ========================================
26
- # Set seeds for reproducibility - ensures same image always gives same output
27
- torch.manual_seed(42)
28
- torch.cuda.manual_seed_all(42)
29
- np.random.seed(42)
30
-
31
- # Ensure deterministic behavior in PyTorch operations
32
- torch.backends.cudnn.deterministic = True
33
- torch.backends.cudnn.benchmark = False
34
-
35
- # Constants
36
- MIN_PIXELS = 3136
37
- MAX_PIXELS = 11289600
38
- IMAGE_FACTOR = 28
39
 
40
- # Prompts
41
- prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
42
 
43
- 1. Bbox format: [x1, y1, x2, y2]
44
-
45
- 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
46
-
47
- 3. Text Extraction & Formatting Rules:
48
- - Picture: For the 'Picture' category, the text field should be omitted.
49
- - Formula: Format its text as LaTeX.
50
- - Table: Format its text as HTML.
51
- - All Others (Text, Title, etc.): Format their text as Markdown.
52
-
53
- 4. Constraints:
54
- - The output text must be the original text from the image, with no translation.
55
- - All layout elements must be sorted according to human reading order.
56
-
57
- 5. Final Output: The entire output must be a single JSON object.
58
- """
59
-
60
- # Utility functions
61
- def round_by_factor(number: int, factor: int) -> int:
62
- """Returns the closest integer to 'number' that is divisible by 'factor'."""
63
- return round(number / factor) * factor
64
 
 
 
65
 
66
- def smart_resize(
67
- height: int,
68
- width: int,
69
- factor: int = 28,
70
- min_pixels: int = 3136,
71
- max_pixels: int = 11289600,
72
- ):
73
- """Rescales the image so that the following conditions are met:
74
- 1. Both dimensions (height and width) are divisible by 'factor'.
75
- 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
76
- 3. The aspect ratio of the image is maintained as closely as possible.
77
- """
78
- if max(height, width) / min(height, width) > 200:
79
- raise ValueError(
80
- f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
81
- )
82
- h_bar = max(factor, round_by_factor(height, factor))
83
- w_bar = max(factor, round_by_factor(width, factor))
84
 
85
- if h_bar * w_bar > max_pixels:
86
- beta = math.sqrt((height * width) / max_pixels)
87
- h_bar = round_by_factor(height / beta, factor)
88
- w_bar = round_by_factor(width / beta, factor)
89
- elif h_bar * w_bar < min_pixels:
90
- beta = math.sqrt(min_pixels / (height * width))
91
- h_bar = round_by_factor(height * beta, factor)
92
- w_bar = round_by_factor(width * beta, factor)
93
- return h_bar, w_bar
94
 
95
 
96
- def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
97
- """Fetch and process an image"""
98
- if isinstance(image_input, str):
99
- if image_input.startswith(("http://", "https://")):
100
- response = requests.get(image_input)
101
- image = Image.open(BytesIO(response.content)).convert('RGB')
102
- else:
103
- image = Image.open(image_input).convert('RGB')
104
- elif isinstance(image_input, Image.Image):
105
- image = image_input.convert('RGB')
106
- else:
107
- raise ValueError(f"Invalid image input type: {type(image_input)}")
108
-
109
- if min_pixels is not None or max_pixels is not None:
110
- min_pixels = min_pixels or MIN_PIXELS
111
- max_pixels = max_pixels or MAX_PIXELS
112
- height, width = smart_resize(
113
- image.height,
114
- image.width,
115
- factor=IMAGE_FACTOR,
116
- min_pixels=min_pixels,
117
- max_pixels=max_pixels
118
- )
119
- image = image.resize((width, height), Image.LANCZOS)
120
 
121
- return image
122
-
123
-
124
- def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
125
- """Load images from PDF file"""
126
- images = []
127
- try:
128
- pdf_document = fitz.open(pdf_path)
129
- for page_num in range(len(pdf_document)):
130
- page = pdf_document.load_page(page_num)
131
- # Convert page to image
132
- mat = fitz.Matrix(2.0, 2.0) # Increase resolution
133
- pix = page.get_pixmap(matrix=mat)
134
- img_data = pix.tobytes("ppm")
135
- image = Image.open(BytesIO(img_data)).convert('RGB')
136
- images.append(image)
137
- pdf_document.close()
138
- except Exception as e:
139
- print(f"Error loading PDF: {e}")
140
- return []
141
- return images
142
-
143
-
144
- def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
145
- """Draw layout bounding boxes on image"""
146
- img_copy = image.copy()
147
- draw = ImageDraw.Draw(img_copy)
148
 
149
- # Colors for different categories
150
- colors = {
151
- 'Caption': '#FF6B6B',
152
- 'Footnote': '#4ECDC4',
153
- 'Formula': '#45B7D1',
154
- 'List-item': '#96CEB4',
155
- 'Page-footer': '#FFEAA7',
156
- 'Page-header': '#DDA0DD',
157
- 'Picture': '#FFD93D',
158
- 'Section-header': '#6C5CE7',
159
- 'Table': '#FD79A8',
160
- 'Text': '#74B9FF',
161
- 'Title': '#E17055'
162
- }
163
 
164
  try:
165
- # Load a font
166
- try:
167
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
168
- except Exception:
169
- font = ImageFont.load_default()
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- for item in layout_data:
172
- if 'bbox' in item and 'category' in item:
173
- bbox = item['bbox']
174
- category = item['category']
175
- color = colors.get(category, '#000000')
176
-
177
- # Draw rectangle
178
- draw.rectangle(bbox, outline=color, width=2)
179
-
180
- # Draw label
181
- label = category
182
- label_bbox = draw.textbbox((0, 0), label, font=font)
183
- label_width = label_bbox[2] - label_bbox[0]
184
- label_height = label_bbox[3] - label_bbox[1]
185
-
186
- # Position label above the box
187
- label_x = bbox[0]
188
- label_y = max(0, bbox[1] - label_height - 2)
189
-
190
- # Draw background for label
191
- draw.rectangle(
192
- [label_x, label_y, label_x + label_width + 4, label_y + label_height + 2],
193
- fill=color
194
- )
195
-
196
- # Draw text
197
- draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
198
-
199
- except Exception as e:
200
- print(f"Error drawing layout: {e}")
201
-
202
- return img_copy
203
-
204
-
205
- def is_arabic_text(text: str) -> bool:
206
- """Check if text in headers and paragraphs contains mostly Arabic characters"""
207
- if not text:
208
- return False
209
-
210
- # Extract text from headers and paragraphs only
211
- # Match markdown headers (# ## ###) and regular paragraph text
212
- header_pattern = r'^#{1,6}\s+(.+)$'
213
- paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
214
-
215
- content_text = []
216
-
217
- for line in text.split('\n'):
218
- line = line.strip()
219
- if not line:
220
- continue
221
-
222
- # Check for headers
223
- header_match = re.match(header_pattern, line, re.MULTILINE)
224
- if header_match:
225
- content_text.append(header_match.group(1))
226
- continue
227
-
228
- # Check for paragraph text (exclude lists, tables, code blocks, images)
229
- if re.match(paragraph_pattern, line, re.MULTILINE):
230
- content_text.append(line)
231
-
232
- if not content_text:
233
- return False
234
-
235
- # Join all content text and check for Arabic characters
236
- combined_text = ' '.join(content_text)
237
-
238
- # Arabic Unicode ranges
239
- arabic_chars = 0
240
- total_chars = 0
241
-
242
- for char in combined_text:
243
- if char.isalpha():
244
- total_chars += 1
245
- # Arabic script ranges
246
- if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
247
- arabic_chars += 1
248
-
249
- if total_chars == 0:
250
- return False
251
-
252
- # Consider text as Arabic if more than 50% of alphabetic characters are Arabic
253
- return (arabic_chars / total_chars) > 0.5
254
-
255
-
256
- def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
257
- """Convert layout JSON to markdown format"""
258
- import base64
259
- from io import BytesIO
260
-
261
- markdown_lines = []
262
-
263
- try:
264
- # Sort items by reading order (top to bottom, left to right)
265
- sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
266
 
267
- for item in sorted_items:
268
- category = item.get('category', '')
269
- text = item.get(text_key, '')
270
- bbox = item.get('bbox', [])
271
-
272
- if category == 'Picture':
273
- # Extract image region and embed it
274
- if bbox and len(bbox) == 4:
275
- try:
276
- # Extract the image region
277
- x1, y1, x2, y2 = bbox
278
- # Ensure coordinates are within image bounds
279
- x1, y1 = max(0, int(x1)), max(0, int(y1))
280
- x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
281
-
282
- if x2 > x1 and y2 > y1:
283
- cropped_img = image.crop((x1, y1, x2, y2))
284
-
285
- # Convert to base64 for embedding
286
- buffer = BytesIO()
287
- cropped_img.save(buffer, format='PNG')
288
- img_data = base64.b64encode(buffer.getvalue()).decode()
289
-
290
- # Add as markdown image
291
- markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
292
- else:
293
- markdown_lines.append("![Image](Image region detected)\n")
294
- except Exception as e:
295
- print(f"Error processing image region: {e}")
296
- markdown_lines.append("![Image](Image detected)\n")
297
- else:
298
- markdown_lines.append("![Image](Image detected)\n")
299
- elif not text:
300
- continue
301
- elif category == 'Title':
302
- markdown_lines.append(f"# {text}\n")
303
- elif category == 'Section-header':
304
- markdown_lines.append(f"## {text}\n")
305
- elif category == 'Text':
306
- markdown_lines.append(f"{text}\n")
307
- elif category == 'List-item':
308
- markdown_lines.append(f"- {text}\n")
309
- elif category == 'Table':
310
- # If text is already HTML, keep it as is
311
- if text.strip().startswith('<'):
312
- markdown_lines.append(f"{text}\n")
313
- else:
314
- markdown_lines.append(f"**Table:** {text}\n")
315
- elif category == 'Formula':
316
- # If text is LaTeX, format it properly
317
- if text.strip().startswith('$') or '\\' in text:
318
- markdown_lines.append(f"$$\n{text}\n$$\n")
319
- else:
320
- markdown_lines.append(f"**Formula:** {text}\n")
321
- elif category == 'Caption':
322
- markdown_lines.append(f"*{text}*\n")
323
- elif category == 'Footnote':
324
- markdown_lines.append(f"^{text}^\n")
325
- elif category in ['Page-header', 'Page-footer']:
326
- # Skip headers and footers in main content
327
- continue
328
- else:
329
- markdown_lines.append(f"{text}\n")
330
-
331
- markdown_lines.append("") # Add spacing
332
-
333
  except Exception as e:
334
- print(f"Error converting to markdown: {e}")
335
- return str(layout_data)
336
-
337
- return "\n".join(markdown_lines)
338
-
339
- # Initialize model/processor lazily inside GPU context
340
- model_id = "rednote-hilab/dots.ocr"
341
- model_path = "./models/dots-ocr-local"
342
- model = None
343
- processor = None
344
-
345
- def ensure_model_loaded():
346
- """Lazily download and load model/processor using eager attention (no FlashAttention)."""
347
- global model, processor
348
- if model is not None and processor is not None:
349
- return
350
-
351
- # Always use eager attention
352
- attn_impl = "eager"
353
- # Use GPU if available, otherwise CPU
354
- if torch.cuda.is_available():
355
- dtype = torch.bfloat16 # Use bfloat16 on GPU for consistency
356
- device_map = "auto"
357
- else:
358
- dtype = torch.float32
359
- device_map = "cpu"
360
-
361
- # Download snapshot locally (idempotent)
362
- snapshot_download(
363
- repo_id=model_id,
364
- local_dir=model_path,
365
- local_dir_use_symlinks=False,
366
- )
367
-
368
- # Load model/processor
369
- loaded_model = AutoModelForCausalLM.from_pretrained(
370
- model_path,
371
- attn_implementation=attn_impl,
372
- torch_dtype=dtype,
373
- device_map=device_map,
374
- trust_remote_code=True,
375
- low_cpu_mem_usage=True,
376
- )
377
- loaded_processor = AutoProcessor.from_pretrained(
378
- model_path,
379
- trust_remote_code=True,
380
- )
381
-
382
- model = loaded_model
383
- processor = loaded_processor
384
 
385
- # Global state variables
386
- device = "cuda" if torch.cuda.is_available() else "cpu"
387
 
388
- # PDF handling state
389
- pdf_cache = {
390
- "images": [],
391
- "current_page": 0,
392
- "total_pages": 0,
393
- "file_type": None,
394
- "is_parsed": False,
395
- "results": []
396
- }
397
  @spaces.GPU()
398
- def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
399
- """Run inference on an image with the given prompt"""
 
 
 
 
 
 
 
 
 
 
400
  try:
 
401
  ensure_model_loaded()
 
402
  if model is None or processor is None:
403
- raise RuntimeError("Model not loaded. Please check model initialization.")
 
 
 
404
 
405
- # Prepare messages in the expected format
406
  messages = [
407
  {
408
  "role": "user",
409
  "content": [
410
  {
411
  "type": "image",
412
- "image": image
413
  },
414
- {"type": "text", "text": prompt}
415
- ]
 
 
 
416
  }
417
  ]
418
 
419
  # Apply chat template
420
  text = processor.apply_chat_template(
421
- messages,
422
- tokenize=False,
423
  add_generation_prompt=True
424
  )
425
 
@@ -435,22 +136,16 @@ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> s
435
  return_tensors="pt",
436
  )
437
 
438
- # Move to the model's primary device (works with device_map as well)
439
- primary_device = next(model.parameters()).device
440
- inputs = inputs.to(primary_device)
441
-
442
- # Generate output - DETERMINISTIC MODE
443
- # Set seed for complete reproducibility
444
- torch.manual_seed(42)
445
- if torch.cuda.is_available():
446
- torch.cuda.manual_seed_all(42)
447
 
 
448
  with torch.no_grad():
449
  generated_ids = model.generate(
450
- **inputs,
451
  max_new_tokens=max_new_tokens,
452
- do_sample=False, # Greedy decoding for deterministic output
453
- # Remove temperature/top_p/top_k when do_sample=False for consistency
454
  )
455
 
456
  # Decode output
@@ -459,1095 +154,226 @@ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> s
459
  ]
460
 
461
  output_text = processor.batch_decode(
462
- generated_ids_trimmed,
463
- skip_special_tokens=True,
464
  clean_up_tokenization_spaces=False
465
  )
466
 
467
- return output_text[0] if output_text else ""
468
-
469
- except Exception as e:
470
- print(f"Error during inference: {e}")
471
- traceback.print_exc()
472
- return f"Error during inference: {str(e)}"
473
-
474
-
475
- @spaces.GPU()
476
- def _generate_text_and_confidence_for_crop(
477
- image: Image.Image,
478
- max_new_tokens: int = 128,
479
- ) -> Tuple[str, float]:
480
- """Generate text for a cropped region and compute average per-token confidence from model scores.
481
-
482
- Returns (generated_text, average_confidence_percent).
483
- """
484
- try:
485
- ensure_model_loaded()
486
- # Prepare a concise extraction prompt for the crop
487
- messages = [
488
- {
489
- "role": "user",
490
- "content": [
491
- {"type": "image", "image": image},
492
- {
493
- "type": "text",
494
- "text": "Extract the exact text content from this image region. Output text only without translation or additional words.",
495
- },
496
- ],
497
- }
498
- ]
499
-
500
- # Apply chat template
501
- text = processor.apply_chat_template(
502
- messages, tokenize=False, add_generation_prompt=True
503
- )
504
-
505
- # Process vision information
506
- image_inputs, video_inputs = process_vision_info(messages)
507
-
508
- # Prepare inputs
509
- inputs = processor(
510
- text=[text],
511
- images=image_inputs,
512
- videos=video_inputs,
513
- padding=True,
514
- return_tensors="pt",
515
- )
516
- primary_device = next(model.parameters()).device
517
- inputs = inputs.to(primary_device)
518
-
519
- # Set seed for deterministic output
520
- torch.manual_seed(42)
521
- if torch.cuda.is_available():
522
- torch.cuda.manual_seed_all(42)
523
-
524
- # Generate with scores - DETERMINISTIC MODE
525
- with torch.no_grad():
526
- outputs = model.generate(
527
- **inputs,
528
- max_new_tokens=max_new_tokens,
529
- do_sample=False, # Greedy decoding for deterministic output
530
- output_scores=True,
531
- return_dict_in_generate=True,
532
- )
533
-
534
- sequences = outputs.sequences # [batch, seq_len]
535
- input_len = inputs.input_ids.shape[1]
536
- # Trim input prompt ids to isolate generated tokens
537
- generated_ids = sequences[:, input_len:]
538
- generated_text = processor.batch_decode(
539
- generated_ids,
540
- skip_special_tokens=True,
541
- clean_up_tokenization_spaces=False,
542
- )[0].strip()
543
-
544
- # Compute average probability of chosen tokens
545
- confidences: List[float] = []
546
- for step, step_scores in enumerate(outputs.scores or []):
547
- # step_scores: [batch, vocab]
548
- probs = torch.nn.functional.softmax(step_scores, dim=-1)
549
- # token id chosen at this step
550
- if input_len + step < sequences.shape[1]:
551
- chosen_ids = sequences[:, input_len + step].unsqueeze(-1)
552
- chosen_probs = probs.gather(dim=-1, index=chosen_ids) # [batch, 1]
553
- confidences.append(float(chosen_probs[0, 0].item()))
554
-
555
- avg_conf_percent = (sum(confidences) / len(confidences) * 100.0) if confidences else 0.0
556
- return generated_text, avg_conf_percent
557
- except Exception as e:
558
- print(f"Error generating crop confidence: {e}")
559
- traceback.print_exc()
560
- return "", 0.0
561
-
562
-
563
- def estimate_text_density(image: Image.Image) -> float:
564
- """
565
- Estimate text density in image using pixel analysis.
566
-
567
- Returns value between 0.0 (no text) and 1.0 (very dense text).
568
- """
569
- try:
570
- # Convert to grayscale
571
- img_gray = image.convert('L')
572
- img_array = np.array(img_gray)
573
-
574
- # Apply Otsu's thresholding to isolate text-like regions
575
- # Text regions are typically darker than background
576
- threshold = np.mean(img_array) * 0.7 # Adaptive threshold
577
- text_mask = img_array < threshold
578
-
579
- # Calculate text density
580
- text_pixels = np.sum(text_mask)
581
- total_pixels = img_array.size
582
- density = text_pixels / total_pixels
583
-
584
- return min(density, 1.0)
585
- except Exception as e:
586
- print(f"Warning: Could not estimate text density: {e}")
587
- return 0.1 # Default to low density
588
-
589
-
590
- def should_chunk_image(image: Image.Image) -> Tuple[bool, str]:
591
- """
592
- Intelligently determine if image should be chunked for better accuracy.
593
-
594
- Returns (should_chunk, reason).
595
- """
596
- width, height = image.size
597
- total_pixels = width * height
598
- density = estimate_text_density(image)
599
-
600
- # Criteria for chunking (prioritizing ACCURACY)
601
-
602
- # 1. Very large images (>8MP) - model struggles with layout detection
603
- if total_pixels > 8_000_000:
604
- return True, f"Large image ({total_pixels/1_000_000:.1f}MP) - chunking for better layout detection"
605
-
606
- # 2. Dense text (>25% coverage) in large image - overwhelming for single pass
607
- if density > 0.25 and total_pixels > 4_000_000:
608
- return True, f"Dense text ({density*100:.1f}% coverage) in large image - chunking for accuracy"
609
-
610
- # 3. Very dense text (>40%) regardless of size - likely tables/forms
611
- if density > 0.40:
612
- return True, f"Very dense text ({density*100:.1f}% coverage) - likely structured document, chunking"
613
-
614
- # 4. Extreme aspect ratio - likely scrolled document
615
- aspect_ratio = max(width, height) / min(width, height)
616
- if aspect_ratio > 3.0 and total_pixels > 3_000_000:
617
- return True, f"Extreme aspect ratio ({aspect_ratio:.1f}) - chunking vertically"
618
-
619
- return False, "Image size and density within optimal range"
620
-
621
-
622
- def chunk_image_intelligently(image: Image.Image) -> List[Dict[str, Any]]:
623
- """
624
- Chunk image into optimal pieces for processing.
625
- Uses overlap to prevent text cutting and smart sizing for accuracy.
626
-
627
- Returns list of chunks with metadata.
628
- """
629
- width, height = image.size
630
-
631
- # Determine optimal chunk size based on density and dimensions
632
- density = estimate_text_density(image)
633
-
634
- if density > 0.40:
635
- # Very dense - use smaller chunks for better accuracy
636
- chunk_size = 1600
637
- elif density > 0.25:
638
- # Moderate density
639
- chunk_size = 2048
640
- else:
641
- # Lower density - can use larger chunks
642
- chunk_size = 2800
643
-
644
- overlap = 150 # Generous overlap to prevent text cutting
645
-
646
- chunks = []
647
- chunk_id = 0
648
-
649
- # Calculate grid
650
- y_positions = list(range(0, height, chunk_size - overlap))
651
- if y_positions[-1] + chunk_size < height:
652
- y_positions.append(height - chunk_size)
653
-
654
- x_positions = list(range(0, width, chunk_size - overlap))
655
- if x_positions[-1] + chunk_size < width:
656
- x_positions.append(width - chunk_size)
657
-
658
- for y in y_positions:
659
- for x in x_positions:
660
- x1, y1 = max(0, x), max(0, y)
661
- x2 = min(x1 + chunk_size, width)
662
- y2 = min(y1 + chunk_size, height)
663
-
664
- # Skip if chunk is too small (overlap region)
665
- if (x2 - x1) < chunk_size // 2 or (y2 - y1) < chunk_size // 2:
666
- continue
667
-
668
- chunk_img = image.crop((x1, y1, x2, y2))
669
-
670
- chunks.append({
671
- 'id': chunk_id,
672
- 'image': chunk_img,
673
- 'offset': (x1, y1),
674
- 'bbox': (x1, y1, x2, y2),
675
- 'size': (x2 - x1, y2 - y1)
676
- })
677
- chunk_id += 1
678
-
679
- print(f"📐 Chunked into {len(chunks)} pieces (chunk_size={chunk_size}, overlap={overlap})")
680
- return chunks
681
-
682
-
683
- def merge_chunk_results(chunk_results: List[Dict[str, Any]], original_size: Tuple[int, int]) -> Dict[str, Any]:
684
- """
685
- Intelligently merge results from multiple chunks.
686
- Handles overlapping regions and deduplication.
687
- """
688
- merged_layout = []
689
- seen_regions = set()
690
-
691
- for chunk_result in chunk_results:
692
- offset_x, offset_y = chunk_result['offset']
693
-
694
- for item in chunk_result.get('layout_result', []):
695
- bbox = item.get('bbox', [])
696
- if not bbox or len(bbox) != 4:
697
- continue
698
-
699
- # Adjust bbox to original image coordinates
700
- adjusted_bbox = [
701
- bbox[0] + offset_x,
702
- bbox[1] + offset_y,
703
- bbox[2] + offset_x,
704
- bbox[3] + offset_y
705
- ]
706
-
707
- # Simple deduplication: check if similar region already exists
708
- region_key = (
709
- adjusted_bbox[0] // 50, # Grid-based dedup (50px tolerance)
710
- adjusted_bbox[1] // 50,
711
- adjusted_bbox[2] // 50,
712
- adjusted_bbox[3] // 50,
713
- item.get('category', 'Text')
714
- )
715
-
716
- if region_key in seen_regions:
717
- continue
718
-
719
- seen_regions.add(region_key)
720
-
721
- # Create merged item
722
- merged_item = item.copy()
723
- merged_item['bbox'] = adjusted_bbox
724
- merged_layout.append(merged_item)
725
-
726
- # Sort by reading order (top to bottom, left to right)
727
- merged_layout.sort(key=lambda x: (x.get('bbox', [0, 0])[1], x.get('bbox', [0, 0])[0]))
728
-
729
- # Create merged result
730
- merged_result = {
731
- 'layout_result': merged_layout,
732
- 'is_merged': True,
733
- 'num_chunks': len(chunk_results)
734
- }
735
-
736
- return merged_result
737
-
738
-
739
- def process_image(
740
- image: Image.Image,
741
- min_pixels: Optional[int] = None,
742
- max_pixels: Optional[int] = None,
743
- max_new_tokens: int = 24000,
744
- ) -> Dict[str, Any]:
745
- """
746
- Process a single image with intelligent chunking for accuracy.
747
- Automatically detects dense/large images and chunks them for better results.
748
- """
749
- try:
750
- original_image = image.copy()
751
- original_size = image.size
752
-
753
- # Resize image if needed
754
- if min_pixels is not None or max_pixels is not None:
755
- image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
756
-
757
- # 🎯 INTELLIGENT CHUNKING: Check if image needs chunking for better accuracy
758
- needs_chunking, reason = should_chunk_image(image)
759
-
760
- if needs_chunking:
761
- print(f"🔄 {reason}")
762
- print(f" Processing in chunks for maximum accuracy...")
763
-
764
- # Chunk the image
765
- chunks = chunk_image_intelligently(image)
766
-
767
- # Process each chunk
768
- chunk_results = []
769
- for i, chunk_data in enumerate(chunks):
770
- print(f" Processing chunk {i+1}/{len(chunks)}...")
771
-
772
- chunk_img = chunk_data['image']
773
-
774
- # Process this chunk with full quality
775
- chunk_output = inference(chunk_img, prompt, max_new_tokens=max_new_tokens)
776
-
777
- try:
778
- chunk_layout = json.loads(chunk_output)
779
- chunk_results.append({
780
- 'layout_result': chunk_layout,
781
- 'offset': chunk_data['offset'],
782
- 'bbox': chunk_data['bbox']
783
- })
784
- except json.JSONDecodeError:
785
- print(f" ⚠️ Chunk {i+1} failed to parse, skipping")
786
- continue
787
-
788
- # Merge chunk results intelligently
789
- if chunk_results:
790
- merged = merge_chunk_results(chunk_results, original_size)
791
- layout_data = merged['layout_result']
792
- raw_output = json.dumps(layout_data, ensure_ascii=False)
793
- print(f"✅ Merged {len(chunk_results)} chunks into {len(layout_data)} regions")
794
- else:
795
- print(f"⚠️ All chunks failed, falling back to single-pass")
796
- raw_output = inference(image, prompt, max_new_tokens=max_new_tokens)
797
- else:
798
- print(f"✅ {reason} - processing in single pass")
799
- # Standard single-pass processing
800
- raw_output = inference(image, prompt, max_new_tokens=max_new_tokens)
801
 
802
- # Process results based on prompt mode
803
- result = {
804
- 'original_image': image,
805
- 'raw_output': raw_output,
806
- 'processed_image': image,
807
- 'layout_result': None,
808
- 'markdown_content': None
809
- }
810
-
811
- # Try to parse JSON and create visualizations (since we're doing layout analysis)
812
- try:
813
- # Try to parse JSON output
814
- layout_data = json.loads(raw_output)
815
-
816
- # 🎯 INTELLIGENT CONFIDENCE SCORING
817
- # Count text regions to determine if per-region scoring is feasible
818
- num_text_regions = sum(1 for item in layout_data
819
- if item.get('text') and item.get('category') not in ['Picture'])
820
-
821
- # For dense documents (>15 regions), skip expensive per-region scoring
822
- # This prioritizes speed on dense images while maintaining OCR accuracy
823
- if num_text_regions <= 15:
824
- print(f"📊 Computing per-region confidence for {num_text_regions} regions...")
825
- # Compute per-region confidence using the model on each cropped region
826
- for idx, item in enumerate(layout_data):
827
- try:
828
- bbox = item.get('bbox', [])
829
- text_content = item.get('text', '')
830
- category = item.get('category', '')
831
- if (not text_content) or category == 'Picture' or not bbox or len(bbox) != 4:
832
- continue
833
- x1, y1, x2, y2 = bbox
834
- x1, y1 = max(0, int(x1)), max(0, int(y1))
835
- x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
836
- if x2 <= x1 or y2 <= y1:
837
- continue
838
- crop_img = image.crop((x1, y1, x2, y2))
839
- # Generate and score text for this crop; we only keep the confidence
840
- _, region_conf = _generate_text_and_confidence_for_crop(crop_img)
841
- item['confidence'] = region_conf
842
- except Exception as e:
843
- print(f"Error scoring region {idx}: {e}")
844
- # Leave confidence absent if scoring fails
845
- else:
846
- print(f"⚡ Skipping per-region confidence scoring ({num_text_regions} regions - using fast mode)")
847
- print(f" OCR accuracy maintained, confidence estimated from model output")
848
- # Assign reasonable default confidence based on successful parsing
849
- for item in layout_data:
850
- if item.get('text') and item.get('category') not in ['Picture']:
851
- item['confidence'] = 87.5 # Reasonable estimate for successful OCR
852
-
853
- result['layout_result'] = layout_data
854
-
855
- # Create visualization with bounding boxes
856
- try:
857
- processed_image = draw_layout_on_image(image, layout_data)
858
- result['processed_image'] = processed_image
859
- except Exception as e:
860
- print(f"Error drawing layout: {e}")
861
- result['processed_image'] = image
862
-
863
- # Generate markdown from layout data
864
- try:
865
- markdown_content = layoutjson2md(image, layout_data, text_key='text')
866
- result['markdown_content'] = markdown_content
867
- except Exception as e:
868
- print(f"Error generating markdown: {e}")
869
- result['markdown_content'] = raw_output
870
-
871
- # ✨ ARABIC TEXT CORRECTION: Apply intelligent correction to each text region
872
- try:
873
- print("🔧 Applying Arabic text correction...")
874
- corrector = get_corrector()
875
-
876
- for idx, item in enumerate(layout_data):
877
- text_content = item.get('text', '')
878
- category = item.get('category', '')
879
-
880
- # Only correct text regions (skip pictures, formulas, etc.)
881
- if not text_content or category in ['Picture', 'Formula', 'Table']:
882
- continue
883
-
884
- # Apply correction
885
- correction_result = corrector.correct_text(text_content)
886
-
887
- # Store both original and corrected versions
888
- item['text_original'] = text_content
889
- item['text_corrected'] = correction_result['corrected']
890
- item['correction_confidence'] = correction_result['overall_confidence']
891
- item['corrections_made'] = correction_result['corrections_made']
892
- item['word_corrections'] = correction_result['words']
893
-
894
- # Update the text field to use corrected version
895
- item['text'] = correction_result['corrected']
896
-
897
- # Regenerate markdown with corrected text
898
- corrected_markdown = layoutjson2md(image, layout_data, text_key='text')
899
- result['markdown_content_corrected'] = corrected_markdown
900
- result['markdown_content_original'] = markdown_content
901
-
902
- print(f"✅ Correction complete")
903
-
904
- except Exception as e:
905
- print(f"⚠️ Error during Arabic correction: {e}")
906
- traceback.print_exc()
907
- # Fallback: keep original text
908
- result['markdown_content_corrected'] = markdown_content
909
- result['markdown_content_original'] = markdown_content
910
-
911
- except json.JSONDecodeError:
912
- print("Failed to parse JSON output, using raw output")
913
- result['markdown_content'] = raw_output
914
- result['markdown_content_original'] = raw_output
915
- result['markdown_content_corrected'] = raw_output
916
-
917
- return result
918
 
919
  except Exception as e:
920
- print(f"Error processing image: {e}")
 
921
  traceback.print_exc()
922
- return {
923
- 'original_image': image,
924
- 'raw_output': f"Error processing image: {str(e)}",
925
- 'processed_image': image,
926
- 'layout_result': None,
927
- 'markdown_content': f"Error processing image: {str(e)}"
928
- }
929
-
930
-
931
- def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
932
- """Load file for preview (supports PDF and images)"""
933
- global pdf_cache
934
-
935
- if not file_path or not os.path.exists(file_path):
936
- return None, "No file selected"
937
-
938
- file_ext = os.path.splitext(file_path)[1].lower()
939
-
940
- try:
941
- if file_ext == '.pdf':
942
- # Load PDF pages
943
- images = load_images_from_pdf(file_path)
944
- if not images:
945
- return None, "Failed to load PDF"
946
-
947
- pdf_cache.update({
948
- "images": images,
949
- "current_page": 0,
950
- "total_pages": len(images),
951
- "file_type": "pdf",
952
- "is_parsed": False,
953
- "results": []
954
- })
955
-
956
- return images[0], f"Page 1 / {len(images)}"
957
-
958
- elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
959
- # Load single image
960
- image = Image.open(file_path).convert('RGB')
961
-
962
- pdf_cache.update({
963
- "images": [image],
964
- "current_page": 0,
965
- "total_pages": 1,
966
- "file_type": "image",
967
- "is_parsed": False,
968
- "results": []
969
- })
970
-
971
- return image, "Page 1 / 1"
972
- else:
973
- return None, f"Unsupported file format: {file_ext}"
974
-
975
- except Exception as e:
976
- print(f"Error loading file: {e}")
977
- return None, f"Error loading file: {str(e)}"
978
-
979
-
980
- def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, List, Any, Optional[Image.Image], Optional[Dict]]:
981
- """Navigate through PDF pages and update all relevant outputs."""
982
- global pdf_cache
983
-
984
- if not pdf_cache["images"]:
985
- return None, '<div class="page-info">No file loaded</div>', [], "No results yet", None, None
986
-
987
- if direction == "prev":
988
- pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
989
- elif direction == "next":
990
- pdf_cache["current_page"] = min(
991
- pdf_cache["total_pages"] - 1,
992
- pdf_cache["current_page"] + 1
993
- )
994
-
995
- index = pdf_cache["current_page"]
996
- current_image_preview = pdf_cache["images"][index]
997
- page_info_html = f'<div class="page-info">Page {index + 1} / {pdf_cache["total_pages"]}</div>'
998
-
999
- # Initialize default result values
1000
- markdown_content = "Page not processed yet"
1001
- processed_img = None
1002
- layout_json = None
1003
- ocr_table_data = []
1004
-
1005
- # Get results for current page if available
1006
- if (pdf_cache["is_parsed"] and
1007
- index < len(pdf_cache["results"]) and
1008
- pdf_cache["results"][index]):
1009
-
1010
- result = pdf_cache["results"][index]
1011
- markdown_content = result.get('markdown_content') or result.get('raw_output', 'No content available')
1012
- processed_img = result.get('processed_image', None) # Get the processed image
1013
- layout_json = result.get('layout_result', None) # Get the layout JSON
1014
-
1015
- # Generate OCR table for current page
1016
- if layout_json and result.get('original_image'):
1017
- # Need to import the helper here or move it outside
1018
- import base64
1019
- from io import BytesIO
1020
-
1021
- for idx, item in enumerate(layout_json):
1022
- bbox = item.get('bbox', [])
1023
- text = item.get('text', '')
1024
- category = item.get('category', '')
1025
-
1026
- if not text or category == 'Picture':
1027
- continue
1028
-
1029
- img_html = ""
1030
- if bbox and len(bbox) == 4:
1031
- try:
1032
- x1, y1, x2, y2 = bbox
1033
- orig_img = result['original_image']
1034
- x1, y1 = max(0, int(x1)), max(0, int(y1))
1035
- x2, y2 = min(orig_img.width, int(x2)), min(orig_img.height, int(y2))
1036
-
1037
- if x2 > x1 and y2 > y1:
1038
- cropped_img = orig_img.crop((x1, y1, x2, y2))
1039
- buffer = BytesIO()
1040
- cropped_img.save(buffer, format='PNG')
1041
- img_data = base64.b64encode(buffer.getvalue()).decode()
1042
- img_html = f'<img src="data:image/png;base64,{img_data}" style="max-width:200px; max-height:100px; object-fit:contain;" />'
1043
- except Exception as e:
1044
- print(f"Error cropping region {idx}: {e}")
1045
- img_html = f"<div>Region {idx+1}</div>"
1046
- else:
1047
- img_html = f"<div>Region {idx+1}</div>"
1048
-
1049
- # Extract confidence from item if available, otherwise N/A
1050
- confidence = item.get('confidence', 'N/A')
1051
- if isinstance(confidence, (int, float)):
1052
- confidence = f"{confidence:.1f}%"
1053
- elif confidence != 'N/A':
1054
- confidence = str(confidence)
1055
-
1056
- ocr_table_data.append([img_html, text, confidence])
1057
-
1058
- # Check for Arabic text to set RTL property
1059
- if is_arabic_text(markdown_content):
1060
- markdown_update = gr.update(value=markdown_content, rtl=True)
1061
- else:
1062
- markdown_update = markdown_content
1063
-
1064
- return current_image_preview, page_info_html, ocr_table_data, markdown_update, processed_img, layout_json
1065
 
1066
 
1067
  def create_gradio_interface():
1068
- """Create the Gradio interface"""
1069
 
1070
  # Custom CSS
1071
  css = """
1072
  .main-container {
1073
- max-width: 1400px;
1074
  margin: 0 auto;
1075
  }
1076
 
1077
  .header-text {
1078
  text-align: center;
1079
  color: #2c3e50;
1080
- margin-bottom: 20px;
1081
  }
1082
 
1083
  .process-button {
 
1084
  border: none !important;
1085
  color: white !important;
1086
  font-weight: bold !important;
 
 
1087
  }
1088
 
1089
  .process-button:hover {
1090
  transform: translateY(-2px) !important;
1091
- box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
1092
- }
1093
-
1094
- .info-box {
1095
- border: 1px solid #dee2e6;
1096
- border-radius: 8px;
1097
- padding: 15px;
1098
- margin: 10px 0;
1099
- }
1100
-
1101
- .page-info {
1102
- text-align: center;
1103
- padding: 8px 16px;
1104
- border-radius: 20px;
1105
- font-weight: bold;
1106
- margin: 10px 0;
1107
  }
1108
 
1109
- .model-status {
1110
- padding: 10px;
 
1111
  border-radius: 8px;
1112
- margin: 10px 0;
1113
- text-align: center;
1114
- font-weight: bold;
1115
- }
1116
-
1117
- .status-ready {
1118
- background: #d1edff;
1119
- color: #0c5460;
1120
- border: 1px solid #b8daff;
1121
- }
1122
-
1123
- /* Arabic Correction Styling */
1124
- .original-text-box {
1125
- background: #fff5f5 !important;
1126
- border: 2px solid #fc8181 !important;
1127
- border-radius: 8px;
1128
- padding: 15px;
1129
  min-height: 300px;
1130
- direction: rtl;
 
 
1131
  }
1132
 
1133
- .corrected-text-box {
1134
- background: #f0fff4 !important;
1135
- border: 2px solid #68d391 !important;
1136
- border-radius: 8px;
1137
  padding: 15px;
1138
- min-height: 300px;
1139
- direction: rtl;
1140
- }
1141
-
1142
- .correction-high {
1143
- background: #c6f6d5;
1144
- padding: 2px 4px;
1145
- border-radius: 3px;
1146
- }
1147
-
1148
- .correction-medium {
1149
- background: #fef5e7;
1150
- padding: 2px 4px;
1151
- border-radius: 3px;
1152
- }
1153
-
1154
- .correction-low {
1155
- background: #ffe0e0;
1156
- padding: 2px 4px;
1157
- border-radius: 3px;
1158
  }
1159
  """
1160
 
1161
- with gr.Blocks(theme=gr.themes.Soft(), css=css, title="Arabic OCR - Document Text Extraction") as demo:
1162
 
1163
  # Header
1164
  gr.HTML("""
1165
- <div class="title" style="text-align: center">
1166
- <h1>🔍 Arabic OCR - Professional Document Text Extraction</h1>
1167
- <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
1168
- Advanced AI-powered OCR solution for Arabic documents with high accuracy layout detection and text extraction
1169
  </p>
 
 
 
 
 
 
 
 
 
 
 
1170
  </div>
1171
  """)
1172
 
1173
  # Main interface
1174
  with gr.Row():
1175
- # Left column - Input and controls
1176
  with gr.Column(scale=1):
1177
-
1178
- # File input
1179
- file_input = gr.File(
1180
- label="Upload Image or PDF",
1181
- file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"],
1182
- type="filepath"
1183
- )
1184
-
1185
- # Image preview
1186
- image_preview = gr.Image(
1187
- label="Preview",
1188
  type="pil",
1189
- interactive=False,
1190
- height=300
1191
  )
1192
 
1193
- # Page navigation for PDFs
1194
- with gr.Row():
1195
- prev_page_btn = gr.Button("◀ Previous", size="md")
1196
- page_info = gr.HTML('<div class="page-info">No file loaded</div>')
1197
- next_page_btn = gr.Button("Next ▶", size="md")
1198
-
1199
  # Advanced settings
1200
- with gr.Accordion("Advanced Settings", open=False):
1201
- max_new_tokens = gr.Slider(
1202
- minimum=1000,
1203
- maximum=32000,
1204
- value=24000,
1205
- step=1000,
1206
- label="Max New Tokens",
1207
- info="Maximum number of tokens to generate"
1208
  )
1209
 
1210
- min_pixels = gr.Number(
1211
- value=MIN_PIXELS,
1212
- label="Min Pixels",
1213
- info="Minimum image resolution"
 
 
 
1214
  )
1215
 
1216
- max_pixels = gr.Number(
1217
- value=MAX_PIXELS,
1218
- label="Max Pixels",
1219
- info="Maximum image resolution"
1220
- )
1221
 
1222
  # Process button
1223
  process_btn = gr.Button(
1224
- "🚀 Process Document",
1225
  variant="primary",
1226
  elem_classes=["process-button"],
1227
  size="lg"
1228
  )
1229
 
1230
  # Clear button
1231
- clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
1232
 
1233
- # Right column - Results
1234
- with gr.Column(scale=2):
 
 
 
 
 
 
 
 
 
 
1235
 
1236
- # Results tabs
1237
- with gr.Tabs():
1238
- # Processed image tab
1239
- with gr.Tab("🖼️ Processed Image"):
1240
- processed_image = gr.Image(
1241
- label="Image with Layout Detection",
1242
- type="pil",
1243
- interactive=False,
1244
- height=500
1245
- )
1246
- # ✨ NEW: Arabic Text Correction Comparison Tab
1247
- with gr.Tab("✨ Corrected Text (AI)"):
1248
- gr.Markdown("""
1249
- ### 🔧 AI-Powered Arabic Text Correction
1250
- This tab shows **Original OCR** vs **AI-Corrected** text side-by-side.
1251
- Corrections use dictionary matching, context analysis, and linguistic intelligence.
1252
- """)
1253
-
1254
- with gr.Row():
1255
- with gr.Column():
1256
- gr.Markdown("#### 📄 Original OCR Output")
1257
- original_text_output = gr.Markdown(
1258
- value="Original text will appear here...",
1259
- elem_classes=["original-text-box"]
1260
- )
1261
- with gr.Column():
1262
- gr.Markdown("#### ✅ Corrected Text")
1263
- corrected_text_output = gr.Markdown(
1264
- value="Corrected text will appear here...",
1265
- elem_classes=["corrected-text-box"]
1266
- )
1267
-
1268
- correction_stats = gr.Markdown(value="")
1269
-
1270
- # Editable OCR Results Table
1271
- with gr.Tab("📊 OCR Results Table"):
1272
- gr.Markdown("### Editable OCR Results\nReview and edit the extracted text for each detected region")
1273
- ocr_table = gr.Dataframe(
1274
- headers=["Region Image", "Extracted Text", "Confidence"],
1275
- datatype=["html", "str", "str"],
1276
- label="OCR Results",
1277
- interactive=True,
1278
- wrap=True
1279
- )
1280
- # Markdown output tab
1281
- with gr.Tab("📝 Extracted Content"):
1282
- markdown_output = gr.Markdown(
1283
- value="Click 'Process Document' to see extracted content...",
1284
- height=500
1285
- )
1286
- # JSON layout tab
1287
- with gr.Tab("📋 Layout JSON"):
1288
- json_output = gr.JSON(
1289
- label="Layout Analysis Results",
1290
- value=None
1291
- )
1292
 
1293
- # Helper function to create OCR table
1294
- def create_ocr_table(image: Image.Image, layout_data: List[Dict]) -> List[List[str]]:
1295
- """Create table data from layout results with cropped images"""
1296
- import base64
1297
- from io import BytesIO
1298
-
1299
- if not layout_data:
1300
- return []
1301
-
1302
- table_data = []
1303
-
1304
- for idx, item in enumerate(layout_data):
1305
- bbox = item.get('bbox', [])
1306
- text = item.get('text', '')
1307
- category = item.get('category', '')
1308
-
1309
- # Skip items without text or Picture category
1310
- if not text or category == 'Picture':
1311
- continue
1312
-
1313
- # Crop the image region
1314
- img_html = ""
1315
- if bbox and len(bbox) == 4:
1316
- try:
1317
- x1, y1, x2, y2 = bbox
1318
- # Ensure coordinates are within image bounds
1319
- x1, y1 = max(0, int(x1)), max(0, int(y1))
1320
- x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
1321
-
1322
- if x2 > x1 and y2 > y1:
1323
- cropped_img = image.crop((x1, y1, x2, y2))
1324
-
1325
- # Convert to base64 for HTML display
1326
- buffer = BytesIO()
1327
- cropped_img.save(buffer, format='PNG')
1328
- img_data = base64.b64encode(buffer.getvalue()).decode()
1329
-
1330
- # Create HTML img tag
1331
- img_html = f'<img src="data:image/png;base64,{img_data}" style="max-width:200px; max-height:100px; object-fit:contain;" />'
1332
- except Exception as e:
1333
- print(f"Error cropping region {idx}: {e}")
1334
- img_html = f"<div>Region {idx+1}</div>"
1335
- else:
1336
- img_html = f"<div>Region {idx+1}</div>"
1337
-
1338
- # Add confidence score - extract from item if available, otherwise N/A
1339
- confidence = item.get('confidence', 'N/A')
1340
- if isinstance(confidence, (int, float)):
1341
- confidence = f"{confidence:.1f}%"
1342
- elif confidence != 'N/A':
1343
- confidence = str(confidence)
1344
-
1345
- # Add row to table
1346
- table_data.append([img_html, text, confidence])
1347
-
1348
- return table_data
1349
 
1350
  # Event handlers
1351
- @spaces.GPU()
1352
- def process_document(file_path, max_tokens, min_pix, max_pix):
1353
- """Process the uploaded document"""
1354
- global pdf_cache
1355
 
1356
  try:
1357
- # Ensure model/processor are loaded within GPU context
1358
- ensure_model_loaded()
1359
- if not file_path:
1360
- return None, [], "Please upload a file first.", None
 
 
1361
 
1362
- if model is None:
1363
- return None, [], "Model not loaded. Please refresh the page and try again.", None
 
 
1364
 
1365
- # Load and preview file
1366
- image, page_info = load_file_for_preview(file_path)
1367
- if image is None:
1368
- return None, [], page_info, None
1369
 
1370
- # Process the image(s)
1371
- if pdf_cache["file_type"] == "pdf":
1372
- # Process all pages for PDF
1373
- all_results = []
1374
- all_markdown = []
1375
-
1376
- for i, img in enumerate(pdf_cache["images"]):
1377
- result = process_image(
1378
- img,
1379
- min_pixels=int(min_pix) if min_pix else None,
1380
- max_pixels=int(max_pix) if max_pix else None,
1381
- max_new_tokens=int(max_tokens) if max_tokens else 24000,
1382
- )
1383
- all_results.append(result)
1384
- if result.get('markdown_content'):
1385
- all_markdown.append(f"## Page {i+1}\n\n{result['markdown_content']}")
1386
-
1387
- pdf_cache["results"] = all_results
1388
- pdf_cache["is_parsed"] = True
1389
-
1390
- # Show results for first page
1391
- first_result = all_results[0]
1392
- combined_markdown = "\n\n---\n\n".join(all_markdown)
1393
-
1394
- # Check if the combined markdown contains mostly Arabic text
1395
- if is_arabic_text(combined_markdown):
1396
- markdown_update = gr.update(value=combined_markdown, rtl=True)
1397
- else:
1398
- markdown_update = combined_markdown
1399
-
1400
- # Create OCR table for first page
1401
- ocr_table_data = []
1402
- if first_result['layout_result']:
1403
- ocr_table_data = create_ocr_table(
1404
- first_result['original_image'],
1405
- first_result['layout_result']
1406
- )
1407
-
1408
- # Prepare correction comparison
1409
- original_text = first_result.get('markdown_content_original', first_result.get('markdown_content', ''))
1410
- corrected_text = first_result.get('markdown_content_corrected', first_result.get('markdown_content', ''))
1411
-
1412
- # Calculate correction statistics
1413
- total_corrections = 0
1414
- if first_result.get('layout_result'):
1415
- for item in first_result['layout_result']:
1416
- total_corrections += item.get('corrections_made', 0)
1417
-
1418
- stats_text = f"### 📊 Correction Statistics\n- **Corrections Made**: {total_corrections}\n- **Method**: Dictionary + Context Analysis"
1419
-
1420
- return (
1421
- first_result['processed_image'],
1422
- original_text if is_arabic_text(original_text) else gr.update(value=original_text, rtl=False),
1423
- corrected_text if is_arabic_text(corrected_text) else gr.update(value=corrected_text, rtl=False),
1424
- stats_text,
1425
- ocr_table_data,
1426
- markdown_update,
1427
- first_result['layout_result']
1428
- )
1429
- else:
1430
- # Process single image
1431
- result = process_image(
1432
- image,
1433
- min_pixels=int(min_pix) if min_pix else None,
1434
- max_pixels=int(max_pix) if max_pix else None,
1435
- max_new_tokens=int(max_tokens) if max_tokens else 24000,
1436
- )
1437
-
1438
- pdf_cache["results"] = [result]
1439
- pdf_cache["is_parsed"] = True
1440
-
1441
- # Check if the content contains mostly Arabic text
1442
- content = result['markdown_content'] or "No content extracted"
1443
- if is_arabic_text(content):
1444
- markdown_update = gr.update(value=content, rtl=True)
1445
- else:
1446
- markdown_update = content
1447
-
1448
- # Create OCR table
1449
- ocr_table_data = []
1450
- if result['layout_result']:
1451
- ocr_table_data = create_ocr_table(
1452
- result['original_image'],
1453
- result['layout_result']
1454
- )
1455
-
1456
- # Prepare correction comparison
1457
- original_text = result.get('markdown_content_original', result.get('markdown_content', ''))
1458
- corrected_text = result.get('markdown_content_corrected', result.get('markdown_content', ''))
1459
-
1460
- # Calculate correction statistics
1461
- total_corrections = 0
1462
- if result.get('layout_result'):
1463
- for item in result['layout_result']:
1464
- total_corrections += item.get('corrections_made', 0)
1465
-
1466
- stats_text = f"### 📊 Correction Statistics\n- **Corrections Made**: {total_corrections}\n- **Method**: Dictionary + Context Analysis"
1467
-
1468
- return (
1469
- result['processed_image'],
1470
- original_text if is_arabic_text(original_text) else gr.update(value=original_text, rtl=False),
1471
- corrected_text if is_arabic_text(corrected_text) else gr.update(value=corrected_text, rtl=False),
1472
- stats_text,
1473
- ocr_table_data,
1474
- markdown_update,
1475
- result['layout_result']
1476
- )
1477
-
1478
  except Exception as e:
1479
- error_msg = f"Error processing document: {str(e)}"
1480
- print(error_msg)
1481
- traceback.print_exc()
1482
- return None, "Error", "Error", "Error occurred", [], error_msg, None
1483
-
1484
- def handle_file_upload(file_path):
1485
- """Handle file upload and show preview"""
1486
- if not file_path:
1487
- return None, "No file loaded"
1488
-
1489
- image, page_info = load_file_for_preview(file_path)
1490
- return image, page_info
1491
-
1492
- def handle_page_turn(direction):
1493
- """Handle page navigation"""
1494
- image, page_info, result = turn_page(direction)
1495
- return image, page_info, result
1496
-
1497
- def clear_all():
1498
- """Clear all data and reset interface"""
1499
- global pdf_cache
1500
-
1501
- pdf_cache = {
1502
- "images": [], "current_page": 0, "total_pages": 0,
1503
- "file_type": None, "is_parsed": False, "results": []
1504
- }
1505
-
1506
- return (
1507
- None, # file_input
1508
- None, # image_preview
1509
- '<div class="page-info">No file loaded</div>', # page_info
1510
- None, # processed_image
1511
- "Original text will appear here...", # original_text_output
1512
- "Corrected text will appear here...", # corrected_text_output
1513
- "", # correction_stats
1514
- [], # ocr_table
1515
- "Click 'Process Document' to see extracted content...", # markdown_output
1516
- None, # json_output
1517
- )
1518
 
1519
- # Wire up event handlers
1520
- file_input.change(
1521
- handle_file_upload,
1522
- inputs=[file_input],
1523
- outputs=[image_preview, page_info]
1524
- )
1525
 
1526
- # The outputs list is now updated to include all components that need to change
1527
- prev_page_btn.click(
1528
- lambda: turn_page("prev"),
1529
- outputs=[image_preview, page_info, ocr_table, markdown_output, processed_image, json_output]
1530
- )
1531
-
1532
- next_page_btn.click(
1533
- lambda: turn_page("next"),
1534
- outputs=[image_preview, page_info, ocr_table, markdown_output, processed_image, json_output]
1535
- )
1536
 
 
1537
  process_btn.click(
1538
- process_document,
1539
- inputs=[file_input, max_new_tokens, min_pixels, max_pixels],
1540
- outputs=[processed_image, original_text_output, corrected_text_output, correction_stats, ocr_table, markdown_output, json_output]
1541
  )
1542
 
1543
- # The outputs list for the clear button is now correct
1544
  clear_btn.click(
1545
- clear_all,
1546
- outputs=[
1547
- file_input, image_preview, page_info, processed_image,
1548
- original_text_output, corrected_text_output, correction_stats,
1549
- ocr_table, markdown_output, json_output
1550
- ]
 
 
1551
  )
1552
 
1553
  return demo
@@ -1563,3 +389,4 @@ if __name__ == "__main__":
1563
  debug=True,
1564
  show_error=True
1565
  )
 
 
1
  import spaces
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
 
3
  import torch
4
+ from PIL import Image
 
5
  from qwen_vl_utils import process_vision_info
6
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
7
+ import traceback
 
 
 
8
 
9
  # ========================================
10
+ # AIN VLM MODEL FOR OCR
11
  # ========================================
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Model configuration
14
+ MODEL_ID = "MBZUAI/AIN"
15
 
16
+ # Global model and processor
17
+ model = None
18
+ processor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Strict OCR-focused prompt
21
+ OCR_PROMPT = """Extract all text from this image exactly as it appears.
22
 
23
+ Requirements:
24
+ 1. Extract ONLY the text content - do not describe, analyze, or interpret the image
25
+ 2. Maintain the original text structure, layout, and formatting
26
+ 3. Preserve line breaks, paragraphs, and spacing as they appear
27
+ 4. Do not translate the text - keep it in its original language
28
+ 5. Do not add any explanations, descriptions, or additional commentary
29
+ 6. If there are tables, maintain their structure
30
+ 7. If there are headers, titles, or sections, preserve their hierarchy
 
 
 
 
 
 
 
 
 
 
31
 
32
+ Output only the extracted text, nothing else."""
 
 
 
 
 
 
 
 
33
 
34
 
35
+ def ensure_model_loaded():
36
+ """Lazily load the AIN VLM model and processor."""
37
+ global model, processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ if model is not None and processor is not None:
40
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ print("🔄 Loading AIN VLM model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  try:
45
+ # Determine device and dtype
46
+ if torch.cuda.is_available():
47
+ device_map = "auto"
48
+ torch_dtype = "auto"
49
+ print("✅ Using GPU (CUDA)")
50
+ else:
51
+ device_map = "cpu"
52
+ torch_dtype = torch.float32
53
+ print("✅ Using CPU")
54
+
55
+ # Load model
56
+ loaded_model = Qwen2VLForConditionalGeneration.from_pretrained(
57
+ MODEL_ID,
58
+ torch_dtype=torch_dtype,
59
+ device_map=device_map,
60
+ trust_remote_code=True,
61
+ )
62
 
63
+ # Load processor
64
+ loaded_processor = AutoProcessor.from_pretrained(
65
+ MODEL_ID,
66
+ trust_remote_code=True,
67
+ )
68
+
69
+ model = loaded_model
70
+ processor = loaded_processor
71
+
72
+ print("✅ Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  except Exception as e:
75
+ print(f"Error loading model: {e}")
76
+ traceback.print_exc()
77
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
 
 
79
 
 
 
 
 
 
 
 
 
 
80
  @spaces.GPU()
81
+ def extract_text_from_image(image: Image.Image, custom_prompt: str = None, max_new_tokens: int = 2048) -> str:
82
+ """
83
+ Extract text from image using AIN VLM model.
84
+
85
+ Args:
86
+ image: PIL Image to process
87
+ custom_prompt: Optional custom prompt (uses default OCR prompt if None)
88
+ max_new_tokens: Maximum tokens to generate
89
+
90
+ Returns:
91
+ Extracted text as string
92
+ """
93
  try:
94
+ # Ensure model is loaded
95
  ensure_model_loaded()
96
+
97
  if model is None or processor is None:
98
+ return "❌ Error: Model not loaded. Please refresh and try again."
99
+
100
+ # Use custom prompt or default OCR prompt
101
+ prompt_to_use = custom_prompt if custom_prompt and custom_prompt.strip() else OCR_PROMPT
102
 
103
+ # Prepare messages in the format expected by the model
104
  messages = [
105
  {
106
  "role": "user",
107
  "content": [
108
  {
109
  "type": "image",
110
+ "image": image,
111
  },
112
+ {
113
+ "type": "text",
114
+ "text": prompt_to_use
115
+ },
116
+ ],
117
  }
118
  ]
119
 
120
  # Apply chat template
121
  text = processor.apply_chat_template(
122
+ messages,
123
+ tokenize=False,
124
  add_generation_prompt=True
125
  )
126
 
 
136
  return_tensors="pt",
137
  )
138
 
139
+ # Move to device
140
+ device = next(model.parameters()).device
141
+ inputs = inputs.to(device)
 
 
 
 
 
 
142
 
143
+ # Generate output
144
  with torch.no_grad():
145
  generated_ids = model.generate(
146
+ **inputs,
147
  max_new_tokens=max_new_tokens,
148
+ do_sample=False, # Greedy decoding for consistency
 
149
  )
150
 
151
  # Decode output
 
154
  ]
155
 
156
  output_text = processor.batch_decode(
157
+ generated_ids_trimmed,
158
+ skip_special_tokens=True,
159
  clean_up_tokenization_spaces=False
160
  )
161
 
162
+ result = output_text[0] if output_text else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ return result.strip() if result else "No text extracted"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  except Exception as e:
167
+ error_msg = f"Error during text extraction: {str(e)}"
168
+ print(error_msg)
169
  traceback.print_exc()
170
+ return error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  def create_gradio_interface():
174
+ """Create the Gradio interface for AIN OCR."""
175
 
176
  # Custom CSS
177
  css = """
178
  .main-container {
179
+ max-width: 1200px;
180
  margin: 0 auto;
181
  }
182
 
183
  .header-text {
184
  text-align: center;
185
  color: #2c3e50;
186
+ margin-bottom: 30px;
187
  }
188
 
189
  .process-button {
190
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
191
  border: none !important;
192
  color: white !important;
193
  font-weight: bold !important;
194
+ font-size: 1.1em !important;
195
+ padding: 12px 24px !important;
196
  }
197
 
198
  .process-button:hover {
199
  transform: translateY(-2px) !important;
200
+ box-shadow: 0 6px 12px rgba(0,0,0,0.2) !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  }
202
 
203
+ .output-text {
204
+ background: #f8f9fa;
205
+ border: 2px solid #dee2e6;
206
  border-radius: 8px;
207
+ padding: 20px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  min-height: 300px;
209
+ font-family: 'Courier New', monospace;
210
+ white-space: pre-wrap;
211
+ direction: auto;
212
  }
213
 
214
+ .info-box {
215
+ background: #e3f2fd;
216
+ border-left: 4px solid #2196f3;
 
217
  padding: 15px;
218
+ margin: 10px 0;
219
+ border-radius: 4px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  }
221
  """
222
 
223
+ with gr.Blocks(theme=gr.themes.Soft(), css=css, title="AIN VLM OCR") as demo:
224
 
225
  # Header
226
  gr.HTML("""
227
+ <div class="header-text">
228
+ <h1>🔍 AIN VLM - Vision Language Model OCR</h1>
229
+ <p style="font-size: 1.1em; color: #6b7280; margin-top: 10px;">
230
+ Advanced OCR using Vision Language Model (VLM) for accurate text extraction
231
  </p>
232
+ <p style="font-size: 0.95em; color: #9ca3af; margin-top: 8px;">
233
+ Powered by <strong>MBZUAI/AIN</strong> - Specialized for understanding and extracting text from images
234
+ </p>
235
+ </div>
236
+ """)
237
+
238
+ # Info box
239
+ gr.Markdown("""
240
+ <div class="info-box">
241
+ <strong>ℹ️ How it works:</strong> Upload an image containing text, click "Process Image", and get the extracted text.
242
+ The VLM model intelligently understands context and can handle handwritten text better than traditional OCR models.
243
  </div>
244
  """)
245
 
246
  # Main interface
247
  with gr.Row():
248
+ # Left column - Input
249
  with gr.Column(scale=1):
250
+ # Image input
251
+ image_input = gr.Image(
252
+ label="📸 Upload Image",
 
 
 
 
 
 
 
 
253
  type="pil",
254
+ height=400
 
255
  )
256
 
 
 
 
 
 
 
257
  # Advanced settings
258
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
259
+ custom_prompt = gr.Textbox(
260
+ label="Custom Prompt (Optional)",
261
+ placeholder="Leave empty to use default OCR prompt...",
262
+ lines=4,
263
+ info="Customize the prompt if you want specific extraction behavior"
 
 
264
  )
265
 
266
+ max_tokens = gr.Slider(
267
+ minimum=512,
268
+ maximum=4096,
269
+ value=2048,
270
+ step=128,
271
+ label="Max Tokens",
272
+ info="Maximum length of extracted text"
273
  )
274
 
275
+ show_prompt_btn = gr.Button("👁️ Show Default Prompt", size="sm")
 
 
 
 
276
 
277
  # Process button
278
  process_btn = gr.Button(
279
+ "🚀 Process Image",
280
  variant="primary",
281
  elem_classes=["process-button"],
282
  size="lg"
283
  )
284
 
285
  # Clear button
286
+ clear_btn = gr.Button("🗑️ Clear All", variant="secondary", size="lg")
287
 
288
+ # Right column - Output
289
+ with gr.Column(scale=1):
290
+ # Text output
291
+ text_output = gr.Textbox(
292
+ label="📝 Extracted Text",
293
+ placeholder="Extracted text will appear here...",
294
+ lines=20,
295
+ max_lines=25,
296
+ show_copy_button=True,
297
+ interactive=False,
298
+ elem_classes=["output-text"]
299
+ )
300
 
301
+ # Status/info
302
+ status_output = gr.Markdown(
303
+ value="*Ready to process images*",
304
+ elem_classes=["info-box"]
305
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
+ # Examples
308
+ gr.Markdown("### 📚 Example Images")
309
+ gr.Examples(
310
+ examples=[
311
+ ["image/app/1762329983969.png"],
312
+ ["image/app/1762330009302.png"],
313
+ ["image/app/1762330020168.png"],
314
+ ],
315
+ inputs=image_input,
316
+ label="Try these examples"
317
+ )
318
+
319
+ # Default prompt display
320
+ default_prompt_display = gr.Textbox(
321
+ label="Default OCR Prompt",
322
+ value=OCR_PROMPT,
323
+ lines=10,
324
+ visible=False,
325
+ interactive=False
326
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  # Event handlers
329
+ def process_image_handler(image, custom_prompt_text, max_tokens_value):
330
+ """Handle image processing."""
331
+ if image is None:
332
+ return "", "⚠️ Please upload an image first."
333
 
334
  try:
335
+ status = "⏳ Processing image..."
336
+ extracted_text = extract_text_from_image(
337
+ image,
338
+ custom_prompt=custom_prompt_text,
339
+ max_new_tokens=int(max_tokens_value)
340
+ )
341
 
342
+ if extracted_text and not extracted_text.startswith("❌"):
343
+ status = f" Text extracted successfully! ({len(extracted_text)} characters)"
344
+ else:
345
+ status = "⚠️ No text extracted or error occurred."
346
 
347
+ return extracted_text, status
 
 
 
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  except Exception as e:
350
+ error_msg = f"Error: {str(e)}"
351
+ return error_msg, "❌ Processing failed."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
+ def clear_all_handler():
354
+ """Clear all inputs and outputs."""
355
+ return None, "", "", "✨ Ready to process images"
 
 
 
356
 
357
+ def toggle_prompt_display(current_visible):
358
+ """Toggle the visibility of the default prompt."""
359
+ return gr.update(visible=not current_visible)
 
 
 
 
 
 
 
360
 
361
+ # Wire up events
362
  process_btn.click(
363
+ process_image_handler,
364
+ inputs=[image_input, custom_prompt, max_tokens],
365
+ outputs=[text_output, status_output]
366
  )
367
 
 
368
  clear_btn.click(
369
+ clear_all_handler,
370
+ outputs=[image_input, text_output, custom_prompt, status_output]
371
+ )
372
+
373
+ # Show/hide default prompt
374
+ show_prompt_btn.click(
375
+ lambda: gr.update(visible=True),
376
+ outputs=[default_prompt_display]
377
  )
378
 
379
  return demo
 
389
  debug=True,
390
  show_error=True
391
  )
392
+