File size: 7,335 Bytes
92d2175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""
Image Tool - OCR với Nanonets model
"""

import os
import tempfile
import requests
from PIL import Image
from typing import Optional

# Global variables cho model
_model = None
_processor = None
_tokenizer = None

def initialize_nanonets_model():
    """
    Khởi tạo Nanonets OCR model
    """
    global _model, _processor, _tokenizer
    
    if _model is not None:
        return True  # Đã khởi tạo rồi
    
    try:
        from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
        
        model_path = "nanonets/Nanonets-OCR-s"
        
        _model = AutoModelForImageTextToText.from_pretrained(
            model_path, 
            torch_dtype="auto", 
            device_map="auto", 
            attn_implementation="flash_attention_2"
        )
        _model.eval()
        
        _tokenizer = AutoTokenizer.from_pretrained(model_path)
        _processor = AutoProcessor.from_pretrained(model_path)
        
        print("✅ Nanonets OCR model initialized successfully")
        return True
        
    except Exception as e:
        print(f"❌ Failed to initialize Nanonets model: {e}")
        return False

def ocr_page_with_nanonets_s(image_path: str, max_new_tokens: int = 4096) -> str:
    """
    Extract text from image using Nanonets OCR model
    """
    global _model, _processor, _tokenizer
    
    # Đảm bảo model đã được khởi tạo
    if not initialize_nanonets_model():
        return "Error: Could not initialize Nanonets OCR model"
    
    try:
        prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and ☑ for check boxes."""
        
        image = Image.open(image_path)
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": [
                {"type": "image", "image": f"file://{image_path}"},
                {"type": "text", "text": prompt},
            ]},
        ]
        
        text = _processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = _processor(text=[text], images=[image], padding=True, return_tensors="pt")
        inputs = inputs.to(_model.device)
        
        output_ids = _model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
        generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
        
        output_text = _processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        return output_text[0]
        
    except Exception as e:
        return f"OCR Error: {str(e)}"

def download_image_file(task_id: str) -> Optional[str]:
    """
    Download image file from API
    """
    try:
        api_url = "https://agents-course-unit4-scoring.hf.space"
        file_url = f"{api_url}/files/{task_id}"
        
        response = requests.get(file_url, timeout=30)
        if response.status_code == 200:
            # Determine file extension
            content_type = response.headers.get('content-type', '')
            if 'image' in content_type:
                if 'png' in content_type:
                    suffix = '.png'
                elif 'jpeg' in content_type or 'jpg' in content_type:
                    suffix = '.jpg'
                else:
                    suffix = '.png'  # Default
            else:
                suffix = '.png'  # Default for unknown image types
                
            with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
                tmp_file.write(response.content)
                return tmp_file.name
        else:
            return None
    except Exception as e:
        print(f"Error downloading image: {e}")
        return None

def ocr_image_with_nanonets(task_id: str = "", image_path: str = "", max_new_tokens: int = 15000) -> str:
    """
    Main function: OCR image với Nanonets model
    
    Args:
        task_id: ID để download file từ API
        image_path: Đường dẫn file ảnh local (nếu có)
        max_new_tokens: Số token tối đa cho generation
        
    Returns:
        Extracted text from image
    """
    target_image_path = None
    
    try:
        # Xác định đường dẫn ảnh
        if image_path and os.path.exists(image_path):
            target_image_path = image_path
        elif task_id:
            target_image_path = download_image_file(task_id)
            if not target_image_path:
                return "Error: Could not download image file"
        else:
            return "Error: No image path or task_id provided"
        
        # Kiểm tra file ảnh tồn tại
        if not os.path.exists(target_image_path):
            return "Error: Image file not found"
        
        # Thực hiện OCR
        result = ocr_page_with_nanonets_s(target_image_path, max_new_tokens)
        
        # Cleanup downloaded file nếu cần
        if task_id and target_image_path != image_path:
            try:
                os.unlink(target_image_path)
            except:
                pass
        
        return result
        
    except Exception as e:
        return f"Image processing error: {str(e)}"

# Fallback OCR function (nếu Nanonets không hoạt động)
def fallback_ocr_image(task_id: str = "", image_path: str = "") -> str:
    """
    Fallback OCR using basic image info (nếu Nanonets model không khả dụng)
    """
    target_image_path = None
    
    try:
        # Xác định đường dẫn ảnh
        if image_path and os.path.exists(image_path):
            target_image_path = image_path
        elif task_id:
            target_image_path = download_image_file(task_id)
            if not target_image_path:
                return "Error: Could not download image file"
        else:
            return "Error: No image path or task_id provided"
        
        # Basic image info
        img = Image.open(target_image_path)
        result = f"Image detected - Format: {img.format}, Size: {img.size}, Mode: {img.mode}. Nanonets OCR not available. Please describe what you see in the image."
        
        # Cleanup
        if task_id and target_image_path != image_path:
            try:
                os.unlink(target_image_path)
            except:
                pass
                
        return result
        
    except Exception as e:
        return f"Image processing error: {str(e)}"

# Test function
if __name__ == "__main__":
    # Test với ảnh local (nếu có)
    test_image = "/path/to/test/image.jpg"
    if os.path.exists(test_image):
        result = ocr_image_with_nanonets(image_path=test_image)
        print("OCR Result:", result)
    else:
        print("No test image found")