Spaces:
Sleeping
Sleeping
File size: 7,335 Bytes
92d2175 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
"""
Image Tool - OCR với Nanonets model
"""
import os
import tempfile
import requests
from PIL import Image
from typing import Optional
# Global variables cho model
_model = None
_processor = None
_tokenizer = None
def initialize_nanonets_model():
"""
Khởi tạo Nanonets OCR model
"""
global _model, _processor, _tokenizer
if _model is not None:
return True # Đã khởi tạo rồi
try:
from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
model_path = "nanonets/Nanonets-OCR-s"
_model = AutoModelForImageTextToText.from_pretrained(
model_path,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2"
)
_model.eval()
_tokenizer = AutoTokenizer.from_pretrained(model_path)
_processor = AutoProcessor.from_pretrained(model_path)
print("✅ Nanonets OCR model initialized successfully")
return True
except Exception as e:
print(f"❌ Failed to initialize Nanonets model: {e}")
return False
def ocr_page_with_nanonets_s(image_path: str, max_new_tokens: int = 4096) -> str:
"""
Extract text from image using Nanonets OCR model
"""
global _model, _processor, _tokenizer
# Đảm bảo model đã được khởi tạo
if not initialize_nanonets_model():
return "Error: Could not initialize Nanonets OCR model"
try:
prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and ☑ for check boxes."""
image = Image.open(image_path)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{"type": "image", "image": f"file://{image_path}"},
{"type": "text", "text": prompt},
]},
]
text = _processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = _processor(text=[text], images=[image], padding=True, return_tensors="pt")
inputs = inputs.to(_model.device)
output_ids = _model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = _processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
return output_text[0]
except Exception as e:
return f"OCR Error: {str(e)}"
def download_image_file(task_id: str) -> Optional[str]:
"""
Download image file from API
"""
try:
api_url = "https://agents-course-unit4-scoring.hf.space"
file_url = f"{api_url}/files/{task_id}"
response = requests.get(file_url, timeout=30)
if response.status_code == 200:
# Determine file extension
content_type = response.headers.get('content-type', '')
if 'image' in content_type:
if 'png' in content_type:
suffix = '.png'
elif 'jpeg' in content_type or 'jpg' in content_type:
suffix = '.jpg'
else:
suffix = '.png' # Default
else:
suffix = '.png' # Default for unknown image types
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_file.write(response.content)
return tmp_file.name
else:
return None
except Exception as e:
print(f"Error downloading image: {e}")
return None
def ocr_image_with_nanonets(task_id: str = "", image_path: str = "", max_new_tokens: int = 15000) -> str:
"""
Main function: OCR image với Nanonets model
Args:
task_id: ID để download file từ API
image_path: Đường dẫn file ảnh local (nếu có)
max_new_tokens: Số token tối đa cho generation
Returns:
Extracted text from image
"""
target_image_path = None
try:
# Xác định đường dẫn ảnh
if image_path and os.path.exists(image_path):
target_image_path = image_path
elif task_id:
target_image_path = download_image_file(task_id)
if not target_image_path:
return "Error: Could not download image file"
else:
return "Error: No image path or task_id provided"
# Kiểm tra file ảnh tồn tại
if not os.path.exists(target_image_path):
return "Error: Image file not found"
# Thực hiện OCR
result = ocr_page_with_nanonets_s(target_image_path, max_new_tokens)
# Cleanup downloaded file nếu cần
if task_id and target_image_path != image_path:
try:
os.unlink(target_image_path)
except:
pass
return result
except Exception as e:
return f"Image processing error: {str(e)}"
# Fallback OCR function (nếu Nanonets không hoạt động)
def fallback_ocr_image(task_id: str = "", image_path: str = "") -> str:
"""
Fallback OCR using basic image info (nếu Nanonets model không khả dụng)
"""
target_image_path = None
try:
# Xác định đường dẫn ảnh
if image_path and os.path.exists(image_path):
target_image_path = image_path
elif task_id:
target_image_path = download_image_file(task_id)
if not target_image_path:
return "Error: Could not download image file"
else:
return "Error: No image path or task_id provided"
# Basic image info
img = Image.open(target_image_path)
result = f"Image detected - Format: {img.format}, Size: {img.size}, Mode: {img.mode}. Nanonets OCR not available. Please describe what you see in the image."
# Cleanup
if task_id and target_image_path != image_path:
try:
os.unlink(target_image_path)
except:
pass
return result
except Exception as e:
return f"Image processing error: {str(e)}"
# Test function
if __name__ == "__main__":
# Test với ảnh local (nếu có)
test_image = "/path/to/test/image.jpg"
if os.path.exists(test_image):
result = ocr_image_with_nanonets(image_path=test_image)
print("OCR Result:", result)
else:
print("No test image found") |