CalcTrainer / image_processing_gpu.py
hoololi's picture
Upload 2 files
1093dfb verified
raw
history blame
8.79 kB
# ==========================================
# image_processing_gpu.py - Version ZeroGPU avec modèles OCR commutables
# ==========================================
"""
Module de traitement d'images GPU-optimisé pour ZeroGPU HuggingFace Spaces
"""
import time
import torch
import spaces
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from utils import (
optimize_image_for_ocr,
prepare_image_for_dataset,
create_thumbnail_fast,
create_white_canvas,
log_memory_usage,
cleanup_memory,
validate_ocr_result
)
# ==========================================
# Configuration des modèles OCR disponibles
# ==========================================
AVAILABLE_OCR_MODELS = {
"microsoft/trocr-base-handwritten": {
"description": "Modèle de base Microsoft pour écriture manuscrite",
"display_name": "TrOCR Base Handwritten (Microsoft)",
"optimized_for": "general_handwriting"
},
"hoololi/trocr-base-handwritten-calctrainer": {
"description": "Modèle fine tuné pour les nombres entiers",
"display_name": "TrOCR CalcTrainer (Hoololi)",
"optimized_for": "mathematical_numbers"
}
}
# Modèle par défaut
DEFAULT_OCR_MODEL = "hoololi/trocr-base-handwritten-calctrainer"
current_ocr_model_name = DEFAULT_OCR_MODEL
# Variables globales pour OCR
processor = None
model = None
current_loaded_model = None
def get_available_models() -> dict:
"""Retourne la liste des modèles OCR disponibles"""
return AVAILABLE_OCR_MODELS
def get_current_model_info() -> dict:
"""Retourne les informations du modèle OCR actuellement chargé"""
global current_ocr_model_name, current_loaded_model
model_config = AVAILABLE_OCR_MODELS.get(current_ocr_model_name, AVAILABLE_OCR_MODELS[DEFAULT_OCR_MODEL])
if torch.cuda.is_available():
device = "ZeroGPU"
gpu_name = torch.cuda.get_device_name()
else:
device = "CPU"
gpu_name = "N/A"
return {
"model_name": current_ocr_model_name,
"display_name": model_config["display_name"],
"description": model_config["description"],
"current_loaded": current_loaded_model,
"device": device,
"gpu_name": gpu_name,
"framework": "HuggingFace-Transformers-ZeroGPU",
"optimized_for": model_config["optimized_for"],
"is_loaded": processor is not None and model is not None,
# Compatibilité avec l'ancien code
"version": current_ocr_model_name
}
def set_ocr_model(model_name: str) -> bool:
"""
Change le modèle OCR actif
Args:
model_name: Nom exact du modèle (ex: "microsoft/trocr-base-handwritten")
Returns:
bool: True si le changement a réussi
"""
global current_ocr_model_name
if model_name not in AVAILABLE_OCR_MODELS:
print(f"❌ Modèle '{model_name}' non disponible. Modèles disponibles: {list(AVAILABLE_OCR_MODELS.keys())}")
return False
if model_name == current_ocr_model_name and processor is not None and model is not None:
print(f"✅ Modèle '{model_name}' déjà chargé")
return True
model_config = AVAILABLE_OCR_MODELS[model_name]
print(f"🔄 Changement vers le modèle: {model_config['display_name']}")
current_ocr_model_name = model_name
# Nettoyer le modèle précédent
cleanup_current_model()
# Charger le nouveau modèle
return init_ocr_model()
def cleanup_current_model():
"""Nettoie le modèle actuellement chargé pour libérer la mémoire"""
global processor, model, current_loaded_model
if model is not None:
del model
model = None
if processor is not None:
del processor
processor = None
current_loaded_model = None
# Nettoyage mémoire GPU si disponible
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("🧹 Modèle précédent nettoyé")
def init_ocr_model(model_name: str = None) -> bool:
"""
Initialise TrOCR pour ZeroGPU avec le modèle spécifié
Args:
model_name: Nom exact du modèle à charger (optionnel, utilise current_ocr_model_name par défaut)
"""
global processor, model, current_ocr_model_name, current_loaded_model
if model_name is not None:
if model_name not in AVAILABLE_OCR_MODELS:
print(f"❌ Modèle '{model_name}' non disponible")
return False
current_ocr_model_name = model_name
model_config = AVAILABLE_OCR_MODELS[current_ocr_model_name]
try:
print(f"🔄 Chargement {model_config['display_name']} (ZeroGPU)...")
print(f" 📍 Modèle: {current_ocr_model_name}")
processor = TrOCRProcessor.from_pretrained(current_ocr_model_name)
model = VisionEncoderDecoderModel.from_pretrained(current_ocr_model_name)
current_loaded_model = current_ocr_model_name
# Optimisations
model.eval()
if torch.cuda.is_available():
model = model.cuda()
device_info = f"GPU ({torch.cuda.get_device_name()})"
print(f"✅ {model_config['display_name']} prêt sur {device_info} !")
else:
device_info = "CPU (ZeroGPU pas encore alloué)"
print(f"⚠️ {model_config['display_name']} sur CPU - {device_info}")
return True
except Exception as e:
print(f"❌ Erreur lors du chargement {model_config['display_name']}: {e}")
return False
# Alias pour compatibilité avec l'ancien code
def get_ocr_model_info() -> dict:
"""Alias pour get_current_model_info() - compatibilité"""
return get_current_model_info()
@spaces.GPU
def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]:
"""
OCR avec TrOCR ZeroGPU - Version simplifiée avec modèle commutable
"""
if image_dict is None:
if debug:
print(" ❌ Image manquante")
return "0", None, None
try:
start_time = time.time()
if debug:
model_info = get_current_model_info()
print(f" 🔄 Début OCR {model_info['display_name']} ZeroGPU...")
# Optimiser image
optimized_image = optimize_image_for_ocr(image_dict, max_size=384)
if optimized_image is None:
if debug:
print(" ❌ Échec optimisation image")
return "0", None, None
# TrOCR - traitement ZeroGPU
if processor is None or model is None:
if debug:
print(" ❌ TrOCR non initialisé")
return "0", None, None
if debug:
print(" 🤖 Lancement TrOCR ZeroGPU...")
with torch.no_grad():
# Preprocessing
pixel_values = processor(images=optimized_image, return_tensors="pt").pixel_values
# GPU transfer si disponible
if torch.cuda.is_available():
pixel_values = pixel_values.cuda()
# Génération optimisée
generated_ids = model.generate(
pixel_values,
max_length=4,
num_beams=1,
do_sample=False,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id
)
# Décodage
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
final_result = validate_ocr_result(result, max_length=4)
# Préparer pour dataset
dataset_image_data = prepare_image_for_dataset(optimized_image)
if debug:
total_time = time.time() - start_time
device = "ZeroGPU" if torch.cuda.is_available() else "CPU"
model_name = get_current_model_info()['display_name']
print(f" ✅ {model_name} ({device}) terminé en {total_time:.1f}s → '{final_result}'")
if dataset_image_data:
print(f" 🖼️ Image dataset: {type(dataset_image_data.get('handwriting_image', 'None'))}")
return final_result, optimized_image, dataset_image_data
except Exception as e:
print(f"❌ Erreur OCR TrOCR ZeroGPU: {e}")
return "0", None, None
def recognize_number_fast(image_dict) -> tuple[str, any]:
"""Version rapide standard"""
result, optimized_image, _ = recognize_number_fast_with_image(image_dict)
return result, optimized_image
def recognize_number(image_dict) -> str:
"""Interface standard"""
result, _ = recognize_number_fast(image_dict)
return result