Aidahaouas commited on
Commit
f76d1f7
·
verified ·
1 Parent(s): 97aa05e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_mistralai.chat_models import ChatMistralAI
3
+ from langchain_community.utilities import SQLDatabase
4
+ from langgraph.graph import StateGraph
5
+ from typing import TypedDict, Annotated, Literal, List
6
+ import os
7
+ from dotenv import load_dotenv
8
+ from sqlalchemy import create_engine, text
9
+ import pandas as pd
10
+ from typing import Dict
11
+
12
+ # 🔹 Charger les variables d'environnement
13
+ load_dotenv()
14
+ MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
15
+ server = os.getenv("server")
16
+ database = os.getenv("database")
17
+ user_name = os.getenv("user_name")
18
+ password = os.getenv("password")
19
+
20
+ # 🔹 Configuration du modèle Mistral
21
+ model_name = "mistral-large-latest"
22
+
23
+ llm = ChatMistralAI(
24
+ model=model_name,
25
+ api_key=MISTRAL_API_KEY,
26
+ temperature=0,
27
+ stream=True,
28
+ verbose=True
29
+ )
30
+
31
+ # 🔹 Configuration de la base de données SQL Server
32
+ db_url = f"mssql+pyodbc://{user_name}:{password}@{server}/{database}?driver=ODBC+Driver+17+for+SQL+Server"
33
+ engine = create_engine(db_url)
34
+ db = SQLDatabase.from_uri(db_url)
35
+
36
+ # 🔹 Définition du graphe LangGraph
37
+ class QueryState(TypedDict):
38
+ query: Annotated[str, "Requête en langage naturel"]
39
+ sql_query: Annotated[str, "Requête SQL générée"]
40
+ result: Annotated[List[dict], "Résultat SQL"]
41
+ user_intent: Literal["calculation", "visualization", "other"]
42
+ visualization: Annotated[str, "Type de visualisation recommandée"]
43
+
44
+ graph = StateGraph(QueryState)
45
+
46
+ # 🔹 Étape 1 : Analyser l'intention de l'utilisateur
47
+ def analyze_intent(state: QueryState) -> QueryState:
48
+ prompt = f"""
49
+ Tu es un assistant SQL expert. Analyse l'intention de l'utilisateur à partir de sa question.
50
+ L'intention peut être :
51
+ - "calculation" : Si la question nécessite un calcul ou une agrégation (ex : somme, moyenne, etc.).
52
+ - "visualization" : Si la question demande une visualisation (ex : graphique, tableau, etc.).
53
+ - "other" : Pour toute autre intention.
54
+
55
+ **Question de l'utilisateur :**
56
+ {state["query"]}
57
+
58
+ **Intention détectée :**
59
+ """
60
+ intent = llm.invoke(prompt).content.strip().lower()
61
+ return {**state, "user_intent": intent}
62
+
63
+ # 🔹 Étape 2 : Générer la requête SQL avec cache et affichage en temps réel
64
+ def generate_sql(state: QueryState) -> QueryState:
65
+ if isinstance(state, dict) and state.get("sql_query"): # Si la requête SQL existe déjà, ne pas la regénérer
66
+ return state
67
+
68
+ prompt = f"""
69
+ Tu es un assistant SQL expert. L'utilisateur te pose une question en langage naturel sur une base de données.
70
+
71
+ **Tables disponibles :**
72
+ - 'user' (colonnes : firstname, name, role, created_date, last_login, first_login, instagram_url, nb_instagram_followers)
73
+
74
+ **Consignes :**
75
+ - Génère uniquement une requête SQL **valide**.
76
+ - N'ajoute **aucun texte explicatif**, retourne seulement la requête SQL.
77
+ - Assure-toi que la requête ne contient **aucune erreur de syntaxe**.
78
+ - **Ne mets pas de point-virgule à la fin de la requête**.
79
+ - Utilise des noms de tables et de colonnes sans caractères spéciaux, et utilise des crochets `[]` si nécessaire pour SQL Server.
80
+
81
+ **Question de l'utilisateur :**
82
+ {state["query"]}
83
+
84
+ **Requête SQL :**
85
+ """
86
+ response = llm.stream(prompt)
87
+ sql_query = "".join(chunk.content for chunk in response).strip()
88
+
89
+ return {**state, "sql_query": sql_query}
90
+
91
+ # 🔹 Étape 3 : Valider et corriger la requête SQL avec cache et affichage en temps réel
92
+ def validate_and_fix_sql(state: QueryState) -> QueryState:
93
+ if "sql_query" in state and state["sql_query"]:
94
+ return state
95
+
96
+ prompt = f"""
97
+ Tu es un assistant SQL expert. Valide et corrige la requête SQL suivante si nécessaire.
98
+
99
+ **Requête SQL :**
100
+ {state["sql_query"]}
101
+
102
+ **Schéma de la base de données :**
103
+ - 'user' (colonnes : firstname, name, role, created_date, last_login, first_login, instagram_url, nb_instagram_followers)
104
+
105
+ **Consignes :**
106
+ - Si la requête est valide, retourne-la telle quelle.
107
+ - Si la requête contient des erreurs, corrige-la et retourne la version corrigée.
108
+ - N'ajoute **aucun texte explicatif**, retourne seulement la requête SQL.
109
+ - Utilise des crochets `[]` pour délimiter les noms de tables ou de colonnes si nécessaire, car la base de données est SQL Server.
110
+ - Remplace `LIMIT` par `TOP` pour SQL Server.
111
+ - Assure-toi que la requête ne contient **aucune erreur de syntaxe**.
112
+
113
+ **Requête SQL corrigée :**
114
+ """
115
+ response = llm.stream(prompt)
116
+ corrected_query = "".join(chunk.content for chunk in response).strip()
117
+
118
+ return {**state, "sql_query": corrected_query}
119
+
120
+ # 🔹 Étape 4 : Exécuter la requête SQL
121
+ def execute_query(state: Dict) -> Dict:
122
+ try:
123
+ with engine.connect() as conn:
124
+ result = conn.execute(text(state["sql_query"])).fetchall()
125
+ result_dict = [dict(row._mapping) for row in result]
126
+
127
+ if result_dict:
128
+ df = pd.DataFrame(result_dict)
129
+ st.dataframe(df) # Afficher les résultats sous forme de tableau
130
+ else:
131
+ st.write("⚠️ Aucun résultat trouvé.")
132
+
133
+ return {**state, "result": result_dict}
134
+ except Exception as e:
135
+ return {**state, "result": [{"error": str(e)}]}
136
+
137
+ # 🔹 Configuration du graphe LangGraph
138
+ graph.add_node("analyze_intent", analyze_intent)
139
+ graph.add_node("generate_sql", generate_sql)
140
+ graph.add_node("validate_and_fix_sql", validate_and_fix_sql)
141
+ graph.add_node("execute_query", execute_query)
142
+
143
+ graph.set_entry_point("analyze_intent")
144
+ graph.add_edge("analyze_intent", "generate_sql")
145
+ graph.add_edge("generate_sql", "validate_and_fix_sql")
146
+ graph.add_edge("validate_and_fix_sql", "execute_query")
147
+
148
+ agent = graph.compile()
149
+
150
+ # 🔹 Initialisation de l'historique des interactions
151
+ if "chat_history" not in st.session_state:
152
+ st.session_state.chat_history = []
153
+
154
+ # Fonction pour afficher les messages avec le bon style (droite pour utilisateur, gauche pour assistant)
155
+ def display_chat_history():
156
+ for message in st.session_state.chat_history:
157
+ if message["role"] == "user":
158
+ st.markdown(
159
+ f'<div style="text-align: right; background-color: #D3D3D3; padding: 10px; border-radius: 10px; margin-bottom: 10px; width: 50%; margin-left: auto;">'
160
+ f'{message["content"]}</div>',
161
+ unsafe_allow_html=True
162
+ )
163
+ else:
164
+ st.markdown(
165
+ f'<div style="text-align: left; background-color: #F1F1F1; padding: 10px; border-radius: 10px; margin-bottom: 10px; width: 50%;">'
166
+ f'{message["content"]}</div>',
167
+ unsafe_allow_html=True
168
+ )
169
+
170
+ # 🔹 Interface Streamlit
171
+ st.title("🧠 Assistant SQL")
172
+
173
+ # Affichage de l'historique
174
+ display_chat_history()
175
+
176
+ # Barre latérale pour ajouter des descriptions
177
+ with st.sidebar:
178
+ st.title("ℹ️ À propos")
179
+ st.write("Posez vos questions SQL en langage naturel")
180
+
181
+ # Champ de saisie EN BAS (comme ChatGPT)
182
+ query = st.chat_input("Comment puis-je vous aider ?")
183
+
184
+ if query:
185
+ # Ajouter la question de l'utilisateur à l'historique
186
+ st.session_state.chat_history.append({"role": "user", "content": query})
187
+
188
+ # Exécution de l'agent SQL
189
+ initial_state = {"query": query, "sql_query": "", "result": [], "user_intent": "other"}
190
+ output = agent.invoke(initial_state)
191
+
192
+ # Ajouter les résultats à l'historique sous forme de Markdown
193
+ if "result" in output and output["result"]:
194
+ if isinstance(output["result"], list) and output["result"]: # Vérifie que les résultats sont une liste non vide
195
+ result_markdown = " "
196
+ for row in output["result"]:
197
+ if isinstance(row, dict): # Si les résultats sont des dictionnaires
198
+ result_markdown += "\n".join(f"{value}" for value in row.values()) + "\n" # Afficher seulement les valeurs sans parenthèses
199
+ else: # Si les résultats sont des chaînes ou d'autres types
200
+ result_markdown += f"- {row},\n"
201
+ st.session_state.chat_history.append({"role": "assistant", "content": result_markdown})
202
+ else:
203
+ st.session_state.chat_history.append({"role": "assistant", "content": "⚠️ Aucun résultat trouvé."})
204
+ # Rafraîchir la page pour afficher immédiatement les nouvelles réponses
205
+ st.rerun()