Spaces:
Sleeping
Sleeping
| """ | |
| Image Tool - OCR với Nanonets model | |
| """ | |
| import os | |
| import tempfile | |
| import requests | |
| from PIL import Image | |
| from typing import Optional | |
| # Global variables cho model | |
| _model = None | |
| _processor = None | |
| _tokenizer = None | |
| def initialize_nanonets_model(): | |
| """ | |
| Khởi tạo Nanonets OCR model | |
| """ | |
| global _model, _processor, _tokenizer | |
| if _model is not None: | |
| return True # Đã khởi tạo rồi | |
| try: | |
| from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText | |
| model_path = "nanonets/Nanonets-OCR-s" | |
| _model = AutoModelForImageTextToText.from_pretrained( | |
| model_path, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| attn_implementation="flash_attention_2" | |
| ) | |
| _model.eval() | |
| _tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| _processor = AutoProcessor.from_pretrained(model_path) | |
| print("✅ Nanonets OCR model initialized successfully") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Failed to initialize Nanonets model: {e}") | |
| return False | |
| def ocr_page_with_nanonets_s(image_path: str, max_new_tokens: int = 4096) -> str: | |
| """ | |
| Extract text from image using Nanonets OCR model | |
| """ | |
| global _model, _processor, _tokenizer | |
| # Đảm bảo model đã được khởi tạo | |
| if not initialize_nanonets_model(): | |
| return "Error: Could not initialize Nanonets OCR model" | |
| try: | |
| prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and ☑ for check boxes.""" | |
| image = Image.open(image_path) | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": [ | |
| {"type": "image", "image": f"file://{image_path}"}, | |
| {"type": "text", "text": prompt}, | |
| ]}, | |
| ] | |
| text = _processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = _processor(text=[text], images=[image], padding=True, return_tensors="pt") | |
| inputs = inputs.to(_model.device) | |
| output_ids = _model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) | |
| generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] | |
| output_text = _processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| return output_text[0] | |
| except Exception as e: | |
| return f"OCR Error: {str(e)}" | |
| def download_image_file(task_id: str) -> Optional[str]: | |
| """ | |
| Download image file from API | |
| """ | |
| try: | |
| api_url = "https://agents-course-unit4-scoring.hf.space" | |
| file_url = f"{api_url}/files/{task_id}" | |
| response = requests.get(file_url, timeout=30) | |
| if response.status_code == 200: | |
| # Determine file extension | |
| content_type = response.headers.get('content-type', '') | |
| if 'image' in content_type: | |
| if 'png' in content_type: | |
| suffix = '.png' | |
| elif 'jpeg' in content_type or 'jpg' in content_type: | |
| suffix = '.jpg' | |
| else: | |
| suffix = '.png' # Default | |
| else: | |
| suffix = '.png' # Default for unknown image types | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: | |
| tmp_file.write(response.content) | |
| return tmp_file.name | |
| else: | |
| return None | |
| except Exception as e: | |
| print(f"Error downloading image: {e}") | |
| return None | |
| def ocr_image_with_nanonets(task_id: str = "", image_path: str = "", max_new_tokens: int = 15000) -> str: | |
| """ | |
| Main function: OCR image với Nanonets model | |
| Args: | |
| task_id: ID để download file từ API | |
| image_path: Đường dẫn file ảnh local (nếu có) | |
| max_new_tokens: Số token tối đa cho generation | |
| Returns: | |
| Extracted text from image | |
| """ | |
| target_image_path = None | |
| try: | |
| # Xác định đường dẫn ảnh | |
| if image_path and os.path.exists(image_path): | |
| target_image_path = image_path | |
| elif task_id: | |
| target_image_path = download_image_file(task_id) | |
| if not target_image_path: | |
| return "Error: Could not download image file" | |
| else: | |
| return "Error: No image path or task_id provided" | |
| # Kiểm tra file ảnh tồn tại | |
| if not os.path.exists(target_image_path): | |
| return "Error: Image file not found" | |
| # Thực hiện OCR | |
| result = ocr_page_with_nanonets_s(target_image_path, max_new_tokens) | |
| # Cleanup downloaded file nếu cần | |
| if task_id and target_image_path != image_path: | |
| try: | |
| os.unlink(target_image_path) | |
| except: | |
| pass | |
| return result | |
| except Exception as e: | |
| return f"Image processing error: {str(e)}" | |
| # Fallback OCR function (nếu Nanonets không hoạt động) | |
| def fallback_ocr_image(task_id: str = "", image_path: str = "") -> str: | |
| """ | |
| Fallback OCR using basic image info (nếu Nanonets model không khả dụng) | |
| """ | |
| target_image_path = None | |
| try: | |
| # Xác định đường dẫn ảnh | |
| if image_path and os.path.exists(image_path): | |
| target_image_path = image_path | |
| elif task_id: | |
| target_image_path = download_image_file(task_id) | |
| if not target_image_path: | |
| return "Error: Could not download image file" | |
| else: | |
| return "Error: No image path or task_id provided" | |
| # Basic image info | |
| img = Image.open(target_image_path) | |
| result = f"Image detected - Format: {img.format}, Size: {img.size}, Mode: {img.mode}. Nanonets OCR not available. Please describe what you see in the image." | |
| # Cleanup | |
| if task_id and target_image_path != image_path: | |
| try: | |
| os.unlink(target_image_path) | |
| except: | |
| pass | |
| return result | |
| except Exception as e: | |
| return f"Image processing error: {str(e)}" | |
| # Test function | |
| if __name__ == "__main__": | |
| # Test với ảnh local (nếu có) | |
| test_image = "/path/to/test/image.jpg" | |
| if os.path.exists(test_image): | |
| result = ocr_image_with_nanonets(image_path=test_image) | |
| print("OCR Result:", result) | |
| else: | |
| print("No test image found") |