File size: 8,794 Bytes
1e4b8a0
1093dfb
1e4b8a0
 
cdceb7f
36961d0
cdceb7f
 
1e4b8a0
36961d0
 
 
1e4b8a0
 
 
cdceb7f
1e4b8a0
 
 
 
 
 
 
1093dfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdceb7f
 
 
1093dfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdceb7f
1093dfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdceb7f
 
1093dfb
 
cdceb7f
1093dfb
 
 
cdceb7f
 
 
 
 
 
 
1093dfb
cdceb7f
 
1093dfb
cdceb7f
 
 
 
1093dfb
cdceb7f
0316030
1093dfb
cdceb7f
1093dfb
 
1e4b8a0
36961d0
cdceb7f
1e4b8a0
1093dfb
1e4b8a0
 
 
 
 
 
 
 
 
1093dfb
 
1e4b8a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdceb7f
1e4b8a0
 
cdceb7f
1e4b8a0
 
 
cdceb7f
1e4b8a0
 
 
 
 
 
 
 
 
cdceb7f
1e4b8a0
 
 
36961d0
cdceb7f
1e4b8a0
 
 
 
1093dfb
 
cdceb7f
 
1e4b8a0
cdceb7f
1e4b8a0
 
 
 
 
cdceb7f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# ==========================================
# image_processing_gpu.py - Version ZeroGPU avec modèles OCR commutables
# ==========================================

"""
Module de traitement d'images GPU-optimisé pour ZeroGPU HuggingFace Spaces
"""

import time
import torch
import spaces
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

from utils import (
    optimize_image_for_ocr,
    prepare_image_for_dataset, 
    create_thumbnail_fast,
    create_white_canvas,
    log_memory_usage,
    cleanup_memory,
    validate_ocr_result
)

# ==========================================
# Configuration des modèles OCR disponibles
# ==========================================

AVAILABLE_OCR_MODELS = {
    "microsoft/trocr-base-handwritten": {
        "description": "Modèle de base Microsoft pour écriture manuscrite",
        "display_name": "TrOCR Base Handwritten (Microsoft)",
        "optimized_for": "general_handwriting"
    },
    "hoololi/trocr-base-handwritten-calctrainer": {
        "description": "Modèle fine tuné pour les nombres entiers",
        "display_name": "TrOCR CalcTrainer (Hoololi)",
        "optimized_for": "mathematical_numbers"
    }
}

# Modèle par défaut
DEFAULT_OCR_MODEL = "hoololi/trocr-base-handwritten-calctrainer"
current_ocr_model_name = DEFAULT_OCR_MODEL

# Variables globales pour OCR
processor = None
model = None
current_loaded_model = None

def get_available_models() -> dict:
    """Retourne la liste des modèles OCR disponibles"""
    return AVAILABLE_OCR_MODELS

def get_current_model_info() -> dict:
    """Retourne les informations du modèle OCR actuellement chargé"""
    global current_ocr_model_name, current_loaded_model
    
    model_config = AVAILABLE_OCR_MODELS.get(current_ocr_model_name, AVAILABLE_OCR_MODELS[DEFAULT_OCR_MODEL])
    
    if torch.cuda.is_available():
        device = "ZeroGPU"
        gpu_name = torch.cuda.get_device_name()
    else:
        device = "CPU"
        gpu_name = "N/A"
    
    return {
        "model_name": current_ocr_model_name,
        "display_name": model_config["display_name"],
        "description": model_config["description"],
        "current_loaded": current_loaded_model,
        "device": device,
        "gpu_name": gpu_name,
        "framework": "HuggingFace-Transformers-ZeroGPU",
        "optimized_for": model_config["optimized_for"],
        "is_loaded": processor is not None and model is not None,
        # Compatibilité avec l'ancien code
        "version": current_ocr_model_name
    }

def set_ocr_model(model_name: str) -> bool:
    """
    Change le modèle OCR actif
    
    Args:
        model_name: Nom exact du modèle (ex: "microsoft/trocr-base-handwritten")
        
    Returns:
        bool: True si le changement a réussi
    """
    global current_ocr_model_name
    
    if model_name not in AVAILABLE_OCR_MODELS:
        print(f"❌ Modèle '{model_name}' non disponible. Modèles disponibles: {list(AVAILABLE_OCR_MODELS.keys())}")
        return False
    
    if model_name == current_ocr_model_name and processor is not None and model is not None:
        print(f"✅ Modèle '{model_name}' déjà chargé")
        return True
    
    model_config = AVAILABLE_OCR_MODELS[model_name]
    print(f"🔄 Changement vers le modèle: {model_config['display_name']}")
    current_ocr_model_name = model_name
    
    # Nettoyer le modèle précédent
    cleanup_current_model()
    
    # Charger le nouveau modèle
    return init_ocr_model()

