VanguardAI commited on
Commit
b4b76e6
·
verified ·
1 Parent(s): 0961ac0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -5
app.py CHANGED
@@ -13,6 +13,12 @@ import traceback
13
  # Model configuration
14
  MODEL_ID = "MBZUAI/AIN"
15
 
 
 
 
 
 
 
16
  # Global model and processor
17
  model = None
18
  processor = None
@@ -60,9 +66,11 @@ def ensure_model_loaded():
60
  trust_remote_code=True,
61
  )
62
 
63
- # Load processor
64
  loaded_processor = AutoProcessor.from_pretrained(
65
  MODEL_ID,
 
 
66
  trust_remote_code=True,
67
  )
68
 
@@ -78,7 +86,13 @@ def ensure_model_loaded():
78
 
79
 
80
  @spaces.GPU()
81
- def extract_text_from_image(image: Image.Image, custom_prompt: str = None, max_new_tokens: int = 2048) -> str:
 
 
 
 
 
 
82
  """
83
  Extract text from image using AIN VLM model.
84
 
@@ -86,6 +100,8 @@ def extract_text_from_image(image: Image.Image, custom_prompt: str = None, max_n
86
  image: PIL Image to process
87
  custom_prompt: Optional custom prompt (uses default OCR prompt if None)
88
  max_new_tokens: Maximum tokens to generate
 
 
89
 
90
  Returns:
91
  Extracted text as string
@@ -100,6 +116,10 @@ def extract_text_from_image(image: Image.Image, custom_prompt: str = None, max_n
100
  # Use custom prompt or default OCR prompt
101
  prompt_to_use = custom_prompt if custom_prompt and custom_prompt.strip() else OCR_PROMPT
102
 
 
 
 
 
103
  # Prepare messages in the format expected by the model
104
  messages = [
105
  {
@@ -272,6 +292,18 @@ def create_gradio_interface():
272
  info="Maximum length of extracted text"
273
  )
274
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  show_prompt_btn = gr.Button("👁️ Show Default Prompt", size="sm")
276
 
277
  # Process button
@@ -326,7 +358,7 @@ def create_gradio_interface():
326
  )
327
 
328
  # Event handlers
329
- def process_image_handler(image, custom_prompt_text, max_tokens_value):
330
  """Handle image processing."""
331
  if image is None:
332
  return "", "⚠️ Please upload an image first."
@@ -336,7 +368,9 @@ def create_gradio_interface():
336
  extracted_text = extract_text_from_image(
337
  image,
338
  custom_prompt=custom_prompt_text,
339
- max_new_tokens=int(max_tokens_value)
 
 
340
  )
341
 
342
  if extracted_text and not extracted_text.startswith("❌"):
@@ -361,7 +395,7 @@ def create_gradio_interface():
361
  # Wire up events
362
  process_btn.click(
363
  process_image_handler,
364
- inputs=[image_input, custom_prompt, max_tokens],
365
  outputs=[text_output, status_output]
366
  )
367
 
 
13
  # Model configuration
14
  MODEL_ID = "MBZUAI/AIN"
15
 
16
+ # Image resolution settings for the processor
17
+ # The default range for the number of visual tokens per image in the model is 4-16384
18
+ # These settings balance speed and memory usage
19
+ MIN_PIXELS = 256 * 28 * 28 # Minimum resolution
20
+ MAX_PIXELS = 1280 * 28 * 28 # Maximum resolution
21
+
22
  # Global model and processor
23
  model = None
24
  processor = None
 
66
  trust_remote_code=True,
67
  )
68
 
69
+ # Load processor with resolution settings
70
  loaded_processor = AutoProcessor.from_pretrained(
71
  MODEL_ID,
72
+ min_pixels=MIN_PIXELS,
73
+ max_pixels=MAX_PIXELS,
74
  trust_remote_code=True,
75
  )
76
 
 
86
 
87
 
88
  @spaces.GPU()
89
+ def extract_text_from_image(
90
+ image: Image.Image,
91
+ custom_prompt: str = None,
92
+ max_new_tokens: int = 2048,
93
+ min_pixels: int = None,
94
+ max_pixels: int = None
95
+ ) -> str:
96
  """
97
  Extract text from image using AIN VLM model.
98
 
 
100
  image: PIL Image to process
101
  custom_prompt: Optional custom prompt (uses default OCR prompt if None)
102
  max_new_tokens: Maximum tokens to generate
103
+ min_pixels: Minimum image resolution (optional)
104
+ max_pixels: Maximum image resolution (optional)
105
 
106
  Returns:
107
  Extracted text as string
 
116
  # Use custom prompt or default OCR prompt
117
  prompt_to_use = custom_prompt if custom_prompt and custom_prompt.strip() else OCR_PROMPT
118
 
119
+ # Use custom resolution settings if provided, otherwise use defaults
120
+ min_pix = min_pixels if min_pixels else MIN_PIXELS
121
+ max_pix = max_pixels if max_pixels else MAX_PIXELS
122
+
123
  # Prepare messages in the format expected by the model
124
  messages = [
125
  {
 
292
  info="Maximum length of extracted text"
293
  )
294
 
295
+ with gr.Row():
296
+ min_pixels_input = gr.Number(
297
+ value=MIN_PIXELS,
298
+ label="Min Pixels",
299
+ info="Minimum image resolution"
300
+ )
301
+ max_pixels_input = gr.Number(
302
+ value=MAX_PIXELS,
303
+ label="Max Pixels",
304
+ info="Maximum image resolution"
305
+ )
306
+
307
  show_prompt_btn = gr.Button("👁️ Show Default Prompt", size="sm")
308
 
309
  # Process button
 
358
  )
359
 
360
  # Event handlers
361
+ def process_image_handler(image, custom_prompt_text, max_tokens_value, min_pix, max_pix):
362
  """Handle image processing."""
363
  if image is None:
364
  return "", "⚠️ Please upload an image first."
 
368
  extracted_text = extract_text_from_image(
369
  image,
370
  custom_prompt=custom_prompt_text,
371
+ max_new_tokens=int(max_tokens_value),
372
+ min_pixels=int(min_pix) if min_pix else None,
373
+ max_pixels=int(max_pix) if max_pix else None
374
  )
375
 
376
  if extracted_text and not extracted_text.startswith("❌"):
 
395
  # Wire up events
396
  process_btn.click(
397
  process_image_handler,
398
+ inputs=[image_input, custom_prompt, max_tokens, min_pixels_input, max_pixels_input],
399
  outputs=[text_output, status_output]
400
  )
401