Spaces:
Sleeping
Sleeping
| # ========================================== | |
| # image_processing_gpu.py - Version ZeroGPU avec modèles OCR commutables | |
| # ========================================== | |
| """ | |
| Module de traitement d'images GPU-optimisé pour ZeroGPU HuggingFace Spaces | |
| """ | |
| import time | |
| import torch | |
| import spaces | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| from utils import ( | |
| optimize_image_for_ocr, | |
| prepare_image_for_dataset, | |
| create_thumbnail_fast, | |
| create_white_canvas, | |
| log_memory_usage, | |
| cleanup_memory, | |
| validate_ocr_result | |
| ) | |
| # ========================================== | |
| # Configuration des modèles OCR disponibles | |
| # ========================================== | |
| AVAILABLE_OCR_MODELS = { | |
| "microsoft/trocr-base-handwritten": { | |
| "description": "Modèle de base Microsoft pour écriture manuscrite", | |
| "display_name": "TrOCR Base Handwritten (Microsoft)", | |
| "optimized_for": "general_handwriting" | |
| }, | |
| "hoololi/trocr-base-handwritten-calctrainer": { | |
| "description": "Modèle fine tuné pour les nombres entiers", | |
| "display_name": "TrOCR CalcTrainer (Hoololi)", | |
| "optimized_for": "mathematical_numbers" | |
| } | |
| } | |
| # Modèle par défaut | |
| DEFAULT_OCR_MODEL = "hoololi/trocr-base-handwritten-calctrainer" | |
| current_ocr_model_name = DEFAULT_OCR_MODEL | |
| # Variables globales pour OCR | |
| processor = None | |
| model = None | |
| current_loaded_model = None | |
| def get_available_models() -> dict: | |
| """Retourne la liste des modèles OCR disponibles""" | |
| return AVAILABLE_OCR_MODELS | |
| def get_current_model_info() -> dict: | |
| """Retourne les informations du modèle OCR actuellement chargé""" | |
| global current_ocr_model_name, current_loaded_model | |
| model_config = AVAILABLE_OCR_MODELS.get(current_ocr_model_name, AVAILABLE_OCR_MODELS[DEFAULT_OCR_MODEL]) | |
| if torch.cuda.is_available(): | |
| device = "ZeroGPU" | |
| gpu_name = torch.cuda.get_device_name() | |
| else: | |
| device = "CPU" | |
| gpu_name = "N/A" | |
| return { | |
| "model_name": current_ocr_model_name, | |
| "display_name": model_config["display_name"], | |
| "description": model_config["description"], | |
| "current_loaded": current_loaded_model, | |
| "device": device, | |
| "gpu_name": gpu_name, | |
| "framework": "HuggingFace-Transformers-ZeroGPU", | |
| "optimized_for": model_config["optimized_for"], | |
| "is_loaded": processor is not None and model is not None, | |
| # Compatibilité avec l'ancien code | |
| "version": current_ocr_model_name | |
| } | |
| def set_ocr_model(model_name: str) -> bool: | |
| """ | |
| Change le modèle OCR actif | |
| Args: | |
| model_name: Nom exact du modèle (ex: "microsoft/trocr-base-handwritten") | |
| Returns: | |
| bool: True si le changement a réussi | |
| """ | |
| global current_ocr_model_name | |
| if model_name not in AVAILABLE_OCR_MODELS: | |
| print(f"❌ Modèle '{model_name}' non disponible. Modèles disponibles: {list(AVAILABLE_OCR_MODELS.keys())}") | |
| return False | |
| if model_name == current_ocr_model_name and processor is not None and model is not None: | |
| print(f"✅ Modèle '{model_name}' déjà chargé") | |
| return True | |
| model_config = AVAILABLE_OCR_MODELS[model_name] | |
| print(f"🔄 Changement vers le modèle: {model_config['display_name']}") | |
| current_ocr_model_name = model_name | |
| # Nettoyer le modèle précédent | |
| cleanup_current_model() | |
| # Charger le nouveau modèle | |
| return init_ocr_model() | |
| def cleanup_current_model(): | |
| """Nettoie le modèle actuellement chargé pour libérer la mémoire""" | |
| global processor, model, current_loaded_model | |
| if model is not None: | |
| del model | |
| model = None | |
| if processor is not None: | |
| del processor | |
| processor = None | |
| current_loaded_model = None | |
| # Nettoyage mémoire GPU si disponible | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print("🧹 Modèle précédent nettoyé") | |
| def init_ocr_model(model_name: str = None) -> bool: | |
| """ | |
| Initialise TrOCR pour ZeroGPU avec le modèle spécifié | |
| Args: | |
| model_name: Nom exact du modèle à charger (optionnel, utilise current_ocr_model_name par défaut) | |
| """ | |
| global processor, model, current_ocr_model_name, current_loaded_model | |
| if model_name is not None: | |
| if model_name not in AVAILABLE_OCR_MODELS: | |
| print(f"❌ Modèle '{model_name}' non disponible") | |
| return False | |
| current_ocr_model_name = model_name | |
| model_config = AVAILABLE_OCR_MODELS[current_ocr_model_name] | |
| try: | |
| print(f"🔄 Chargement {model_config['display_name']} (ZeroGPU)...") | |
| print(f" 📍 Modèle: {current_ocr_model_name}") | |
| processor = TrOCRProcessor.from_pretrained(current_ocr_model_name) | |
| model = VisionEncoderDecoderModel.from_pretrained(current_ocr_model_name) | |
| current_loaded_model = current_ocr_model_name | |
| # Optimisations | |
| model.eval() | |
| if torch.cuda.is_available(): | |
| model = model.cuda() | |
| device_info = f"GPU ({torch.cuda.get_device_name()})" | |
| print(f"✅ {model_config['display_name']} prêt sur {device_info} !") | |
| else: | |
| device_info = "CPU (ZeroGPU pas encore alloué)" | |
| print(f"⚠️ {model_config['display_name']} sur CPU - {device_info}") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Erreur lors du chargement {model_config['display_name']}: {e}") | |
| return False | |
| # Alias pour compatibilité avec l'ancien code | |
| def get_ocr_model_info() -> dict: | |
| """Alias pour get_current_model_info() - compatibilité""" | |
| return get_current_model_info() | |
| def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]: | |
| """ | |
| OCR avec TrOCR ZeroGPU - Version simplifiée avec modèle commutable | |
| """ | |
| if image_dict is None: | |
| if debug: | |
| print(" ❌ Image manquante") | |
| return "0", None, None | |
| try: | |
| start_time = time.time() | |
| if debug: | |
| model_info = get_current_model_info() | |
| print(f" 🔄 Début OCR {model_info['display_name']} ZeroGPU...") | |
| # Optimiser image | |
| optimized_image = optimize_image_for_ocr(image_dict, max_size=384) | |
| if optimized_image is None: | |
| if debug: | |
| print(" ❌ Échec optimisation image") | |
| return "0", None, None | |
| # TrOCR - traitement ZeroGPU | |
| if processor is None or model is None: | |
| if debug: | |
| print(" ❌ TrOCR non initialisé") | |
| return "0", None, None | |
| if debug: | |
| print(" 🤖 Lancement TrOCR ZeroGPU...") | |
| with torch.no_grad(): | |
| # Preprocessing | |
| pixel_values = processor(images=optimized_image, return_tensors="pt").pixel_values | |
| # GPU transfer si disponible | |
| if torch.cuda.is_available(): | |
| pixel_values = pixel_values.cuda() | |
| # Génération optimisée | |
| generated_ids = model.generate( | |
| pixel_values, | |
| max_length=4, | |
| num_beams=1, | |
| do_sample=False, | |
| early_stopping=True, | |
| pad_token_id=processor.tokenizer.pad_token_id | |
| ) | |
| # Décodage | |
| result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| final_result = validate_ocr_result(result, max_length=4) | |
| # Préparer pour dataset | |
| dataset_image_data = prepare_image_for_dataset(optimized_image) | |
| if debug: | |
| total_time = time.time() - start_time | |
| device = "ZeroGPU" if torch.cuda.is_available() else "CPU" | |
| model_name = get_current_model_info()['display_name'] | |
| print(f" ✅ {model_name} ({device}) terminé en {total_time:.1f}s → '{final_result}'") | |
| if dataset_image_data: | |
| print(f" 🖼️ Image dataset: {type(dataset_image_data.get('handwriting_image', 'None'))}") | |
| return final_result, optimized_image, dataset_image_data | |
| except Exception as e: | |
| print(f"❌ Erreur OCR TrOCR ZeroGPU: {e}") | |
| return "0", None, None | |
| def recognize_number_fast(image_dict) -> tuple[str, any]: | |
| """Version rapide standard""" | |
| result, optimized_image, _ = recognize_number_fast_with_image(image_dict) | |
| return result, optimized_image | |
| def recognize_number(image_dict) -> str: | |
| """Interface standard""" | |
| result, _ = recognize_number_fast(image_dict) | |
| return result |