Spaces:
Running
Running
| import spaces | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from qwen_vl_utils import process_vision_info | |
| from transformers import Qwen2VLForConditionalGeneration, AutoProcessor | |
| import traceback | |
| # ======================================== | |
| # AIN VLM MODEL FOR OCR | |
| # ======================================== | |
| # Model configuration | |
| MODEL_ID = "MBZUAI/AIN" | |
| # Image resolution settings for the processor | |
| # The default range for the number of visual tokens per image in the model is 4-16384 | |
| # These settings balance speed and memory usage | |
| MIN_PIXELS = 256 * 28 * 28 # Minimum resolution | |
| MAX_PIXELS = 1280 * 28 * 28 # Maximum resolution | |
| # Global model and processor | |
| model = None | |
| processor = None | |
| # Strict OCR-focused prompt | |
| OCR_PROMPT = """Extract all text from this image exactly as it appears. | |
| Requirements: | |
| 1. Extract ONLY the text content - do not describe, analyze, or interpret the image | |
| 2. Maintain the original text structure, layout, and formatting | |
| 3. Preserve line breaks, paragraphs, and spacing as they appear | |
| 4. Do not translate the text - keep it in its original language | |
| 5. Do not add any explanations, descriptions, or additional commentary | |
| 6. If there are tables, maintain their structure | |
| 7. If there are headers, titles, or sections, preserve their hierarchy | |
| Output only the extracted text, nothing else.""" | |
| def ensure_model_loaded(): | |
| """Lazily load the AIN VLM model and processor.""" | |
| global model, processor | |
| if model is not None and processor is not None: | |
| return | |
| print("π Loading AIN VLM model...") | |
| try: | |
| # Determine device and dtype | |
| if torch.cuda.is_available(): | |
| device_map = "auto" | |
| torch_dtype = "auto" | |
| print("β Using GPU (CUDA)") | |
| else: | |
| device_map = "cpu" | |
| torch_dtype = torch.float32 | |
| print("β Using CPU") | |
| # Load model | |
| loaded_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch_dtype, | |
| device_map=device_map, | |
| trust_remote_code=True, | |
| ) | |
| # Load processor with resolution settings | |
| loaded_processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| min_pixels=MIN_PIXELS, | |
| max_pixels=MAX_PIXELS, | |
| trust_remote_code=True, | |
| ) | |
| model = loaded_model | |
| processor = loaded_processor | |
| print("β Model loaded successfully!") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| traceback.print_exc() | |
| raise | |
| def extract_text_from_image( | |
| image: Image.Image, | |
| custom_prompt: str = None, | |
| max_new_tokens: int = 2048, | |
| min_pixels: int = None, | |
| max_pixels: int = None | |
| ) -> str: | |
| """ | |
| Extract text from image using AIN VLM model. | |
| Args: | |
| image: PIL Image to process | |
| custom_prompt: Optional custom prompt (uses default OCR prompt if None) | |
| max_new_tokens: Maximum tokens to generate | |
| min_pixels: Minimum image resolution (optional) | |
| max_pixels: Maximum image resolution (optional) | |
| Returns: | |
| Extracted text as string | |
| """ | |
| try: | |
| # Ensure model is loaded | |
| ensure_model_loaded() | |
| if model is None or processor is None: | |
| return "β Error: Model not loaded. Please refresh and try again." | |
| # Use custom prompt or default OCR prompt | |
| prompt_to_use = custom_prompt if custom_prompt and custom_prompt.strip() else OCR_PROMPT | |
| # Use custom resolution settings if provided, otherwise use defaults | |
| min_pix = min_pixels if min_pixels else MIN_PIXELS | |
| max_pix = max_pixels if max_pixels else MAX_PIXELS | |
| # Prepare messages in the format expected by the model | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": image, | |
| }, | |
| { | |
| "type": "text", | |
| "text": prompt_to_use | |
| }, | |
| ], | |
| } | |
| ] | |
| # Apply chat template | |
| text = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Process vision information | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| # Prepare inputs | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| # Move to device | |
| device = next(model.parameters()).device | |
| inputs = inputs.to(device) | |
| # Generate output | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, # Greedy decoding for consistency | |
| ) | |
| # Decode output | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| ) | |
| result = output_text[0] if output_text else "" | |
| return result.strip() if result else "No text extracted" | |
| except Exception as e: | |
| error_msg = f"β Error during text extraction: {str(e)}" | |
| print(error_msg) | |
| traceback.print_exc() | |
| return error_msg | |
| def create_gradio_interface(): | |
| """Create the Gradio interface for AIN OCR.""" | |
| # Custom CSS | |
| css = """ | |
| .main-container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| .header-text { | |
| text-align: center; | |
| color: #2c3e50; | |
| margin-bottom: 30px; | |
| } | |
| .process-button { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: bold !important; | |
| font-size: 1.1em !important; | |
| padding: 12px 24px !important; | |
| } | |
| .process-button:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 12px rgba(0,0,0,0.2) !important; | |
| } | |
| .output-text { | |
| background: #f8f9fa; | |
| border: 2px solid #dee2e6; | |
| border-radius: 8px; | |
| padding: 20px; | |
| min-height: 300px; | |
| font-family: 'Courier New', monospace; | |
| white-space: pre-wrap; | |
| direction: auto; | |
| } | |
| .info-box { | |
| background: #e3f2fd; | |
| border-left: 4px solid #2196f3; | |
| padding: 15px; | |
| margin: 10px 0; | |
| border-radius: 4px; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css, title="AIN VLM OCR") as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div class="header-text"> | |
| <h1>π AIN VLM - Vision Language Model OCR</h1> | |
| <p style="font-size: 1.1em; color: #6b7280; margin-top: 10px;"> | |
| Advanced OCR using Vision Language Model (VLM) for accurate text extraction | |
| </p> | |
| <p style="font-size: 0.95em; color: #9ca3af; margin-top: 8px;"> | |
| Powered by <strong>MBZUAI/AIN</strong> - Specialized for understanding and extracting text from images | |
| </p> | |
| </div> | |
| """) | |
| # Info box | |
| gr.Markdown(""" | |
| <div class="info-box"> | |
| <strong>βΉοΈ How it works:</strong> Upload an image containing text, click "Process Image", and get the extracted text. | |
| The VLM model intelligently understands context and can handle handwritten text better than traditional OCR models. | |
| </div> | |
| """) | |
| # Main interface | |
| with gr.Row(): | |
| # Left column - Input | |
| with gr.Column(scale=1): | |
| # Image input | |
| image_input = gr.Image( | |
| label="πΈ Upload Image", | |
| type="pil", | |
| height=400 | |
| ) | |
| # Advanced settings | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| custom_prompt = gr.Textbox( | |
| label="Custom Prompt (Optional)", | |
| placeholder="Leave empty to use default OCR prompt...", | |
| lines=4, | |
| info="Customize the prompt if you want specific extraction behavior" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=512, | |
| maximum=4096, | |
| value=2048, | |
| step=128, | |
| label="Max Tokens", | |
| info="Maximum length of extracted text" | |
| ) | |
| with gr.Row(): | |
| min_pixels_input = gr.Number( | |
| value=MIN_PIXELS, | |
| label="Min Pixels", | |
| info="Minimum image resolution" | |
| ) | |
| max_pixels_input = gr.Number( | |
| value=MAX_PIXELS, | |
| label="Max Pixels", | |
| info="Maximum image resolution" | |
| ) | |
| show_prompt_btn = gr.Button("ποΈ Show Default Prompt", size="sm") | |
| # Process button | |
| process_btn = gr.Button( | |
| "π Process Image", | |
| variant="primary", | |
| elem_classes=["process-button"], | |
| size="lg" | |
| ) | |
| # Clear button | |
| clear_btn = gr.Button("ποΈ Clear All", variant="secondary", size="lg") | |
| # Right column - Output | |
| with gr.Column(scale=1): | |
| # Text output | |
| text_output = gr.Textbox( | |
| label="π Extracted Text", | |
| placeholder="Extracted text will appear here...", | |
| lines=20, | |
| max_lines=25, | |
| show_copy_button=True, | |
| interactive=False, | |
| elem_classes=["output-text"] | |
| ) | |
| # Status/info | |
| status_output = gr.Markdown( | |
| value="*Ready to process images*", | |
| elem_classes=["info-box"] | |
| ) | |
| # Examples | |
| gr.Markdown("### π Example Images") | |
| gr.Examples( | |
| examples=[ | |
| ["image/app/1762329983969.png"], | |
| ["image/app/1762330009302.png"], | |
| ["image/app/1762330020168.png"], | |
| ], | |
| inputs=image_input, | |
| label="Try these examples" | |
| ) | |
| # Default prompt display | |
| default_prompt_display = gr.Textbox( | |
| label="Default OCR Prompt", | |
| value=OCR_PROMPT, | |
| lines=10, | |
| visible=False, | |
| interactive=False | |
| ) | |
| # Event handlers | |
| def process_image_handler(image, custom_prompt_text, max_tokens_value, min_pix, max_pix): | |
| """Handle image processing.""" | |
| if image is None: | |
| return "", "β οΈ Please upload an image first." | |
| try: | |
| status = "β³ Processing image..." | |
| extracted_text = extract_text_from_image( | |
| image, | |
| custom_prompt=custom_prompt_text, | |
| max_new_tokens=int(max_tokens_value), | |
| min_pixels=int(min_pix) if min_pix else None, | |
| max_pixels=int(max_pix) if max_pix else None | |
| ) | |
| if extracted_text and not extracted_text.startswith("β"): | |
| status = f"β Text extracted successfully! ({len(extracted_text)} characters)" | |
| else: | |
| status = "β οΈ No text extracted or error occurred." | |
| return extracted_text, status | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| return error_msg, "β Processing failed." | |
| def clear_all_handler(): | |
| """Clear all inputs and outputs.""" | |
| return None, "", "", "β¨ Ready to process images" | |
| def toggle_prompt_display(current_visible): | |
| """Toggle the visibility of the default prompt.""" | |
| return gr.update(visible=not current_visible) | |
| # Wire up events | |
| process_btn.click( | |
| process_image_handler, | |
| inputs=[image_input, custom_prompt, max_tokens, min_pixels_input, max_pixels_input], | |
| outputs=[text_output, status_output] | |
| ) | |
| clear_btn.click( | |
| clear_all_handler, | |
| outputs=[image_input, text_output, custom_prompt, status_output] | |
| ) | |
| # Show/hide default prompt | |
| show_prompt_btn.click( | |
| lambda: gr.update(visible=True), | |
| outputs=[default_prompt_display] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Create and launch the interface | |
| demo = create_gradio_interface() | |
| demo.queue(max_size=10).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=True, | |
| show_error=True | |
| ) | |