final_agent_course / utils /image_tool.py
tuan3335's picture
structure code
92d2175
"""
Image Tool - OCR với Nanonets model
"""
import os
import tempfile
import requests
from PIL import Image
from typing import Optional
# Global variables cho model
_model = None
_processor = None
_tokenizer = None
def initialize_nanonets_model():
"""
Khởi tạo Nanonets OCR model
"""
global _model, _processor, _tokenizer
if _model is not None:
return True # Đã khởi tạo rồi
try:
from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
model_path = "nanonets/Nanonets-OCR-s"
_model = AutoModelForImageTextToText.from_pretrained(
model_path,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2"
)
_model.eval()
_tokenizer = AutoTokenizer.from_pretrained(model_path)
_processor = AutoProcessor.from_pretrained(model_path)
print("✅ Nanonets OCR model initialized successfully")
return True
except Exception as e:
print(f"❌ Failed to initialize Nanonets model: {e}")
return False
def ocr_page_with_nanonets_s(image_path: str, max_new_tokens: int = 4096) -> str:
"""
Extract text from image using Nanonets OCR model
"""
global _model, _processor, _tokenizer
# Đảm bảo model đã được khởi tạo
if not initialize_nanonets_model():
return "Error: Could not initialize Nanonets OCR model"
try:
prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and ☑ for check boxes."""
image = Image.open(image_path)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{"type": "image", "image": f"file://{image_path}"},
{"type": "text", "text": prompt},
]},
]
text = _processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = _processor(text=[text], images=[image], padding=True, return_tensors="pt")
inputs = inputs.to(_model.device)
output_ids = _model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = _processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
return output_text[0]
except Exception as e:
return f"OCR Error: {str(e)}"
def download_image_file(task_id: str) -> Optional[str]:
"""
Download image file from API
"""
try:
api_url = "https://agents-course-unit4-scoring.hf.space"
file_url = f"{api_url}/files/{task_id}"
response = requests.get(file_url, timeout=30)
if response.status_code == 200:
# Determine file extension
content_type = response.headers.get('content-type', '')
if 'image' in content_type:
if 'png' in content_type:
suffix = '.png'
elif 'jpeg' in content_type or 'jpg' in content_type:
suffix = '.jpg'
else:
suffix = '.png' # Default
else:
suffix = '.png' # Default for unknown image types
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_file.write(response.content)
return tmp_file.name
else:
return None
except Exception as e:
print(f"Error downloading image: {e}")
return None
def ocr_image_with_nanonets(task_id: str = "", image_path: str = "", max_new_tokens: int = 15000) -> str:
"""
Main function: OCR image với Nanonets model
Args:
task_id: ID để download file từ API
image_path: Đường dẫn file ảnh local (nếu có)
max_new_tokens: Số token tối đa cho generation
Returns:
Extracted text from image
"""
target_image_path = None
try:
# Xác định đường dẫn ảnh
if image_path and os.path.exists(image_path):
target_image_path = image_path
elif task_id:
target_image_path = download_image_file(task_id)
if not target_image_path:
return "Error: Could not download image file"
else:
return "Error: No image path or task_id provided"
# Kiểm tra file ảnh tồn tại
if not os.path.exists(target_image_path):
return "Error: Image file not found"
# Thực hiện OCR
result = ocr_page_with_nanonets_s(target_image_path, max_new_tokens)
# Cleanup downloaded file nếu cần
if task_id and target_image_path != image_path:
try:
os.unlink(target_image_path)
except:
pass
return result
except Exception as e:
return f"Image processing error: {str(e)}"
# Fallback OCR function (nếu Nanonets không hoạt động)
def fallback_ocr_image(task_id: str = "", image_path: str = "") -> str:
"""
Fallback OCR using basic image info (nếu Nanonets model không khả dụng)
"""
target_image_path = None
try:
# Xác định đường dẫn ảnh
if image_path and os.path.exists(image_path):
target_image_path = image_path
elif task_id:
target_image_path = download_image_file(task_id)
if not target_image_path:
return "Error: Could not download image file"
else:
return "Error: No image path or task_id provided"
# Basic image info
img = Image.open(target_image_path)
result = f"Image detected - Format: {img.format}, Size: {img.size}, Mode: {img.mode}. Nanonets OCR not available. Please describe what you see in the image."
# Cleanup
if task_id and target_image_path != image_path:
try:
os.unlink(target_image_path)
except:
pass
return result
except Exception as e:
return f"Image processing error: {str(e)}"
# Test function
if __name__ == "__main__":
# Test với ảnh local (nếu có)
test_image = "/path/to/test/image.jpg"
if os.path.exists(test_image):
result = ocr_image_with_nanonets(image_path=test_image)
print("OCR Result:", result)
else:
print("No test image found")