import streamlit as st from langchain_mistralai.chat_models import ChatMistralAI from langchain_community.utilities import SQLDatabase from langgraph.graph import StateGraph from typing import TypedDict, Annotated, Literal, List import os from dotenv import load_dotenv from sqlalchemy import create_engine, text import pandas as pd from typing import Dict # Vérifier les drivers ODBC installés os.system("python check_odbc.py") # 🔹 Charger les variables d'environnement load_dotenv() MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") server = os.getenv("DB_HOST") database = os.getenv("DB_NAME") user = os.getenv("DB_USER") password = os.getenv("DB_PASSWORD") # Vérifier que les variables d'environnement sont chargées if not all([MISTRAL_API_KEY, server, database, user, password]): raise ValueError("Veuillez vérifier les variables d'environnement dans les paramètres Hugging Face.") # 🔹 Configuration du modèle Mistral model_name = "mistral-large-latest" llm = ChatMistralAI( model=model_name, api_key=MISTRAL_API_KEY, temperature=0, stream=True, verbose=True ) # 🔹 Configuration de la base de données SQL Server db_url = f"mssql+pyodbc://{user}:{password}@{server}/{database}?driver=ODBC+Driver+17+for+SQL+Server" try: engine = create_engine(db_url) db = SQLDatabase.from_uri(db_url) print("Connexion à la base de données réussie !") except Exception as e: print(f"Erreur lors de la connexion à la base de données : {e}") engine = create_engine(db_url) db = SQLDatabase.from_uri(db_url) # 🔹 Définition du graphe LangGraph class QueryState(TypedDict): query: Annotated[str, "Requête en langage naturel"] sql_query: Annotated[str, "Requête SQL générée"] result: Annotated[List[dict], "Résultat SQL"] user_intent: Literal["calculation", "visualization", "other"] visualization: Annotated[str, "Type de visualisation recommandée"] graph = StateGraph(QueryState) # 🔹 Étape 1 : Analyser l'intention de l'utilisateur def analyze_intent(state: QueryState) -> QueryState: prompt = f""" Tu es un assistant SQL expert. Analyse l'intention de l'utilisateur à partir de sa question. L'intention peut être : - "calculation" : Si la question nécessite un calcul ou une agrégation (ex : somme, moyenne, etc.). - "visualization" : Si la question demande une visualisation (ex : graphique, tableau, etc.). - "other" : Pour toute autre intention. **Question de l'utilisateur :** {state["query"]} **Intention détectée :** """ intent = llm.invoke(prompt).content.strip().lower() return {**state, "user_intent": intent} # 🔹 Étape 2 : Générer la requête SQL avec cache et affichage en temps réel def generate_sql(state: QueryState) -> QueryState: if isinstance(state, dict) and state.get("sql_query"): # Si la requête SQL existe déjà, ne pas la regénérer return state prompt = f""" Tu es un assistant SQL expert. L'utilisateur te pose une question en langage naturel sur une base de données. **Tables disponibles :** - 'user' (colonnes : firstname, name, role, created_date, last_login, first_login, instagram_url, nb_instagram_followers) **Consignes :** - Génère uniquement une requête SQL **valide**. - N'ajoute **aucun texte explicatif**, retourne seulement la requête SQL. - Assure-toi que la requête ne contient **aucune erreur de syntaxe**. - **Ne mets pas de point-virgule à la fin de la requête**. - Utilise des noms de tables et de colonnes sans caractères spéciaux, et utilise des crochets `[]` si nécessaire pour SQL Server. **Question de l'utilisateur :** {state["query"]} **Requête SQL :** """ response = llm.stream(prompt) sql_query = "".join(chunk.content for chunk in response).strip() return {**state, "sql_query": sql_query} # 🔹 Étape 3 : Valider et corriger la requête SQL avec cache et affichage en temps réel def validate_and_fix_sql(state: QueryState) -> QueryState: if "sql_query" in state and state["sql_query"]: return state prompt = f""" Tu es un assistant SQL expert. Valide et corrige la requête SQL suivante si nécessaire. **Requête SQL :** {state["sql_query"]} **Schéma de la base de données :** - 'user' (colonnes : firstname, name, role, created_date, last_login, first_login, instagram_url, nb_instagram_followers) **Consignes :** - Si la requête est valide, retourne-la telle quelle. - Si la requête contient des erreurs, corrige-la et retourne la version corrigée. - N'ajoute **aucun texte explicatif**, retourne seulement la requête SQL. - Utilise des crochets `[]` pour délimiter les noms de tables ou de colonnes si nécessaire, car la base de données est SQL Server. - Remplace `LIMIT` par `TOP` pour SQL Server. - Assure-toi que la requête ne contient **aucune erreur de syntaxe**. **Requête SQL corrigée :** """ response = llm.stream(prompt) corrected_query = "".join(chunk.content for chunk in response).strip() return {**state, "sql_query": corrected_query} # 🔹 Étape 4 : Exécuter la requête SQL def execute_query(state: Dict) -> Dict: try: with engine.connect() as conn: result = conn.execute(text(state["sql_query"])).fetchall() result_dict = [dict(row._mapping) for row in result] if result_dict: df = pd.DataFrame(result_dict) st.dataframe(df) # Afficher les résultats sous forme de tableau else: st.write("⚠️ Aucun résultat trouvé.") return {**state, "result": result_dict} except Exception as e: return {**state, "result": [{"error": str(e)}]} # 🔹 Configuration du graphe LangGraph graph.add_node("analyze_intent", analyze_intent) graph.add_node("generate_sql", generate_sql) graph.add_node("validate_and_fix_sql", validate_and_fix_sql) graph.add_node("execute_query", execute_query) graph.set_entry_point("analyze_intent") graph.add_edge("analyze_intent", "generate_sql") graph.add_edge("generate_sql", "validate_and_fix_sql") graph.add_edge("validate_and_fix_sql", "execute_query") agent = graph.compile() # 🔹 Initialisation de l'historique des interactions if "chat_history" not in st.session_state: st.session_state.chat_history = [] # Fonction pour afficher les messages avec le bon style (droite pour utilisateur, gauche pour assistant) def display_chat_history(): for message in st.session_state.chat_history: if message["role"] == "user": st.markdown( f'
' f'{message["content"]}
', unsafe_allow_html=True ) else: st.markdown( f'
' f'{message["content"]}
', unsafe_allow_html=True ) # 🔹 Interface Streamlit st.title("🧠 Assistant SQL") # Affichage de l'historique display_chat_history() # Barre latérale pour ajouter des descriptions with st.sidebar: st.title("ℹ️ À propos") st.write("Posez vos questions SQL en langage naturel") # Champ de saisie EN BAS (comme ChatGPT) query = st.chat_input("Comment puis-je vous aider ?") if query: # Ajouter la question de l'utilisateur à l'historique st.session_state.chat_history.append({"role": "user", "content": query}) # Exécution de l'agent SQL initial_state = {"query": query, "sql_query": "", "result": [], "user_intent": "other"} output = agent.invoke(initial_state) # Ajouter les résultats à l'historique sous forme de Markdown if "result" in output and output["result"]: if isinstance(output["result"], list) and output["result"]: # Vérifie que les résultats sont une liste non vide result_markdown = " " for row in output["result"]: if isinstance(row, dict): # Si les résultats sont des dictionnaires result_markdown += "\n".join(f"{value}" for value in row.values()) + "\n" # Afficher seulement les valeurs sans parenthèses else: # Si les résultats sont des chaînes ou d'autres types result_markdown += f"- {row},\n" st.session_state.chat_history.append({"role": "assistant", "content": result_markdown}) else: st.session_state.chat_history.append({"role": "assistant", "content": "⚠️ Aucun résultat trouvé."}) # Rafraîchir la page pour afficher immédiatement les nouvelles réponses st.rerun()