Ferdlance's picture
Update app.py
c1a6d5f verified
raw
history blame
18 kB
#!/usr/bin/env python3
import os
import sys
import time
import random
import logging
import requests
import json
import re
import subprocess
import shutil
from datetime import datetime
from pathlib import Path
# Streamlit et visualisation
import streamlit as st
import pandas as pd
import plotly.express as px
# Parsing HTML
import html2text
# Importation du module de configuration (supposé exister)
from config import app_config as config
# --- CONFIGURATION DE LA PAGE ET LOGGING ---
st.set_page_config(
page_title="DevSecOps Data Bot",
layout="wide",
initial_sidebar_state="expanded"
)
# Initialisation de l'état de la session (géré par le fichier config)
config.init_session_state()
def setup_logging():
"""Configure un logger pour tracer l'exécution dans un fichier et la console."""
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
log_file = log_dir / f"data_collector_{datetime.now().strftime('%Y%m%d')}.log"
# Évite d'ajouter des handlers multiples si la fonction est appelée plusieurs fois
logger = logging.getLogger(__name__)
if not logger.handlers:
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger
logger = setup_logging()
h_parser = html2text.HTML2Text()
h_parser.ignore_links = True
# --- GESTION DU SERVEUR LLM LOCAL ---
def check_server_status():
"""Vérifie si le serveur LLM est actif."""
try:
response = requests.get(config.LLM_SERVER_URL.replace("/completion", "/health"), timeout=3)
if response.status_code == 200 and response.json().get('status') == 'ok':
st.session_state.server_status = "Actif"
return True
except requests.exceptions.RequestException:
pass
st.session_state.server_status = "Inactif"
return False
def start_llm_server():
"""Démarre le serveur llama.cpp en subprocess."""
if check_server_status():
st.toast("✅ Le serveur LLM est déjà actif.", icon="✅")
return
model_path = Path(config.MODEL_PATH)
server_binary = Path(config.LLAMA_SERVER_PATH)
if not model_path.exists():
st.error(f"Le modèle GGUF est introuvable à : {config.MODEL_PATH}")
return
if not server_binary.exists():
st.error(f"Le binaire du serveur est introuvable : {config.LLAMA_SERVER_PATH}")
return
# Commande pour démarrer le serveur
command = [
str(server_binary),
"-m", str(model_path),
"--port", str(config.LLM_PORT),
"--host", "0.0.0.0",
"-c", "4096",
"-ngl", "999", # Nombre de couches GPU, ajuster si nécessaire
"--threads", "8"
]
log_file = Path("logs/llama_server.log")
pid_file = Path("server/server.pid")
pid_file.parent.mkdir(exist_ok=True)
try:
with open(log_file, 'w') as log:
process = subprocess.Popen(command, stdout=log, stderr=subprocess.STDOUT)
with open(pid_file, 'w') as f:
f.write(str(process.pid))
st.info("Tentative de démarrage du serveur LLM...")
time.sleep(10) # Laisse le temps au serveur de démarrer
if check_server_status():
st.success("Serveur LLM démarré avec succès !")
else:
st.error("Le serveur n'a pas pu démarrer. Vérifiez les logs dans `logs/llama_server.log`.")
except Exception as e:
st.error(f"Erreur lors du démarrage du serveur : {e}")
def stop_llm_server():
"""Arrête le serveur LLM en tuant le processus via son PID."""
pid_file = Path("server/server.pid")
if not pid_file.exists():
st.warning("Aucun fichier PID trouvé. Le serveur est probablement déjà arrêté.")
check_server_status()
return
try:
with open(pid_file, 'r') as f:
pid = int(f.read().strip())
# Tente de tuer le processus
os.kill(pid, 9) # SIGKILL
st.info(f"Signal d'arrêt envoyé au processus {pid}.")
os.remove(pid_file)
except (ProcessLookupError, FileNotFoundError):
st.warning("Le processus n'existait pas ou le fichier PID a déjà été supprimé.")
if pid_file.exists():
os.remove(pid_file)
except Exception as e:
st.error(f"Erreur lors de l'arrêt du serveur : {e}")
time.sleep(3)
if not check_server_status():
st.success("Serveur LLM arrêté avec succès.")
else:
st.warning("Le serveur semble toujours actif. Une vérification manuelle peut être nécessaire.")
# --- LOGIQUE D'ENRICHISSEMENT IA ---
class IAEnricher:
"""Classe pour interagir avec le LLM et enrichir les données."""
def __init__(self):
self.server_url = config.LLM_SERVER_URL
self.available = check_server_status()
def _query_llm(self, prompt, n_predict=512):
if not self.available:
return None
payload = {
"prompt": prompt,
"n_predict": n_predict,
"temperature": st.session_state.temperature,
"stop": ["<|im_end|>", "</s>", "\n}\n"]
}
try:
response = requests.post(self.server_url, json=payload, timeout=120)
response.raise_for_status()
return response.json().get('content', '')
except requests.exceptions.RequestException as e:
logger.error(f"Erreur de communication avec le serveur LLM : {e}")
return None
def _extract_json(self, text):
"""Extrait un objet JSON d'une chaîne de texte, de manière plus robuste."""
if not text:
return None
# Trouve le premier '{' et le dernier '}' pour délimiter le JSON potentiel
start = text.find('{')
end = text.rfind('}')
if start != -1 and end != -1 and end > start:
json_str = text[start:end+1]
try:
return json.loads(json_str)
except json.JSONDecodeError:
logger.warning(f"Impossible de décoder le JSON extrait : {json_str[:200]}...")
return None
def analyze_content_relevance(self, content):
"""Utilise l'IA pour analyser la pertinence d'un contenu."""
if not self.available or not st.session_state.enable_enrichment:
return {"relevant": True, "attack_signatures": [], "security_tags": [], "it_relevance_score": 50}
prompt = config.PROMPTS["analyze_relevance"].format(content=content[:1500])
response_text = self._query_llm(prompt, n_predict=256)
analysis = self._extract_json(response_text)
if analysis:
return analysis
# Valeur par défaut si l'IA échoue
return {"relevant": True, "attack_signatures": [], "security_tags": [], "it_relevance_score": 50}
# --- FONCTIONS DE COLLECTE DE DONNÉES ---
def check_api_keys():
"""Vérifie la présence des clés API et met à jour un flag global."""
keys_needed = ['GITHUB_API_TOKEN', 'NVD_API_KEY', 'STACK_EXCHANGE_API_KEY']
missing_keys = [key for key in keys_needed if not os.getenv(key)]
if missing_keys:
logger.warning(f"Clés API manquantes : {', '.join(missing_keys)}. Le bot fonctionnera en mode dégradé.")
config.USE_API_KEYS = False
else:
logger.info("Toutes les clés API nécessaires sont configurées.")
config.USE_API_KEYS = True
def make_request(url, headers=None, params=None):
"""Effectue une requête HTTP avec gestion des pauses et des erreurs."""
# Logique de pause pour éviter le rate-limiting
pause_time = random.uniform(2, 5) if not config.USE_API_KEYS else random.uniform(0.5, 1.5)
time.sleep(pause_time)
try:
response = requests.get(url, headers=headers, params=params, timeout=30)
if response.status_code == 429: # Rate limited
retry_after = int(response.headers.get('Retry-After', 15))
logger.warning(f"Limite de débit atteinte. Pause de {retry_after} secondes...")
time.sleep(retry_after)
return make_request(url, headers, params)
response.raise_for_status() # Lève une exception pour les codes 4xx/5xx
return response
except requests.exceptions.RequestException as e:
logger.error(f"Erreur de requête pour {url}: {e}")
return None
def clean_html(html_content):
"""Nettoie le contenu HTML pour extraire le texte brut."""
if not html_content:
return ""
return h_parser.handle(html_content)
def save_data(data):
"""Ajoute les données collectées à l'état de la session."""
st.session_state.qa_data.append(data)
st.session_state.total_qa_pairs = len(st.session_state.qa_data)
logger.info(f"Donnée sauvegardée : {data['source']} (Total: {st.session_state.total_qa_pairs})")
# Mise à jour du log dans l'UI
log_placeholder = st.session_state.get('log_placeholder')
if log_placeholder:
log_placeholder.text(f"Dernière collecte : {data['source']}")
def collect_github_data(query, limit):
"""Collecte les problèmes de sécurité depuis des dépôts GitHub."""
logger.info(f"GitHub: Recherche de '{query}'...")
base_url = "https://api.github.com"
headers = {"Accept": "application/vnd.github.v3+json"}
if config.USE_API_KEYS:
headers["Authorization"] = f"token {os.getenv('GITHUB_API_TOKEN')}"
search_url = f"{base_url}/search/repositories"
params = {"q": query, "sort": "stars", "per_page": limit}
response = make_request(search_url, headers=headers, params=params)
if not response: return
for repo in response.json().get("items", []):
issues_url = repo["issues_url"].replace("{/number}", "")
issues_params = {"state": "all", "labels": "security,vulnerability", "per_page": 5}
issues_response = make_request(issues_url, headers=headers, params=issues_params)
if issues_response:
for issue in issues_response.json():
if "pull_request" not in issue and issue.get("body"):
analysis = ia_enricher.analyze_content_relevance(issue['title'] + " " + issue['body'])
if analysis['relevant'] and analysis['it_relevance_score'] >= st.session_state.min_relevance:
save_data({
"question": issue["title"],
"answer": clean_html(issue["body"]),
"category": "devsecops",
"source": f"github_{repo['full_name']}",
"tags": [t['name'] for t in issue.get('labels', [])] + analysis['security_tags']
})
def collect_nvd_data(limit):
"""Collecte les dernières vulnérabilités CVE depuis le NVD."""
logger.info("NVD: Collecte des dernières vulnérabilités...")
base_url = "https://services.nvd.nist.gov/rest/json/cves/2.0"
headers = {}
if config.USE_API_KEYS:
headers["apiKey"] = os.getenv('NVD_API_KEY')
params = {"resultsPerPage": limit}
response = make_request(base_url, headers=headers, params=params)
if not response: return
for vuln in response.json().get("vulnerabilities", []):
cve = vuln.get("cve", {})
cve_id = cve.get("id", "N/A")
description = next((d['value'] for d in cve.get('descriptions', []) if d['lang'] == 'en'), "")
if description:
save_data({
"question": f"Qu'est-ce que la vulnérabilité {cve_id} ?",
"answer": description,
"category": "security",
"source": f"nvd_{cve_id}",
"tags": ["cve", "vulnerability"]
})
# --- FONCTION PRINCIPALE ET INTERFACE STREAMLIT ---
def run_data_collection(sources, queries, limits):
"""Orchestre la collecte de données depuis les sources sélectionnées."""
st.session_state.bot_status = "En cours d'exécution"
# Nettoyage de l'état précédent avant de démarrer
st.session_state.qa_data = []
st.session_state.total_qa_pairs = 0
check_api_keys()
enabled_sources = [s for s, enabled in sources.items() if enabled]
progress_bar = st.progress(0, text="Démarrage de la collecte...")
for i, source_name in enumerate(enabled_sources):
progress_text = f"Collecte depuis {source_name}... ({i+1}/{len(enabled_sources)})"
progress_bar.progress((i + 1) / len(enabled_sources), text=progress_text)
try:
if source_name == "GitHub":
for query in queries["GitHub"].split('\n'):
if query.strip():
collect_github_data(query.strip(), limits["GitHub"])
elif source_name == "NVD":
collect_nvd_data(limits["NVD"])
# Ajouter d'autres sources ici (Kaggle, etc.) de la même manière
except Exception as e:
logger.error(f"Erreur fatale lors de la collecte depuis {source_name}: {e}")
progress_bar.empty()
st.session_state.bot_status = "Arrêté"
st.toast("Collecte des données terminée !", icon="🎉")
# Forcer le rafraîchissement de la page pour mettre à jour l'onglet statistiques
time.sleep(2)
st.rerun()
def main():
"""Fonction principale de l'application Streamlit."""
st.title("🤖 DevSecOps Data Bot")
st.markdown("Ce bot collecte et enrichit des données DevSecOps depuis diverses sources.")
global ia_enricher
ia_enricher = IAEnricher()
tabs = st.tabs(["▶️ Bot", "📊 Statistiques & Données", "⚙️ Configuration"])
with tabs[0]:
st.header("Tableau de bord")
col1, col2, col3 = st.columns(3)
col1.metric("Statut du bot", st.session_state.bot_status)
col2.metric("Paires Q/R collectées", st.session_state.total_qa_pairs)
col3.metric("Statut du serveur LLM", st.session_state.server_status)
# Placeholder pour les logs en direct
st.session_state['log_placeholder'] = st.empty()
with st.form("collection_form"):
st.subheader("1. Choisir les sources de données")
sources = {
"GitHub": st.checkbox("GitHub (Problèmes de sécurité)", value=True),
"NVD": st.checkbox("NVD (Vulnérabilités CVE)", value=True),
}
st.subheader("2. Paramètres de la collecte")
queries = {}
limits = {}
with st.expander("Configuration pour GitHub"):
queries["GitHub"] = st.text_area("Requêtes GitHub (une par ligne)", "language:python security\ntopic:devsecops vulnerability")
limits["GitHub"] = st.number_input("Nombre de dépôts par requête", 1, 50, 5)
with st.expander("Configuration pour NVD"):
limits["NVD"] = st.number_input("Nombre de CVE à récupérer", 10, 200, 50)
submitted = st.form_submit_button("🚀 Lancer la collecte", type="primary", use_container_width=True)
if submitted:
if st.session_state.bot_status == "En cours d'exécution":
st.warning("Une collecte est déjà en cours.")
else:
run_data_collection(sources, queries, limits)
with tabs[1]:
st.header("Analyse des Données Collectées")
if st.session_state.qa_data:
df = pd.DataFrame(st.session_state.qa_data)
st.subheader("Aperçu des données")
st.dataframe(df)
st.subheader("Répartition par source")
source_counts = df['source'].apply(lambda x: x.split('_')[0]).value_counts()
fig_source = px.bar(source_counts, x=source_counts.index, y=source_counts.values,
labels={'x': 'Source', 'y': 'Nombre'}, title="Nombre de paires Q/R par source")
st.plotly_chart(fig_source, use_container_width=True)
# Bouton de téléchargement
json_data = json.dumps(st.session_state.qa_data, indent=2, ensure_ascii=False)
st.download_button(
label="📥 Télécharger les données (JSON)",
data=json_data,
file_name=f"devsecops_data_{datetime.now().strftime('%Y%m%d')}.json",
mime="application/json",
use_container_width=True
)
else:
st.info("Aucune donnée à afficher. Lancez une collecte depuis l'onglet 'Bot'.")
with tabs[2]:
st.header("Configuration Avancée")
st.subheader("Gestion du serveur LLM local")
st.warning("⚠️ Attention : La gestion du serveur est expérimentale sur les conteneurs.")
llm_col1, llm_col2 = st.columns(2)
if llm_col1.button("Démarrer le serveur LLM", use_container_width=True):
start_llm_server()
st.rerun()
if llm_col2.button("Arrêter le serveur LLM", type="secondary", use_container_width=True):
stop_llm_server()
st.rerun()
st.subheader("Paramètres d'enrichissement IA")
st.session_state.enable_enrichment = st.toggle("Activer l'enrichissement par IA", value=True)
st.session_state.min_relevance = st.slider("Score de pertinence minimum", 0, 100, 50)
st.session_state.temperature = st.slider("Température de l'IA (créativité)", 0.0, 1.5, 0.5)
if __name__ == "__main__":
main()