Spaces:
Sleeping
Sleeping
| # ========================================== | |
| # game_engine.py - Avec métriques OCR et dataset optimisé + modèles commutables | |
| # ========================================== | |
| """ | |
| Moteur de jeu avec tracking complet des performances OCR et support modèles commutables | |
| """ | |
| import random | |
| import time | |
| import datetime | |
| import gradio as gr | |
| import os | |
| import uuid | |
| import gc | |
| import numpy as np | |
| from PIL import Image | |
| # Import GPU uniquement | |
| from image_processing_gpu import ( | |
| recognize_number_fast_with_image, | |
| create_thumbnail_fast, | |
| create_white_canvas, | |
| cleanup_memory, | |
| get_ocr_model_info, | |
| get_available_models, | |
| set_ocr_model, | |
| get_current_model_info | |
| ) | |
| print("✅ Game Engine: Mode GPU avec métriques OCR et modèles commutables") | |
| # Imports dataset | |
| try: | |
| from datasets import Dataset, Image as DatasetImage, load_dataset | |
| DATASET_AVAILABLE = True | |
| print("✅ Modules dataset disponibles") | |
| except ImportError as e: | |
| DATASET_AVAILABLE = False | |
| print(f"⚠️ Modules dataset non disponibles: {e}") | |
| # Dataset name avec nouvelle structure cohérente | |
| DATASET_NAME = "hoololi/CalcTrainer_dataset" | |
| # Configuration des difficultés par opération | |
| DIFFICULTY_RANGES = { | |
| "×": {"Facile": (2, 9), "Difficile": (4, 12)}, | |
| "+": {"Facile": (1, 50), "Difficile": (10, 100)}, | |
| "-": {"Facile": (1, 50), "Difficile": (10, 100)}, | |
| "÷": {"Facile": (1, 10), "Difficile": (2, 12)} | |
| } | |
| def get_ocr_models_info() -> dict: | |
| """Retourne les informations sur les modèles OCR disponibles""" | |
| try: | |
| available_models = get_available_models() | |
| current_model = get_current_model_info() | |
| return { | |
| "available_models": available_models, | |
| "current_model": current_model, | |
| "model_names": list(available_models.keys()) | |
| } | |
| except Exception as e: | |
| print(f"❌ Erreur get_ocr_models_info: {e}") | |
| return { | |
| "available_models": {}, | |
| "current_model": {"model_name": "hoololi/trocr-base-handwritten-calctrainer"}, | |
| "model_names": [] | |
| } | |
| def switch_ocr_model(model_name: str) -> str: | |
| """Change le modèle OCR et retourne un message de statut""" | |
| try: | |
| success = set_ocr_model(model_name) | |
| if success: | |
| model_info = get_current_model_info() | |
| return f"✅ Modèle changé vers: {model_info['display_name']}\n📍 {model_info['description']}" | |
| else: | |
| return f"❌ Échec du changement vers: {model_name}" | |
| except Exception as e: | |
| return f"❌ Erreur lors du changement: {str(e)}" | |
| def create_result_row_with_metrics(i: int, image: dict | np.ndarray | Image.Image, expected: int, operation_data: tuple[int, int, str, int]) -> dict: | |
| """Traite une image avec OCR et mesure les métriques""" | |
| print(f"🔍 Traitement OCR image #{i+1}...") | |
| # Mesurer temps OCR précisément | |
| ocr_start_time = time.time() | |
| recognized, optimized_image, dataset_image_data = recognize_number_fast_with_image(image, debug=False) | |
| ocr_processing_time = time.time() - ocr_start_time | |
| print(f" ⏱️ OCR temps: {ocr_processing_time:.3f}s → '{recognized}'") | |
| try: | |
| recognized_num = int(recognized) if recognized.isdigit() else 0 | |
| except: | |
| recognized_num = 0 | |
| is_correct = recognized_num == expected | |
| a, b, operation, correct_result = operation_data | |
| status_icon = "✅" if is_correct else "❌" | |
| status_text = "Correct" if is_correct else "Incorrect" | |
| row_color = "#e8f5e8" if is_correct else "#ffe8e8" | |
| # Miniature pour affichage | |
| image_thumbnail = create_thumbnail_fast(optimized_image, size=(50, 50)) | |
| # Libérer mémoire | |
| if optimized_image and hasattr(optimized_image, 'close'): | |
| try: | |
| optimized_image.close() | |
| except: | |
| pass | |
| return { | |
| 'html_row': f""" | |
| <tr style="background-color: {row_color};"> | |
| <td style="text-align: center; padding: 8px; border: 1px solid #ddd; color: #333;">{i+1}</td> | |
| <td style="text-align: center; padding: 8px; border: 1px solid #ddd; font-weight: bold; color: #333;">{a}</td> | |
| <td style="text-align: center; padding: 8px; border: 1px solid #ddd; font-weight: bold; color: #333;">{operation}</td> | |
| <td style="text-align: center; padding: 8px; border: 1px solid #ddd; font-weight: bold; color: #333;">{b}</td> | |
| <td style="text-align: center; padding: 8px; border: 1px solid #ddd; font-weight: bold; color: #333;">{expected}</td> | |
| <td style="text-align: center; padding: 8px; border: 1px solid #ddd;">{image_thumbnail}</td> | |
| <td style="text-align: center; padding: 8px; border: 1px solid #ddd; font-weight: bold; color: #333;">{recognized_num}</td> | |
| <td style="text-align: center; padding: 8px; border: 1px solid #ddd; color: #333;">{status_icon} {status_text}</td> | |
| <td style="text-align: center; padding: 8px; border: 1px solid #ddd; color: #666; font-size: 0.9em;">{ocr_processing_time:.3f}s</td> | |
| </tr> | |
| """, | |
| 'is_correct': is_correct, | |
| 'recognized': recognized, | |
| 'recognized_num': recognized_num, | |
| 'dataset_image_data': dataset_image_data, | |
| 'ocr_processing_time': ocr_processing_time | |
| } | |
| class MathGame: | |
| """Moteur de jeu avec métriques OCR complètes et modèles commutables""" | |
| def __init__(self): | |
| self.is_running = False | |
| self.start_time = 0 | |
| self.current_operation = "" | |
| self.correct_answer = 0 | |
| self.user_images = [] | |
| self.expected_answers = [] | |
| self.operations_history = [] | |
| self.question_count = 0 | |
| self.time_remaining = 30 | |
| self.session_data = [] | |
| # Configuration session | |
| self.duration = 30 | |
| self.operation_type = "×" | |
| self.difficulty = "Facile" | |
| # Gestion export | |
| self.export_status = "not_exported" | |
| self.export_timestamp = None | |
| self.export_result = None | |
| def get_export_status(self) -> dict[str, str | bool | None]: | |
| return { | |
| "status": self.export_status, | |
| "timestamp": self.export_timestamp, | |
| "result": self.export_result, | |
| "can_export": self.export_status == "not_exported" and len(self.session_data) > 0 | |
| } | |
| def mark_export_in_progress(self) -> None: | |
| self.export_status = "exporting" | |
| self.export_timestamp = datetime.datetime.now().isoformat() | |
| def mark_export_completed(self, result: str) -> None: | |
| self.export_status = "exported" | |
| self.export_result = result | |
| def generate_multiplication(self, difficulty: str) -> tuple[str, int]: | |
| """Génère une multiplication""" | |
| min_val, max_val = DIFFICULTY_RANGES["×"][difficulty] | |
| a = random.randint(min_val, max_val) | |
| b = random.randint(min_val, max_val) | |
| return f"{a} × {b}", a * b | |
| def generate_addition(self, difficulty: str) -> tuple[str, int]: | |
| """Génère une addition""" | |
| min_val, max_val = DIFFICULTY_RANGES["+"][difficulty] | |
| a = random.randint(min_val, max_val) | |
| b = random.randint(min_val, max_val) | |
| return f"{a} + {b}", a + b | |
| def generate_subtraction(self, difficulty: str) -> tuple[str, int]: | |
| """Génère une soustraction (résultat toujours positif)""" | |
| min_val, max_val = DIFFICULTY_RANGES["-"][difficulty] | |
| a = random.randint(min_val, max_val) | |
| b = random.randint(min_val, a) | |
| return f"{a} - {b}", a - b | |
| def generate_division(self, difficulty: str) -> tuple[str, int]: | |
| """Génère une division exacte""" | |
| min_result, max_result = DIFFICULTY_RANGES["÷"][difficulty] | |
| result = random.randint(min_result, max_result) | |
| if difficulty == "Facile": | |
| divisor = random.randint(2, 9) | |
| else: | |
| divisor = random.randint(2, 12) | |
| dividend = result * divisor | |
| return f"{dividend} ÷ {divisor}", result | |
| def generate_operation(self, operation_type: str, difficulty: str) -> tuple[str, int]: | |
| """Génère une opération selon le type et la difficulté""" | |
| if operation_type == "×": | |
| return self.generate_multiplication(difficulty) | |
| elif operation_type == "+": | |
| return self.generate_addition(difficulty) | |
| elif operation_type == "-": | |
| return self.generate_subtraction(difficulty) | |
| elif operation_type == "÷": | |
| return self.generate_division(difficulty) | |
| elif operation_type == "Aléatoire": | |
| random_op = random.choice(["×", "+", "-", "÷"]) | |
| return self.generate_operation(random_op, difficulty) | |
| else: | |
| return self.generate_multiplication(difficulty) | |
| def start_game(self, duration: str, operation: str, difficulty: str) -> tuple[str, Image.Image, str, str, gr.update, gr.update, str]: | |
| """Démarre le jeu avec la configuration choisie""" | |
| # Configuration | |
| self.duration = 60 if duration == "60 secondes" else 30 | |
| self.operation_type = operation | |
| self.difficulty = difficulty | |
| # Nettoyage simple | |
| if hasattr(self, 'user_images') and self.user_images: | |
| for img in self.user_images: | |
| if hasattr(img, 'close'): | |
| try: | |
| img.close() | |
| except: | |
| pass | |
| # Réinitialisation complète | |
| self.is_running = True | |
| self.start_time = time.time() | |
| self.user_images = [] | |
| self.expected_answers = [] | |
| self.operations_history = [] | |
| self.question_count = 0 | |
| self.time_remaining = self.duration | |
| self.session_data = [] | |
| # Reset export | |
| self.export_status = "not_exported" | |
| self.export_timestamp = None | |
| self.export_result = None | |
| gc.collect() | |
| # Première opération | |
| operation_str, answer = self.generate_operation(self.operation_type, self.difficulty) | |
| self.current_operation = operation_str | |
| self.correct_answer = answer | |
| # Parser l'opération pour l'historique | |
| parts = operation_str.split() | |
| a, op, b = int(parts[0]), parts[1], int(parts[2]) | |
| self.operations_history.append((a, b, op, answer)) | |
| # Affichage | |
| operation_emoji = { | |
| "×": "✖️", "+": "➕", "-": "➖", "÷": "➗", "Aléatoire": "🎲" | |
| } | |
| emoji = operation_emoji.get(self.operation_type, "🔢") | |
| return ( | |
| f'<div style="font-size: 3em; font-weight: bold; text-align: center; padding: 20px; background: linear-gradient(45deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px;">{operation_str}</div>', | |
| create_white_canvas(), | |
| f"🎯 {emoji} {self.operation_type} • {self.difficulty} • Écrivez votre réponse !", | |
| f"⏱️ Temps restant: {self.time_remaining}s", | |
| gr.update(interactive=False), | |
| gr.update(interactive=True), | |
| "" | |
| ) | |
| def next_question(self, image_data: dict | np.ndarray | Image.Image | None) -> tuple[str, Image.Image, str, str, gr.update, gr.update, str]: | |
| """Passe à la question suivante - STOCKAGE SIMPLE, PAS D'OCR""" | |
| if not self.is_running: | |
| return ( | |
| f'<div style="font-size: 3em; font-weight: bold; text-align: center; padding: 20px; background: linear-gradient(45deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px;">{self.current_operation}</div>', | |
| image_data, | |
| "❌ Le jeu n'est pas en cours !", | |
| "⏱️ Temps: 0s", | |
| gr.update(interactive=True), | |
| gr.update(interactive=False), | |
| "" | |
| ) | |
| elapsed_time = time.time() - self.start_time | |
| if elapsed_time >= self.duration: | |
| return self.end_game(image_data) | |
| # STOCKAGE SIMPLE - PAS D'OCR pendant le jeu ! | |
| if image_data is not None: | |
| self.user_images.append(image_data) | |
| self.expected_answers.append(self.correct_answer) | |
| self.question_count += 1 | |
| print(f"📝 Image {self.question_count} stockée (pas d'OCR pendant le jeu)") | |
| # Nouvelle opération | |
| operation_str, answer = self.generate_operation(self.operation_type, self.difficulty) | |
| self.current_operation = operation_str | |
| self.correct_answer = answer | |
| # Parser pour l'historique | |
| parts = operation_str.split() | |
| a, op, b = int(parts[0]), parts[1], int(parts[2]) | |
| self.operations_history.append((a, b, op, answer)) | |
| time_remaining = max(0, self.duration - int(elapsed_time)) | |
| self.time_remaining = time_remaining | |
| if time_remaining <= 0: | |
| return self.end_game(image_data) | |
| # Emoji pour l'opération | |
| operation_emoji = { | |
| "×": "✖️", "+": "➕", "-": "➖", "÷": "➗", "Aléatoire": "🎲" | |
| } | |
| emoji = operation_emoji.get(self.operation_type, "🔢") | |
| return ( | |
| f'<div style="font-size: 3em; font-weight: bold; text-align: center; padding: 20px; background: linear-gradient(45deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px;">{operation_str}</div>', | |
| create_white_canvas(), | |
| f"🎯 {emoji} Question {self.question_count + 1} • {self.difficulty}", | |
| f"⏱️ Temps restant: {time_remaining}s", | |
| gr.update(interactive=False), | |
| gr.update(interactive=True), | |
| "" | |
| ) | |
| def end_game(self, final_image: dict | np.ndarray | Image.Image | None) -> tuple[str, Image.Image, str, str, gr.update, gr.update, str]: | |
| """Fin de jeu - OCR AVEC MÉTRIQUES COMPLÈTES""" | |
| self.is_running = False | |
| print("🏁 Fin de jeu - Début OCR avec métriques détaillées...") | |
| # Ajouter la dernière image si présente | |
| if final_image is not None: | |
| self.user_images.append(final_image) | |
| self.expected_answers.append(self.correct_answer) | |
| self.question_count += 1 | |
| # Ajouter l'opération finale à l'historique si nécessaire | |
| if len(self.operations_history) < len(self.user_images): | |
| parts = self.current_operation.split() | |
| a, op, b = int(parts[0]), parts[1], int(parts[2]) | |
| self.operations_history.append((a, b, op, self.correct_answer)) | |
| # OCR SÉQUENTIEL AVEC MÉTRIQUES | |
| total_questions = len(self.user_images) | |
| correct_answers = 0 | |
| table_rows_html = "" | |
| session_timestamp = datetime.datetime.now().isoformat() | |
| session_id = f"session_{int(datetime.datetime.now().timestamp())}_{str(uuid.uuid4())[:8]}" | |
| # Métriques OCR globales | |
| total_ocr_start_time = time.time() | |
| ocr_times = [] | |
| self.session_data = [] | |
| images_saved = 0 | |
| print(f"🔄 Traitement OCR avec métriques de {total_questions} images...") | |
| # Récupérer infos modèle OCR une seule fois - MODIFIÉ pour utiliser le nouveau système | |
| try: | |
| ocr_model_info = get_ocr_model_info() | |
| model_name = ocr_model_info.get("model_name", "hoololi/trocr-base-handwritten-calctrainer") | |
| hardware = f"{ocr_model_info.get('device', 'Unknown')}-{ocr_model_info.get('gpu_name', 'Unknown')}" | |
| except Exception as e: | |
| print(f"❌ Erreur get_ocr_model_info: {e}") | |
| model_name = "hoololi/trocr-base-handwritten-calctrainer" | |
| hardware = "ZeroGPU-Unknown" | |
| # Boucle OCR avec métriques | |
| for i in range(total_questions): | |
| print(f"📷 OCR image {i+1}/{total_questions}...") | |
| # OCR avec métriques | |
| row_data = create_result_row_with_metrics( | |
| i, | |
| self.user_images[i], | |
| self.expected_answers[i], | |
| self.operations_history[i] if i < len(self.operations_history) else (0, 0, "×", 0) | |
| ) | |
| table_rows_html += row_data['html_row'] | |
| ocr_times.append(row_data['ocr_processing_time']) | |
| if row_data['is_correct']: | |
| correct_answers += 1 | |
| # Structure dataset optimisée | |
| a, b, operation, correct_result = self.operations_history[i] if i < len(self.operations_history) else (0, 0, "×", 0) | |
| # ID unique pour cette question | |
| question_id = f"{session_id}_q{i+1:02d}" | |
| entry = { | |
| # Identification | |
| "session_id": session_id, | |
| "question_id": question_id, | |
| "timestamp": session_timestamp, | |
| # Données mathématiques | |
| "operand_a": a, | |
| "operand_b": b, | |
| "operation": operation, | |
| "correct_answer": self.expected_answers[i], | |
| "difficulty": self.difficulty, | |
| # Données OCR | |
| "ocr_prediction": row_data['recognized'], | |
| "ocr_parsed_number": row_data['recognized_num'], | |
| "is_correct": row_data['is_correct'], | |
| # Métriques modèle OCR | |
| "ocr_model_name": model_name, | |
| "ocr_processing_time": row_data['ocr_processing_time'], | |
| "ocr_confidence": 0.0, # Non disponible avec TrOCR actuel | |
| # Métriques session (calculées à la fin) | |
| "session_duration": self.duration, | |
| "session_total_questions": total_questions, | |
| # Métadonnées techniques | |
| "app_version": "3.2_with_switchable_models", | |
| "hardware": hardware | |
| } | |
| # Image PIL native pour dataset | |
| if row_data['dataset_image_data']: | |
| entry["handwriting_image"] = row_data['dataset_image_data']["handwriting_image"] | |
| images_saved += 1 | |
| self.session_data.append(entry) | |
| # Calculs finaux métriques | |
| total_ocr_time = time.time() - total_ocr_start_time | |
| avg_ocr_time = sum(ocr_times) / len(ocr_times) if ocr_times else 0.0 | |
| accuracy = (correct_answers / total_questions * 100) if total_questions > 0 else 0 | |
| # Ajouter métriques session à toutes les entrées | |
| for entry in self.session_data: | |
| entry["session_accuracy"] = accuracy | |
| entry["session_total_ocr_time"] = total_ocr_time | |
| entry["session_avg_ocr_time"] = avg_ocr_time | |
| # Statistiques détaillées | |
| print(f"📊 === MÉTRIQUES OCR COMPLÈTES ===") | |
| print(f"📷 Images traitées: {total_questions}") | |
| print(f"⏱️ Temps total OCR: {total_ocr_time:.2f}s") | |
| print(f"⚡ Temps moyen/image: {avg_ocr_time:.3f}s") | |
| print(f"🎯 Précision: {accuracy:.1f}%") | |
| print(f"🤖 Modèle: {model_name}") | |
| print(f"💻 Hardware: {hardware}") | |
| # Statistiques par opération | |
| operations_stats = {} | |
| for entry in self.session_data: | |
| op = entry['operation'] | |
| if op not in operations_stats: | |
| operations_stats[op] = {'correct': 0, 'total': 0, 'times': []} | |
| operations_stats[op]['total'] += 1 | |
| operations_stats[op]['times'].append(entry['ocr_processing_time']) | |
| if entry['is_correct']: | |
| operations_stats[op]['correct'] += 1 | |
| print(f"📈 Détail par opération:") | |
| for op, stats in operations_stats.items(): | |
| op_accuracy = (stats['correct'] / stats['total'] * 100) if stats['total'] > 0 else 0 | |
| op_avg_time = sum(stats['times']) / len(stats['times']) if stats['times'] else 0 | |
| print(f" {op}: {op_accuracy:.1f}% précision, {op_avg_time:.3f}s/image ({stats['total']} images)") | |
| # Nettoyage mémoire | |
| for img in self.user_images: | |
| if hasattr(img, 'close'): | |
| try: | |
| img.close() | |
| except: | |
| pass | |
| cleanup_memory() | |
| # HTML résultats avec colonne temps | |
| table_html = f""" | |
| <div style="overflow-x: auto; margin: 20px 0;"> | |
| <table style="width: 100%; border-collapse: collapse; border: 2px solid #4a90e2;"> | |
| <thead> | |
| <tr style="background: #4a90e2; color: white;"> | |
| <th style="padding: 8px;">Question</th> | |
| <th style="padding: 8px;">A</th> | |
| <th style="padding: 8px;">Op</th> | |
| <th style="padding: 8px;">B</th> | |
| <th style="padding: 8px;">Réponse</th> | |
| <th style="padding: 8px;">Votre dessin</th> | |
| <th style="padding: 8px;">OCR</th> | |
| <th style="padding: 8px;">Statut</th> | |
| <th style="padding: 8px;">Temps OCR</th> | |
| </tr> | |
| </thead> | |
| <tbody> | |
| {table_rows_html} | |
| </tbody> | |
| </table> | |
| </div> | |
| """ | |
| # Configuration session pour affichage | |
| config_display = f"{self.operation_type} • {self.difficulty} • {self.duration}s" | |
| export_info = self.get_export_status() | |
| if export_info["can_export"]: | |
| export_section = f""" | |
| <div style="margin-top: 20px; padding: 15px; background-color: #e8f5e8; border-radius: 8px;"> | |
| <h3 style="color: #2e7d32;">📊 Métriques de la série</h3> | |
| <p style="color: #2e7d32;"> | |
| ✅ {total_questions} réponses • 📊 {accuracy:.1f}% de précision<br> | |
| 🖼️ {images_saved} images sauvegardées<br> | |
| ⏱️ OCR: {total_ocr_time:.2f}s total, {avg_ocr_time:.3f}s/image<br> | |
| 🤖 Modèle: {model_name}<br> | |
| 💻 Hardware: {hardware}<br> | |
| ⚙️ Configuration: {config_display} | |
| </p> | |
| </div> | |
| """ | |
| else: | |
| export_section = "" | |
| final_results = f""" | |
| <div style="margin: 20px 0;"> | |
| <div style="background: linear-gradient(45deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px; border-radius: 10px; margin: 20px 0;"> | |
| <h2 style="text-align: center;">🎉 Session terminée !</h2> | |
| <div style="display: flex; justify-content: space-around; flex-wrap: wrap;"> | |
| <div style="text-align: center; margin: 10px;"> | |
| <div style="font-size: 2em; font-weight: bold;">{total_questions}</div> | |
| <div>Questions</div> | |
| </div> | |
| <div style="text-align: center; margin: 10px;"> | |
| <div style="font-size: 2em; font-weight: bold; color: #90EE90;">{correct_answers}</div> | |
| <div>Correctes</div> | |
| </div> | |
| <div style="text-align: center; margin: 10px;"> | |
| <div style="font-size: 2em; font-weight: bold; color: #FFB6C1;">{total_questions - correct_answers}</div> | |
| <div>Incorrectes</div> | |
| </div> | |
| <div style="text-align: center; margin: 10px;"> | |
| <div style="font-size: 2em; font-weight: bold;">{accuracy:.1f}%</div> | |
| <div>Précision</div> | |
| </div> | |
| <div style="text-align: center; margin: 10px;"> | |
| <div style="font-size: 1.5em; font-weight: bold; color: #87CEEB;">{avg_ocr_time:.3f}s</div> | |
| <div>Temps/image</div> | |
| </div> | |
| </div> | |
| </div> | |
| <h2 style="color: #4a90e2;">📊 Détail des Réponses avec Métriques OCR</h2> | |
| {table_html} | |
| {export_section} | |
| </div> | |
| """ | |
| return ( | |
| """<div style="font-size: 3em; font-weight: bold; text-align: center; padding: 20px; background: linear-gradient(45deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px;">🏁 C'est fini !</div>""", | |
| create_white_canvas(), | |
| f"✨ Session {config_display} terminée !", | |
| "⏱️ Temps écoulé !", | |
| gr.update(interactive=True), | |
| gr.update(interactive=False), | |
| final_results | |
| ) | |
| def export_to_optimized_dataset(session_data: list[dict], dataset_name: str = None) -> str: | |
| """Export vers le dataset optimisé avec métriques OCR""" | |
| if dataset_name is None: | |
| dataset_name = DATASET_NAME | |
| if not DATASET_AVAILABLE: | |
| return "❌ Modules dataset non disponibles" | |
| hf_token = os.getenv("HF_TOKEN") or os.getenv("tk_calcul_ocr") | |
| if not hf_token: | |
| return "❌ Token HuggingFace manquant" | |
| try: | |
| print(f"\n🚀 === EXPORT DATASET OPTIMISÉ AVEC MÉTRIQUES ===") | |
| print(f"📊 Dataset: {dataset_name}") | |
| # Filtrer les entrées avec images | |
| clean_entries = [entry for entry in session_data if entry.get('handwriting_image') is not None] | |
| if len(clean_entries) == 0: | |
| return "❌ Aucune entrée avec image à exporter" | |
| # Statistiques pré-export | |
| total_ocr_time = clean_entries[0].get('session_total_ocr_time', 0) | |
| avg_ocr_time = clean_entries[0].get('session_avg_ocr_time', 0) | |
| model_name = clean_entries[0].get('ocr_model_name', 'Unknown') | |
| session_accuracy = clean_entries[0].get('session_accuracy', 0) | |
| print(f"📈 Métriques session:") | |
| print(f" - {len(clean_entries)} images") | |
| print(f" - {session_accuracy:.1f}% précision") | |
| print(f" - {total_ocr_time:.2f}s total OCR") | |
| print(f" - {avg_ocr_time:.3f}s/image") | |
| print(f" - Modèle: {model_name}") | |
| # Charger dataset existant et combiner | |
| try: | |
| existing_dataset = load_dataset(dataset_name, split="train") | |
| existing_data = existing_dataset.to_list() | |
| print(f"📊 {len(existing_data)} entrées existantes") | |
| combined_data = existing_data + clean_entries | |
| clean_dataset = Dataset.from_list(combined_data) | |
| print(f"📊 Dataset combiné: {len(combined_data)} total") | |
| except Exception as e: | |
| print(f"📊 Nouveau dataset: {e}") | |
| clean_dataset = Dataset.from_list(clean_entries) | |
| # Conversion colonne image | |
| try: | |
| clean_dataset = clean_dataset.cast_column("handwriting_image", DatasetImage()) | |
| print("✅ Colonne image convertie") | |
| except Exception as e: | |
| print(f"⚠️ Conversion image: {e}") | |
| # Statistiques par opération pour commit message | |
| operations_count = {} | |
| for entry in clean_entries: | |
| op = entry.get('operation', 'unknown') | |
| operations_count[op] = operations_count.get(op, 0) + 1 | |
| operations_summary = ", ".join([f"{op}: {count}" for op, count in operations_count.items()]) | |
| # Message de commit enrichi avec métriques | |
| commit_message = f"""Add {len(clean_entries)} samples with OCR metrics | |
| Model: {model_name} | |
| Accuracy: {session_accuracy:.1f}% | |
| Avg OCR time: {avg_ocr_time:.3f}s/image | |
| Operations: {operations_summary} | |
| Hardware: {clean_entries[0].get('hardware', 'Unknown')} | |
| """ | |
| # Push vers HuggingFace | |
| print(f"📤 Push vers {dataset_name}...") | |
| clean_dataset.push_to_hub( | |
| dataset_name, | |
| private=False, | |
| token=hf_token, | |
| commit_message=commit_message | |
| ) | |
| cleanup_memory() | |
| return f"""### ✅ Session ajoutée au dataset optimisé ! | |
| 📊 **Dataset:** {dataset_name} | |
| 🖼️ **Images:** {len(clean_entries)} | |
| 🎯 **Précision:** {session_accuracy:.1f}% | |
| ⏱️ **Performance:** {avg_ocr_time:.3f}s/image (total: {total_ocr_time:.1f}s) | |
| 🤖 **Modèle:** {model_name} | |
| 🔢 **Opérations:** {operations_summary} | |
| 📈 **Total dataset:** {len(clean_dataset)} | |
| 🔗 <a href="https://huggingface.co/datasets/{dataset_name}" target="_blank">{dataset_name}</a> | |
| """ | |
| except Exception as e: | |
| print(f"❌ ERREUR: {e}") | |
| return f"❌ Erreur: {str(e)}" | |
| # Fonction de compatibilité pour ne pas casser l'interface | |
| def export_to_clean_dataset(session_data: list[dict], dataset_name: str = None) -> str: | |
| """Wrapper pour compatibilité avec l'ancienne interface""" | |
| return export_to_optimized_dataset(session_data, dataset_name) |