CalcTrainer / utils.py
hoololi's picture
Upload 6 files
36961d0 verified
raw
history blame
5.61 kB
# ==========================================
# utils.py - Fonctions communes GPU simplifiées
# ==========================================
"""
Utilitaires partagés pour le traitement d'images OCR GPU
"""
from PIL import Image, ImageEnhance
import numpy as np
import gc
import os
def create_white_canvas(width: int = 300, height: int = 300) -> Image.Image:
"""Crée un canvas blanc pour le dessin de calculs"""
return Image.new('RGB', (width, height), 'white')
def log_memory_usage(context: str = "") -> None:
"""Log l'usage mémoire actuel"""
try:
import psutil
process = psutil.Process(os.getpid())
memory_mb = process.memory_info().rss / 1024 / 1024
print(f"🔍 Mémoire {context}: {memory_mb:.1f}MB")
except:
pass
def cleanup_memory() -> None:
"""Force le nettoyage mémoire GPU"""
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except:
pass
def optimize_image_for_ocr(image_dict: dict | np.ndarray | Image.Image | None, max_size: int = 300) -> Image.Image | None:
"""
Optimisation image pour OCR GPU
Args:
image_dict: Image d'entrée (format Gradio, numpy ou PIL)
max_size: Taille maximale pour le redimensionnement
Returns:
Image PIL optimisée ou None si erreur
"""
if image_dict is None:
return None
try:
# Gérer les formats Gradio
if isinstance(image_dict, dict):
if 'composite' in image_dict and image_dict['composite'] is not None:
image = image_dict['composite']
elif 'background' in image_dict and image_dict['background'] is not None:
image = image_dict['background']
else:
return None
elif isinstance(image_dict, np.ndarray):
image = image_dict
elif isinstance(image_dict, Image.Image):
image = image_dict
else:
return None
# Conversion vers PIL
if isinstance(image, np.ndarray):
pil_image = Image.fromarray(image).convert('RGB')
else:
pil_image = image.convert('RGB')
# Redimensionnement si nécessaire
if pil_image.size[0] > max_size or pil_image.size[1] > max_size:
pil_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
return pil_image
except Exception as e:
print(f"❌ Erreur optimisation image: {e}")
return None
def prepare_image_for_dataset(image: Image.Image, max_size: tuple[int, int] = (100, 100)) -> dict[str, Image.Image | tuple | str] | None:
"""
Prépare une image pour l'inclusion dans le dataset (FORMAT IMAGE NATIF)
Args:
image: Image PIL à traiter
max_size: Taille maximale (largeur, hauteur)
Returns:
Dictionnaire avec image PIL native, taille, etc. ou None
"""
try:
if image is None:
return None
# Copier et redimensionner
dataset_image = image.copy()
dataset_image.thumbnail(max_size, Image.Resampling.LANCZOS)
compressed_size = dataset_image.size
# Structure propre pour dataset avec IMAGE NATIVE
result = {
"handwriting_image": dataset_image, # Image PIL native
"compressed_size": compressed_size,
"format": "PIL_Image",
"width": compressed_size[0],
"height": compressed_size[1]
}
return result
except Exception as e:
print(f"❌ Erreur préparation image dataset: {e}")
return None
def create_thumbnail_fast(optimized_image: Image.Image | None, size: tuple[int, int] = (40, 40)) -> str:
"""
Création miniature rapide pour affichage dans les résultats
Args:
optimized_image: Image PIL source
size: Taille de la miniature (largeur, hauteur)
Returns:
HTML img tag avec image base64 ou icône par défaut
"""
try:
if optimized_image is None:
return "📝"
# Pour l'affichage dans l'interface, on garde le base64 temporairement
import base64
from io import BytesIO
thumbnail = optimized_image.copy()
thumbnail.thumbnail(size, Image.Resampling.LANCZOS)
buffer = BytesIO()
thumbnail.save(buffer, format='PNG', optimize=True, quality=70)
img_str = base64.b64encode(buffer.getvalue()).decode()
thumbnail.close()
buffer.close()
return f'<img src="data:image/png;base64,{img_str}" width="{size[0]}" height="{size[1]}" style="border: 1px solid #ccc; border-radius: 3px;" alt="Réponse calcul">'
except Exception:
return "📝"
def validate_ocr_result(raw_result: str, max_length: int = 4) -> str:
"""
Valide et nettoie un résultat OCR
Args:
raw_result: Résultat brut de l'OCR
max_length: Longueur maximale autorisée
Returns:
Résultat nettoyé (chiffres uniquement)
"""
if not raw_result:
return "0"
# Extraire uniquement les chiffres
cleaned_result = ''.join(filter(str.isdigit, str(raw_result)))
# Valider la longueur
if cleaned_result and len(cleaned_result) <= max_length:
return cleaned_result
elif cleaned_result:
# Si trop long, prendre les premiers chiffres
return cleaned_result[:max_length]
else:
return "0"