def cleanup_current_model():
    """Nettoie le modèle actuellement chargé pour libérer la mémoire"""
    global processor, model, current_loaded_model
    
    if model is not None:
        del model
        model = None
    
    if processor is not None:
        del processor
        processor = None
    
    current_loaded_model = None
    
    # Nettoyage mémoire GPU si disponible
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print("🧹 Modèle précédent nettoyé")

def init_ocr_model(model_name: str = None) -> bool:
    """
    Initialise TrOCR pour ZeroGPU avec le modèle spécifié
    
    Args:
        model_name: Nom exact du modèle à charger (optionnel, utilise current_ocr_model_name par défaut)
    """
    global processor, model, current_ocr_model_name, current_loaded_model
    
    if model_name is not None:
        if model_name not in AVAILABLE_OCR_MODELS:
            print(f"❌ Modèle '{model_name}' non disponible")
            return False
        current_ocr_model_name = model_name
    
    model_config = AVAILABLE_OCR_MODELS[current_ocr_model_name]
    
    try:
        print(f"🔄 Chargement {model_config['display_name']} (ZeroGPU)...")
        print(f"   📍 Modèle: {current_ocr_model_name}")
        
        processor = TrOCRProcessor.from_pretrained(current_ocr_model_name)
        model = VisionEncoderDecoderModel.from_pretrained(current_ocr_model_name)
        current_loaded_model = current_ocr_model_name
        
        # Optimisations
        model.eval()
        
        if torch.cuda.is_available():
            model = model.cuda()
            device_info = f"GPU ({torch.cuda.get_device_name()})"
            print(f"✅ {model_config['display_name']} prêt sur {device_info} !")
        else:
            device_info = "CPU (ZeroGPU pas encore alloué)"
            print(f"⚠️ {model_config['display_name']} sur CPU - {device_info}")
        
        return True
        
    except Exception as e:
        print(f"❌ Erreur lors du chargement {model_config['display_name']}: {e}")
        return False

# Alias pour compatibilité avec l'ancien code
def get_ocr_model_info() -> dict:
    """Alias pour get_current_model_info() - compatibilité"""
    return get_current_model_info()

@spaces.GPU
def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]:
    """
    OCR avec TrOCR ZeroGPU - Version simplifiée avec modèle commutable
    """
    if image_dict is None:
        if debug:
            print("  ❌ Image manquante")
        return "0", None, None

    try:
        start_time = time.time()
        if debug:
            model_info = get_current_model_info()
            print(f"  🔄 Début OCR {model_info['display_name']} ZeroGPU...")
        
        # Optimiser image
        optimized_image = optimize_image_for_ocr(image_dict, max_size=384)
        if optimized_image is None:
            if debug:
                print("  ❌ Échec optimisation image")
            return "0", None, None

        # TrOCR - traitement ZeroGPU
        if processor is None or model is None:
            if debug:
                print("  ❌ TrOCR non initialisé")
            return "0", None, None
            
        if debug:
            print("  🤖 Lancement TrOCR ZeroGPU...")
        
        with torch.no_grad():
            # Preprocessing
            pixel_values = processor(images=optimized_image, return_tensors="pt").pixel_values
            
            # GPU transfer si disponible
            if torch.cuda.is_available():
                pixel_values = pixel_values.cuda()
            
            # Génération optimisée
            generated_ids = model.generate(
                pixel_values,
                max_length=4,
                num_beams=1,
                do_sample=False,
                early_stopping=True,
                pad_token_id=processor.tokenizer.pad_token_id
            )
            
            # Décodage
            result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            final_result = validate_ocr_result(result, max_length=4)
        
        # Préparer pour dataset
        dataset_image_data = prepare_image_for_dataset(optimized_image)
        
        if debug:
            total_time = time.time() - start_time
            device = "ZeroGPU" if torch.cuda.is_available() else "CPU"
            model_name = get_current_model_info()['display_name']
            print(f"  ✅ {model_name} ({device}) terminé en {total_time:.1f}s → '{final_result}'")
            if dataset_image_data:
                print(f"  🖼️ Image dataset: {type(dataset_image_data.get('handwriting_image', 'None'))}")
        
        return final_result, optimized_image, dataset_image_data

    except Exception as e:
        print(f"❌ Erreur OCR TrOCR ZeroGPU: {e}")
        return "0", None, None

def recognize_number_fast(image_dict) -> tuple[str, any]:
    """Version rapide standard"""
    result, optimized_image, _ = recognize_number_fast_with_image(image_dict)
    return result, optimized_image

def recognize_number(image_dict) -> str:
    """Interface standard"""
    result, _ = recognize_number_fast(image_dict)
    return result