File size: 9,352 Bytes
ee297c4
 
 
 
 
 
 
 
 
 
ca34394
ee297c4
ca34394
 
 
 
 
 
7d3336d
ee297c4
 
 
91ebb02
 
 
 
ee297c4
dfcb99e
 
 
 
ee297c4
 
 
 
 
 
 
 
 
 
 
ca34394
 
 
 
dfcb99e
 
 
 
 
 
 
ca34394
 
 
 
 
 
 
 
 
 
 
ee297c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f76d1f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import streamlit as st
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_community.utilities import SQLDatabase
from langgraph.graph import StateGraph
from typing import TypedDict, Annotated, Literal, List
import os
from dotenv import load_dotenv
from sqlalchemy import create_engine, text
import pandas as pd
from typing import Dict
import mysql.connector

import os
from dotenv import load_dotenv
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
import mysql.connector

# 🔹 Charger les variables d'environnement
load_dotenv()
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
server = os.getenv("DB_HOST")
database = os.getenv("DB_NAME")
user = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")

# Vérifier que les variables d'environnement sont chargées
if not all([MISTRAL_API_KEY, server, database, user, password]):
    raise ValueError("Veuillez vérifier les variables d'environnement dans les paramètres Hugging Face.")
    
# 🔹 Configuration du modèle Mistral
model_name = "mistral-large-latest"

llm = ChatMistralAI(
    model=model_name,
    api_key=MISTRAL_API_KEY,
    temperature=0,
    stream=True,
    verbose=True
)

# 🔹 Configuration de la base de données MySQL
# Utiliser mysql-connector-python avec SQLAlchemy
db_url = f"mysql+mysqlconnector://{user}:{password}@{server}/{database}"

try:
    engine = create_engine(db_url)
    db = SQLDatabase.from_uri(db_url)
    print("Connexion à la base de données réussie !")
except Exception as e:
    print(f"Erreur lors de la connexion à la base de données : {e}")

# Si tu veux tester une connexion directe sans SQLAlchemy
try:
    conn = mysql.connector.connect(
        host=server,
        user=user,
        password=password,
        database=database
    )
    print("Connexion réussie à la base de données MySQL !")
except mysql.connector.Error as err:
    print(f"Erreur de connexion : {err}")

# 🔹 Définition du graphe LangGraph
class QueryState(TypedDict):
    query: Annotated[str, "Requête en langage naturel"]
    sql_query: Annotated[str, "Requête SQL générée"]
    result: Annotated[List[dict], "Résultat SQL"]
    user_intent: Literal["calculation", "visualization", "other"]
    visualization: Annotated[str, "Type de visualisation recommandée"]

graph = StateGraph(QueryState)

# 🔹 Étape 1 : Analyser l'intention de l'utilisateur
def analyze_intent(state: QueryState) -> QueryState:
    prompt = f"""
    Tu es un assistant SQL expert. Analyse l'intention de l'utilisateur à partir de sa question.
    L'intention peut être :
    - "calculation" : Si la question nécessite un calcul ou une agrégation (ex : somme, moyenne, etc.).
    - "visualization" : Si la question demande une visualisation (ex : graphique, tableau, etc.).
    - "other" : Pour toute autre intention.

    **Question de l'utilisateur :**
    {state["query"]}

    **Intention détectée :**
    """
    intent = llm.invoke(prompt).content.strip().lower()
    return {**state, "user_intent": intent}

# 🔹 Étape 2 : Générer la requête SQL avec cache et affichage en temps réel
def generate_sql(state: QueryState) -> QueryState:
    if isinstance(state, dict) and state.get("sql_query"):  # Si la requête SQL existe déjà, ne pas la regénérer
        return state

    prompt = f"""
    Tu es un assistant SQL expert. L'utilisateur te pose une question en langage naturel sur une base de données.

    **Tables disponibles :**
    - 'user' (colonnes : firstname, name, role, created_date, last_login, first_login, instagram_url, nb_instagram_followers)

    **Consignes :**
    - Génère uniquement une requête SQL **valide**.
    - N'ajoute **aucun texte explicatif**, retourne seulement la requête SQL.
    - Assure-toi que la requête ne contient **aucune erreur de syntaxe**.
    - **Ne mets pas de point-virgule à la fin de la requête**.
    - Utilise des noms de tables et de colonnes sans caractères spéciaux, et utilise des crochets `[]` si nécessaire pour SQL Server.

    **Question de l'utilisateur :**
    {state["query"]}

    **Requête SQL :**
    """
    response = llm.stream(prompt)
    sql_query = "".join(chunk.content for chunk in response).strip()

    return {**state, "sql_query": sql_query}

