Spaces:
Sleeping
Sleeping
File size: 9,352 Bytes
ee297c4 ca34394 ee297c4 ca34394 7d3336d ee297c4 91ebb02 ee297c4 dfcb99e ee297c4 ca34394 dfcb99e ca34394 ee297c4 f76d1f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import streamlit as st
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_community.utilities import SQLDatabase
from langgraph.graph import StateGraph
from typing import TypedDict, Annotated, Literal, List
import os
from dotenv import load_dotenv
from sqlalchemy import create_engine, text
import pandas as pd
from typing import Dict
import mysql.connector
import os
from dotenv import load_dotenv
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
import mysql.connector
# 🔹 Charger les variables d'environnement
load_dotenv()
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
server = os.getenv("DB_HOST")
database = os.getenv("DB_NAME")
user = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")
# Vérifier que les variables d'environnement sont chargées
if not all([MISTRAL_API_KEY, server, database, user, password]):
raise ValueError("Veuillez vérifier les variables d'environnement dans les paramètres Hugging Face.")
# 🔹 Configuration du modèle Mistral
model_name = "mistral-large-latest"
llm = ChatMistralAI(
model=model_name,
api_key=MISTRAL_API_KEY,
temperature=0,
stream=True,
verbose=True
)
# 🔹 Configuration de la base de données MySQL
# Utiliser mysql-connector-python avec SQLAlchemy
db_url = f"mysql+mysqlconnector://{user}:{password}@{server}/{database}"
try:
engine = create_engine(db_url)
db = SQLDatabase.from_uri(db_url)
print("Connexion à la base de données réussie !")
except Exception as e:
print(f"Erreur lors de la connexion à la base de données : {e}")
# Si tu veux tester une connexion directe sans SQLAlchemy
try:
conn = mysql.connector.connect(
host=server,
user=user,
password=password,
database=database
)
print("Connexion réussie à la base de données MySQL !")
except mysql.connector.Error as err:
print(f"Erreur de connexion : {err}")
# 🔹 Définition du graphe LangGraph
class QueryState(TypedDict):
query: Annotated[str, "Requête en langage naturel"]
sql_query: Annotated[str, "Requête SQL générée"]
result: Annotated[List[dict], "Résultat SQL"]
user_intent: Literal["calculation", "visualization", "other"]
visualization: Annotated[str, "Type de visualisation recommandée"]
graph = StateGraph(QueryState)
# 🔹 Étape 1 : Analyser l'intention de l'utilisateur
def analyze_intent(state: QueryState) -> QueryState:
prompt = f"""
Tu es un assistant SQL expert. Analyse l'intention de l'utilisateur à partir de sa question.
L'intention peut être :
- "calculation" : Si la question nécessite un calcul ou une agrégation (ex : somme, moyenne, etc.).
- "visualization" : Si la question demande une visualisation (ex : graphique, tableau, etc.).
- "other" : Pour toute autre intention.
**Question de l'utilisateur :**
{state["query"]}
**Intention détectée :**
"""
intent = llm.invoke(prompt).content.strip().lower()
return {**state, "user_intent": intent}
# 🔹 Étape 2 : Générer la requête SQL avec cache et affichage en temps réel
def generate_sql(state: QueryState) -> QueryState:
if isinstance(state, dict) and state.get("sql_query"): # Si la requête SQL existe déjà, ne pas la regénérer
return state
prompt = f"""
Tu es un assistant SQL expert. L'utilisateur te pose une question en langage naturel sur une base de données.
**Tables disponibles :**
- 'user' (colonnes : firstname, name, role, created_date, last_login, first_login, instagram_url, nb_instagram_followers)
**Consignes :**
- Génère uniquement une requête SQL **valide**.
- N'ajoute **aucun texte explicatif**, retourne seulement la requête SQL.
- Assure-toi que la requête ne contient **aucune erreur de syntaxe**.
- **Ne mets pas de point-virgule à la fin de la requête**.
- Utilise des noms de tables et de colonnes sans caractères spéciaux, et utilise des crochets `[]` si nécessaire pour SQL Server.
**Question de l'utilisateur :**
{state["query"]}
**Requête SQL :**
"""
response = llm.stream(prompt)
sql_query = "".join(chunk.content for chunk in response).strip()
return {**state, "sql_query": sql_query}
# 🔹 Étape 3 : Valider et corriger la requête SQL avec cache et affichage en temps réel
def validate_and_fix_sql(state: QueryState) -> QueryState:
if "sql_query" in state and state["sql_query"]:
return state
prompt = f"""
Tu es un assistant SQL expert. Valide et corrige la requête SQL suivante si nécessaire.
**Requête SQL :**
{state["sql_query"]}
**Schéma de la base de données :**
- 'user' (colonnes : firstname, name, role, created_date, last_login, first_login, instagram_url, nb_instagram_followers)
**Consignes :**
- Si la requête est valide, retourne-la telle quelle.
- Si la requête contient des erreurs, corrige-la et retourne la version corrigée.
- N'ajoute **aucun texte explicatif**, retourne seulement la requête SQL.
- Utilise des crochets `[]` pour délimiter les noms de tables ou de colonnes si nécessaire, car la base de données est SQL Server.
- Remplace `LIMIT` par `TOP` pour SQL Server.
- Assure-toi que la requête ne contient **aucune erreur de syntaxe**.
**Requête SQL corrigée :**
"""
response = llm.stream(prompt)
corrected_query = "".join(chunk.content for chunk in response).strip()
return {**state, "sql_query": corrected_query}
# 🔹 Étape 4 : Exécuter la requête SQL
def execute_query(state: Dict) -> Dict:
try:
with engine.connect() as conn:
result = conn.execute(text(state["sql_query"])).fetchall()
result_dict = [dict(row._mapping) for row in result]
if result_dict:
df = pd.DataFrame(result_dict)
st.dataframe(df) # Afficher les résultats sous forme de tableau
else:
st.write("⚠️ Aucun résultat trouvé.")
return {**state, "result": result_dict}
except Exception as e:
return {**state, "result": [{"error": str(e)}]}
# 🔹 Configuration du graphe LangGraph
graph.add_node("analyze_intent", analyze_intent)
graph.add_node("generate_sql", generate_sql)
graph.add_node("validate_and_fix_sql", validate_and_fix_sql)
graph.add_node("execute_query", execute_query)
graph.set_entry_point("analyze_intent")
graph.add_edge("analyze_intent", "generate_sql")
graph.add_edge("generate_sql", "validate_and_fix_sql")
graph.add_edge("validate_and_fix_sql", "execute_query")
agent = graph.compile()
# 🔹 Initialisation de l'historique des interactions
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Fonction pour afficher les messages avec le bon style (droite pour utilisateur, gauche pour assistant)
def display_chat_history():
for message in st.session_state.chat_history:
if message["role"] == "user":
st.markdown(
f'<div style="text-align: right; background-color: #D3D3D3; padding: 10px; border-radius: 10px; margin-bottom: 10px; width: 50%; margin-left: auto;">'
f'{message["content"]}</div>',
unsafe_allow_html=True
)
else:
st.markdown(
f'<div style="text-align: left; background-color: #F1F1F1; padding: 10px; border-radius: 10px; margin-bottom: 10px; width: 50%;">'
f'{message["content"]}</div>',
unsafe_allow_html=True
)
# 🔹 Interface Streamlit
st.title("🧠 Assistant SQL")
# Affichage de l'historique
display_chat_history()
# Barre latérale pour ajouter des descriptions
with st.sidebar:
st.title("ℹ️ À propos")
st.write("Posez vos questions SQL en langage naturel")
# Champ de saisie EN BAS (comme ChatGPT)
query = st.chat_input("Comment puis-je vous aider ?")
if query:
# Ajouter la question de l'utilisateur à l'historique
st.session_state.chat_history.append({"role": "user", "content": query})
# Exécution de l'agent SQL
initial_state = {"query": query, "sql_query": "", "result": [], "user_intent": "other"}
output = agent.invoke(initial_state)
# Ajouter les résultats à l'historique sous forme de Markdown
if "result" in output and output["result"]:
if isinstance(output["result"], list) and output["result"]: # Vérifie que les résultats sont une liste non vide
result_markdown = " "
for row in output["result"]:
if isinstance(row, dict): # Si les résultats sont des dictionnaires
result_markdown += "\n".join(f"{value}" for value in row.values()) + "\n" # Afficher seulement les valeurs sans parenthèses
else: # Si les résultats sont des chaînes ou d'autres types
result_markdown += f"- {row},\n"
st.session_state.chat_history.append({"role": "assistant", "content": result_markdown})
else:
st.session_state.chat_history.append({"role": "assistant", "content": "⚠️ Aucun résultat trouvé."})
# Rafraîchir la page pour afficher immédiatement les nouvelles réponses
st.rerun() |