Spaces:
Sleeping
Sleeping
| """ | |
| TextLens - AI-Powered OCR Application | |
| Main entry point for the application. | |
| """ | |
| import gradio as gr | |
| import torch | |
| import time | |
| import logging | |
| from threading import Thread | |
| from PIL import Image | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModelForCausalLM, | |
| TextIteratorStreamer, | |
| Qwen2VLForConditionalGeneration, | |
| ) | |
| from transformers import Qwen2_5_VLForConditionalGeneration | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Model configurations | |
| QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" | |
| ROLMOCR_MODEL_ID = "reducto/RolmOCR" | |
| def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str: | |
| """Returns an HTML snippet for a thin animated progress bar with a label.""" | |
| return f''' | |
| <div style="display: flex; align-items: center;"> | |
| <span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
| <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;"> | |
| <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div> | |
| </div> | |
| </div> | |
| <style> | |
| @keyframes loading {{ | |
| 0% {{ transform: translateX(-100%); }} | |
| 100% {{ transform: translateX(100%); }} | |
| }} | |
| </style> | |
| ''' | |
| # Load models at startup | |
| logger.info("π Loading OCR models...") | |
| logger.info("This may take a few minutes on first run...") | |
| try: | |
| # Load Qwen2VL OCR model (primary fast model) | |
| logger.info(f"Loading Qwen2VL OCR model: {QV_MODEL_ID}") | |
| qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True) | |
| qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| QV_MODEL_ID, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).to("cuda" if torch.cuda.is_available() else "cpu").eval() | |
| logger.info("β Qwen2VL OCR model loaded successfully!") | |
| # Load RolmOCR model (specialized document model) | |
| logger.info(f"Loading RolmOCR model: {ROLMOCR_MODEL_ID}") | |
| rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True) | |
| rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| ROLMOCR_MODEL_ID, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| ).to("cuda" if torch.cuda.is_available() else "cpu").eval() | |
| logger.info("β RolmOCR model loaded successfully!") | |
| MODELS_LOADED = True | |
| logger.info("π All models loaded and ready!") | |
| except Exception as e: | |
| logger.error(f"β Failed to load models: {str(e)}") | |
| MODELS_LOADED = False | |
| def extract_text_from_image(image, text_query, use_rolmocr=False): | |
| """Extract text from image using selected OCR model with streaming response.""" | |
| if not MODELS_LOADED: | |
| yield "β Error: OCR models failed to load. Please check your setup and try again." | |
| return | |
| if image is None: | |
| yield "β No image provided. Please upload an image to extract text." | |
| return | |
| try: | |
| # Ensure image is in RGB format | |
| if not isinstance(image, Image.Image): | |
| yield "β Invalid image format. Please upload a valid image file." | |
| return | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Prepare text query | |
| if not text_query.strip(): | |
| text_query = "Extract all text from this image" | |
| # Select model and processor | |
| if use_rolmocr: | |
| processor = rolmocr_processor | |
| model = rolmocr_model | |
| model_name = "RolmOCR" | |
| logger.info("Using RolmOCR for specialized document processing") | |
| else: | |
| processor = qwen_processor | |
| model = qwen_model | |
| model_name = "Qwen2VL OCR" | |
| logger.info("Using Qwen2VL OCR for fast text extraction") | |
| # Build messages for the model | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": text_query}, | |
| {"type": "image", "image": image} | |
| ] | |
| } | |
| ] | |
| # Apply chat template and prepare inputs | |
| prompt_full = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = processor( | |
| text=[prompt_full], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True, | |
| ).to("cuda" if torch.cuda.is_available() else "cpu") | |
| # Set up streaming | |
| streamer = TextIteratorStreamer( | |
| processor, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| temperature=0.1 | |
| ) | |
| # Start generation in separate thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Yield progress bar first | |
| yield progress_bar_html(f"π Processing with {model_name}") | |
| # Stream the response | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| # Clean up any special tokens that might leak through | |
| clean_buffer = buffer.replace("<|im_end|>", "").replace("<|endoftext|>", "").strip() | |
| if clean_buffer: | |
| time.sleep(0.01) # Small delay for smooth streaming | |
| yield clean_buffer | |
| # Ensure thread completes | |
| thread.join() | |
| # Final clean response | |
| final_response = buffer.replace("<|im_end|>", "").replace("<|endoftext|>", "").strip() | |
| if not final_response: | |
| yield "β οΈ No text was detected in the image. Please try with a clearer image or different model." | |
| else: | |
| logger.info(f"β Successfully extracted text: {len(final_response)} characters") | |
| yield final_response | |
| except Exception as e: | |
| error_msg = f"β Error processing image: {str(e)}" | |
| logger.error(f"OCR processing failed: {str(e)}") | |
| yield error_msg | |
| def get_model_status(): | |
| """Get current model status information.""" | |
| if MODELS_LOADED: | |
| device = "π’ GPU (CUDA)" if torch.cuda.is_available() else "π‘ CPU" | |
| return f""" | |
| **π€ Model Status: β Ready** | |
| **Primary Model:** Qwen2VL-OCR-2B (Fast general OCR) | |
| **Secondary Model:** RolmOCR (Specialized documents) | |
| **Device:** {device} | |
| **Memory:** Optimized for streaming inference | |
| β¨ Both models loaded and ready for OCR processing! | |
| """ | |
| else: | |
| return """ | |
| **π€ Model Status: β Failed to Load** | |
| Please check your internet connection and GPU setup. | |
| Models need to be downloaded on first run. | |
| """ | |
| # Create Gradio Interface | |
| def create_interface(): | |
| """Create the streamlined OCR interface.""" | |
| with gr.Blocks( | |
| title="TextLens - Fast AI OCR", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .container { max-width: 1200px; margin: auto; } | |
| .header { text-align: center; padding: 20px; } | |
| .model-status { background: #f0f0f0; padding: 15px; border-radius: 8px; margin: 10px 0; } | |
| """ | |
| ) as interface: | |
| # Header | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>π TextLens - AI-Powered OCR</h1> | |
| <p style="font-size: 16px; color: #666;"> | |
| Fast and accurate text extraction using modern AI models | |
| </p> | |
| </div> | |
| """) | |
| # Model Status | |
| with gr.Row(): | |
| with gr.Column(): | |
| status_display = gr.Markdown( | |
| value=get_model_status(), | |
| elem_classes=["model-status"] | |
| ) | |
| refresh_btn = gr.Button("π Refresh Status", size="sm") | |
| # Main Interface | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Upload Image") | |
| image_input = gr.Image( | |
| label="Upload image for OCR", | |
| type="pil", | |
| sources=["upload", "clipboard"] | |
| ) | |
| text_query = gr.Textbox( | |
| label="π OCR Instructions (optional)", | |
| placeholder="Extract all text from this image", | |
| value="Extract all text from this image", | |
| lines=2 | |
| ) | |
| use_rolmocr = gr.Checkbox( | |
| label="π― Use RolmOCR (specialized for documents)", | |
| value=False, | |
| info="Check for complex documents/tables, uncheck for general text" | |
| ) | |
| extract_btn = gr.Button( | |
| "π Extract Text", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Extracted Text") | |
| text_output = gr.Textbox( | |
| label="OCR Results", | |
| lines=15, | |
| max_lines=25, | |
| placeholder="Extracted text will appear here...\n\nβ’ Upload an image to get started\nβ’ Choose between fast OCR or specialized document processing\nβ’ Results will stream in real-time", | |
| show_copy_button=True | |
| ) | |
| # Event handlers | |
| extract_btn.click( | |
| fn=extract_text_from_image, | |
| inputs=[image_input, text_query, use_rolmocr], | |
| outputs=text_output, | |
| show_progress="hidden" # We handle progress with custom HTML | |
| ) | |
| # Auto-extract on image upload | |
| image_input.upload( | |
| fn=extract_text_from_image, | |
| inputs=[image_input, text_query, use_rolmocr], | |
| outputs=text_output, | |
| show_progress="hidden" | |
| ) | |
| refresh_btn.click( | |
| fn=get_model_status, | |
| outputs=status_display | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| logger.info("π Starting TextLens OCR application...") | |
| try: | |
| interface = create_interface() | |
| # Launch configuration | |
| interface.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| debug=False | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to start application: {str(e)}") | |
| raise |