Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from huggingface_hub import login | |
| import os | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| import time | |
| from langdetect import detect | |
| # Authentification | |
| login(token=os.environ["HF_TOKEN"]) | |
| # Liste des modèles | |
| models = [ | |
| "meta-llama/Llama-2-13b-hf", | |
| "meta-llama/Llama-2-7b-hf", | |
| "meta-llama/Llama-2-70b-hf", | |
| "meta-llama/Meta-Llama-3-8B", | |
| "meta-llama/Llama-3.2-3B", | |
| "meta-llama/Llama-3.1-8B", | |
| "mistralai/Mistral-7B-v0.1", | |
| "mistralai/Mixtral-8x7B-v0.1", | |
| "mistralai/Mistral-7B-v0.3", | |
| "google/gemma-2-2b", | |
| "google/gemma-2-9b", | |
| "google/gemma-2-27b", | |
| "croissantllm/CroissantLLMBase" | |
| ] | |
| # Dictionnaire des langues supportées par modèle | |
| model_languages = { | |
| "meta-llama/Llama-2-13b-hf": ["en"], | |
| "meta-llama/Llama-2-7b-hf": ["en"], | |
| "meta-llama/Llama-2-70b-hf": ["en"], | |
| "meta-llama/Meta-Llama-3-8B": ["en"], | |
| "meta-llama/Llama-3.2-3B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"], | |
| "meta-llama/Llama-3.1-8B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"], | |
| "mistralai/Mistral-7B-v0.1": ["en"], | |
| "mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"], | |
| "mistralai/Mistral-7B-v0.3": ["en"], | |
| "google/gemma-2-2b": ["en"], | |
| "google/gemma-2-9b": ["en"], | |
| "google/gemma-2-27b": ["en"], | |
| "croissantllm/CroissantLLMBase": ["en", "fr"] | |
| } | |
| # Variables globales | |
| model = None | |
| tokenizer = None | |
| def load_model(model_name, progress=gr.Progress()): | |
| global model, tokenizer | |
| try: | |
| progress(0, desc="Chargement du tokenizer") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| progress(0.5, desc="Chargement du modèle") | |
| # Configurations spécifiques par modèle | |
| if "mixtral" in model_name.lower(): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| attn_implementation="flash_attention_2", | |
| load_in_8bit=True | |
| ) | |
| elif "llama" in model_name.lower() or "mistral" in model_name.lower(): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| attn_implementation="flash_attention_2" | |
| ) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| progress(1.0, desc="Modèle chargé") | |
| return f"Modèle {model_name} chargé avec succès." | |
| except Exception as e: | |
| return f"Erreur lors du chargement du modèle : {str(e)}" | |
| def ensure_token_display(token): | |
| """Assure que le token est affiché correctement.""" | |
| if token.isdigit() or (token.startswith('-') and token[1:].isdigit()): | |
| return tokenizer.decode([int(token)]) | |
| return token | |
| def analyze_next_token(input_text, temperature, top_p, top_k): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| return "Veuillez d'abord charger un modèle.", None, None | |
| # Détection de la langue | |
| detected_lang = detect(input_text) | |
| if detected_lang not in model_languages.get(model.config._name_or_path, []): | |
| return f"Langue détectée ({detected_lang}) non supportée par ce modèle.", None, None | |
| inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device) | |
| try: | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| last_token_logits = outputs.logits[0, -1, :] | |
| probabilities = torch.nn.functional.softmax(last_token_logits / temperature, dim=-1) | |
| top_k = min(top_k, probabilities.size(-1)) | |
| top_probs, top_indices = torch.topk(probabilities, top_k) | |
| top_words = [ensure_token_display(tokenizer.decode([idx.item()])) for idx in top_indices] | |
| prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)} | |
| prob_text = "Prochains tokens les plus probables :\n\n" | |
| for word, prob in prob_data.items(): | |
| prob_text += f"{word}: {prob:.2%}\n" | |
| prob_plot = plot_probabilities(prob_data) | |
| attention_plot = plot_attention(inputs["input_ids"][0], last_token_logits) | |
| return prob_text, attention_plot, prob_plot | |
| except Exception as e: | |
| return f"Erreur lors de l'analyse : {str(e)}", None, None | |
| def generate_text(input_text, temperature, top_p, top_k): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| return "Veuillez d'abord charger un modèle." | |
| # Détection de la langue | |
| detected_lang = detect(input_text) | |
| if detected_lang not in model_languages.get(model.config._name_or_path, []): | |
| return f"Langue détectée ({detected_lang}) non supportée par ce modèle." | |
| inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device) | |
| try: | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=50, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k | |
| ) | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return generated_text | |
| except Exception as e: | |
| return f"Erreur lors de la génération : {str(e)}" | |
| def plot_probabilities(prob_data): | |
| words = list(prob_data.keys()) | |
| probs = list(prob_data.values()) | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| bars = ax.bar(range(len(words)), probs, color='lightgreen') | |
| ax.set_title("Probabilités des tokens suivants les plus probables") | |
| ax.set_xlabel("Tokens") | |
| ax.set_ylabel("Probabilité") | |
| ax.set_xticks(range(len(words))) | |
| ax.set_xticklabels(words, rotation=45, ha='right') | |
| for i, (bar, word) in enumerate(zip(bars, words)): | |
| height = bar.get_height() | |
| ax.text(i, height, f'{height:.2%}', | |
| ha='center', va='bottom', rotation=0) | |
| plt.tight_layout() | |
| return fig | |
| def plot_attention(input_ids, last_token_logits): | |
| input_tokens = [ensure_token_display(tokenizer.decode([id])) for id in input_ids] | |
| attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1) | |
| top_k = min(len(input_tokens), 10) | |
| top_attention_scores, _ = torch.topk(attention_scores, top_k) | |
| fig, ax = plt.subplots(figsize=(14, 7)) | |
| sns.heatmap(top_attention_scores.unsqueeze(0).cpu().numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%') | |
| ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10) | |
| ax.set_yticklabels(["Attention"], rotation=0, fontsize=10) | |
| ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16) | |
| cbar = ax.collections[0].colorbar | |
| cbar.set_label("Score d'attention", fontsize=12) | |
| cbar.ax.tick_params(labelsize=10) | |
| plt.tight_layout() | |
| return fig | |
| def reset(): | |
| global model, tokenizer | |
| model = None | |
| tokenizer = None | |
| return "", 1.0, 1.0, 50, None, None, None, None | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Analyse et génération de texte avec LLM") | |
| with gr.Accordion("Sélection du modèle"): | |
| model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle") | |
| load_button = gr.Button("Charger le modèle") | |
| load_output = gr.Textbox(label="Statut du chargement") | |
| with gr.Row(): | |
| temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température") | |
| top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p") | |
| top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") | |
| input_text = gr.Textbox(label="Texte d'entrée", lines=3) | |
| analyze_button = gr.Button("Analyser le prochain token") | |
| next_token_probs = gr.Textbox(label="Probabilités du prochain token") | |
| with gr.Row(): | |
| attention_plot = gr.Plot(label="Visualisation de l'attention") | |
| prob_plot = gr.Plot(label="Probabilités des tokens suivants") | |
| generate_button = gr.Button("Générer la suite du texte") | |
| generated_text = gr.Textbox(label="Texte généré") | |
| reset_button = gr.Button("Réinitialiser") | |
| load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output]) | |
| analyze_button.click(analyze_next_token, | |
| inputs=[input_text, temperature, top_p, top_k], | |
| outputs=[next_token_probs, attention_plot, prob_plot]) | |
| generate_button.click(generate_text, | |
| inputs=[input_text, temperature, top_p, top_k], | |
| outputs=[generated_text]) | |
| reset_button.click(reset, | |
| outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text]) | |
| if __name__ == "__main__": | |
| demo.launch() | |