# 🔹 Étape 3 : Valider et corriger la requête SQL avec cache et affichage en temps réel
def validate_and_fix_sql(state: QueryState) -> QueryState:
    if "sql_query" in state and state["sql_query"]:
        return state

    prompt = f"""
    Tu es un assistant SQL expert. Valide et corrige la requête SQL suivante si nécessaire.

    **Requête SQL :**
    {state["sql_query"]}

    **Schéma de la base de données :**
    - 'user' (colonnes : firstname, name, role, created_date, last_login, first_login, instagram_url, nb_instagram_followers)

    **Consignes :**
    - Si la requête est valide, retourne-la telle quelle.
    - Si la requête contient des erreurs, corrige-la et retourne la version corrigée.
    - N'ajoute **aucun texte explicatif**, retourne seulement la requête SQL.
    - Utilise des crochets `[]` pour délimiter les noms de tables ou de colonnes si nécessaire, car la base de données est SQL Server.
    - Remplace `LIMIT` par `TOP` pour SQL Server.
    - Assure-toi que la requête ne contient **aucune erreur de syntaxe**.

    **Requête SQL corrigée :**
    """
    response = llm.stream(prompt)
    corrected_query = "".join(chunk.content for chunk in response).strip()

    return {**state, "sql_query": corrected_query}

# 🔹 Étape 4 : Exécuter la requête SQL
def execute_query(state: Dict) -> Dict:
    try:
        with engine.connect() as conn:
            result = conn.execute(text(state["sql_query"])).fetchall()
            result_dict = [dict(row._mapping) for row in result]

            if result_dict:
                df = pd.DataFrame(result_dict)
                st.dataframe(df)  # Afficher les résultats sous forme de tableau
            else:
                st.write("⚠️ Aucun résultat trouvé.")

        return {**state, "result": result_dict}
    except Exception as e:
        return {**state, "result": [{"error": str(e)}]}

# 🔹 Configuration du graphe LangGraph
graph.add_node("analyze_intent", analyze_intent)
graph.add_node("generate_sql", generate_sql)
graph.add_node("validate_and_fix_sql", validate_and_fix_sql)
graph.add_node("execute_query", execute_query)

graph.set_entry_point("analyze_intent")
graph.add_edge("analyze_intent", "generate_sql")
graph.add_edge("generate_sql", "validate_and_fix_sql")
graph.add_edge("validate_and_fix_sql", "execute_query")

agent = graph.compile()

# 🔹 Initialisation de l'historique des interactions
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# Fonction pour afficher les messages avec le bon style (droite pour utilisateur, gauche pour assistant)
def display_chat_history():
    for message in st.session_state.chat_history:
        if message["role"] == "user":
            st.markdown(
                f'<div style="text-align: right; background-color: #D3D3D3; padding: 10px; border-radius: 10px; margin-bottom: 10px; width: 50%; margin-left: auto;">'
                f'{message["content"]}</div>',
                unsafe_allow_html=True
            )
        else:
            st.markdown(
                f'<div style="text-align: left; background-color: #F1F1F1; padding: 10px; border-radius: 10px; margin-bottom: 10px; width: 50%;">'
                f'{message["content"]}</div>',
                unsafe_allow_html=True
            )

# 🔹 Interface Streamlit
st.title("🧠 Assistant SQL")

# Affichage de l'historique
display_chat_history()

# Barre latérale pour ajouter des descriptions
with st.sidebar:
    st.title("ℹ️ À propos")
    st.write("Posez vos questions SQL en langage naturel")

# Champ de saisie EN BAS (comme ChatGPT)
query = st.chat_input("Comment puis-je vous aider ?")

if query:
    # Ajouter la question de l'utilisateur à l'historique
    st.session_state.chat_history.append({"role": "user", "content": query})

    # Exécution de l'agent SQL
    initial_state = {"query": query, "sql_query": "", "result": [], "user_intent": "other"}
    output = agent.invoke(initial_state)

    # Ajouter les résultats à l'historique sous forme de Markdown
    if "result" in output and output["result"]:
        if isinstance(output["result"], list) and output["result"]:  # Vérifie que les résultats sont une liste non vide
            result_markdown = " "
            for row in output["result"]:
                if isinstance(row, dict):  # Si les résultats sont des dictionnaires
                    result_markdown += "\n".join(f"{value}" for value in row.values()) + "\n"  # Afficher seulement les valeurs sans parenthèses
                else:  # Si les résultats sont des chaînes ou d'autres types
                    result_markdown += f"- {row},\n"
            st.session_state.chat_history.append({"role": "assistant", "content": result_markdown})
        else:
            st.session_state.chat_history.append({"role": "assistant", "content": "⚠️ Aucun résultat trouvé."})
    # Rafraîchir la page pour afficher immédiatement les nouvelles réponses
    st.rerun()