CySecBERT-IT-Event-Triage-Classification / train_and_save_models.py
chaos4455's picture
Update train_and_save_models.py
f7afb2c verified
import os
import json
import torch
import random
import sqlite3
import math
import time
from datetime import datetime
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np
import multiprocessing
import joblib # Importado para salvar os modelos
# REMOVIDAS importações Flask e CORS, pois este script é apenas para treinamento local.
# from flask import Flask, request, jsonify
# from flask_cors import CORS
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPRegressor
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import Ridge
# --- Configuração Inicial ---
DB_NAME = "training_data_large.db"
TABLE_NAME = "events"
MODEL_NAME = "markusbayer/CySecBERT"
RANDOM_SEED = 42
NUM_INITIAL_TRAIN_EVENTS_PER_CLASS = 2000
RISK_THRESHOLD = 50.0
try:
multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
pass
NUM_PROCESSES = os.cpu_count() or 4
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(RANDOM_SEED)
model_base = None # Será inicializado durante o treinamento
mlp_regressor, scaler = None, None
tfidf_vectorizer, tfidf_regressor = None, None
# --- Vocabulário Expandido (compartilhado para geração de dados) ---
# Estas listas são usadas SOMENTE no trainer.py para gerar os dados de treinamento.
# Elas serão duplicadas (ou uma versão reduzida) no app.py se o "Gerar Evento Aleatório" for mantido.
ADVERSARIAL_RISK_ACTORS = [
"Unsandboxed process", "Leaked API key", "Misconfigured service account", "Shadow IT application",
"Dormant user account", "Ransomware payload", "Phishing attempt", "Insider threat",
"Zero-day exploit", "Malicious actor", "Compromised credential", "Vulnerable third-party library",
"Compromised Kubernetes pod", "Malicious Docker container", "AWS IAM role escalation",
"Azure AD privilege escalation", "GCP service account abuse", "Container escape attempt",
"Serverless function injection", "Cloud storage bucket enumeration", "API gateway bypass",
"Microservice lateral movement", "Container registry poisoning", "Cloud metadata exploitation",
"CI/CD pipeline compromise", "Git repository poisoning", "Build artifact tampering",
"Deployment script injection", "Infrastructure as Code attack", "Secret scanning bypass",
"Dependency confusion attack", "Supply chain compromise", "Code signing certificate theft",
"Pipeline privilege escalation", "Artifact repository poisoning", "Build environment escape",
"Compromised IoT device", "Edge computing exploit", "Industrial control system breach",
"SCADA system compromise", "Smart city infrastructure attack", "Medical device exploitation",
"Automotive system breach", "Home automation compromise", "Sensor data manipulation",
"Edge gateway exploitation", "Industrial protocol abuse", "IoT botnet recruitment",
"Mobile app sandbox escape", "iOS jailbreak exploitation", "Android rootkit installation",
"Mobile banking trojan", "Enterprise device compromise", "BYOD policy violation",
"Mobile device management bypass", "App store poisoning", "Mobile certificate pinning bypass",
"Endpoint detection evasion", "Mobile phishing campaign", "Device fingerprinting abuse",
"Network segmentation bypass", "Firewall rule manipulation", "VPN tunnel exploitation",
"DNS hijacking attempt", "BGP route hijacking", "Network protocol abuse",
"Wireless network compromise", "Bluetooth attack vector", "NFC exploitation",
"Network monitoring evasion", "Traffic analysis bypass", "Protocol fuzzing attack"
]
ADVERSARIAL_RISK_ACTIONS = [
"attempted lateral movement via", "initiated a DNS tunneling request to",
"executed a living-off-the-land binary on", "was flagged for unusual API call patterns against",
"triggered a data access anomaly in", "exfiltrated data from", "modified critical system files in",
"gained unauthorized access to", "deployed malicious code on", "brute-forced login for",
"injected SQL into", "exploited a vulnerability in",
"attempted container escape from", "escalated privileges in Kubernetes cluster",
"abused IAM role permissions for", "enumerated cloud storage buckets through",
"bypassed API gateway authentication to", "injected malicious code into serverless function",
"compromised container registry access for", "exploited cloud metadata service to",
"performed lateral movement across microservices in", "poisoned container image in",
"abused cloud resource tagging for", "exploited cloud logging service to",
"compromised CI/CD pipeline to", "injected malicious code into build process for",
"poisoned dependency repository to", "tampered with build artifacts in",
"escalated privileges in deployment pipeline for", "bypassed security scanning in",
"abused infrastructure automation to", "compromised secret management system for",
"injected malicious code into deployment scripts for", "exploited build environment to",
"abused artifact signing process for", "compromised code repository access to",
"compromised IoT device firmware to", "exploited edge computing vulnerability in",
"breached industrial control system through", "manipulated sensor data from",
"exploited SCADA system vulnerability to", "compromised smart city infrastructure via",
"abused industrial protocol to", "exploited edge gateway vulnerability in",
"recruited device into botnet through", "compromised medical device firmware to",
"exploited automotive system vulnerability in", "breached home automation system via",
"escaped mobile app sandbox to", "exploited iOS jailbreak vulnerability in",
"installed rootkit on Android device to", "compromised enterprise mobile device through",
"bypassed mobile device management to", "poisoned mobile app store listing for",
"exploited mobile certificate pinning in", "compromised mobile banking app through",
"abused device fingerprinting to", "exploited mobile phishing vulnerability in",
"breached BYOD policy through", "compromised mobile endpoint security via",
"bypassed network segmentation to", "manipulated firewall rules for",
"exploited VPN tunnel vulnerability in", "hijacked DNS resolution for",
"abused BGP routing protocol to", "compromised wireless network through",
"exploited Bluetooth vulnerability in", "abused NFC communication to",
"evaded network monitoring through", "bypassed traffic analysis via",
"exploited network protocol vulnerability in", "compromised network infrastructure through"
]
ADVERSARIAL_RISK_TARGETS = [
"a code repository", "the CI/CD pipeline", "a cloud storage bucket", "the internal DNS server",
"the virtual machine hypervisor", "sensitive customer data", "financial databases",
"intellectual property servers", "critical infrastructure controls", "user authentication service",
"production web server", "database backup storage",
"Kubernetes cluster control plane", "Docker container registry", "AWS S3 bucket with sensitive data",
"Azure Active Directory tenant", "GCP Cloud Storage bucket", "container orchestration system",
"serverless function environment", "cloud API gateway", "microservice mesh network",
"container security scanning service", "cloud logging and monitoring system", "infrastructure as code repository",
"Git repository with production secrets", "Jenkins build pipeline", "Docker image registry",
"artifact repository with signed packages", "infrastructure provisioning system", "secret management vault",
"code signing certificate store", "dependency management system", "deployment automation platform",
"build environment with elevated privileges", "CI/CD security scanning tools", "infrastructure monitoring system",
"industrial control system network", "SCADA system database", "IoT device management platform",
"edge computing gateway", "smart city infrastructure", "medical device network",
"automotive system bus", "home automation hub", "sensor data collection system",
"industrial protocol gateway", "edge security monitoring system", "IoT device firmware repository",
"enterprise mobile device fleet", "mobile app store backend", "mobile device management system",
"mobile banking infrastructure", "mobile certificate authority", "mobile security scanning service",
"BYOD policy enforcement system", "mobile endpoint detection system", "mobile app security testing platform",
"mobile device fingerprinting database", "mobile phishing detection system", "mobile app code signing service",
"network segmentation firewall", "VPN concentrator", "DNS authoritative server",
"BGP route reflector", "wireless access point controller", "network monitoring system",
"traffic analysis platform", "network security scanning tool", "protocol analysis system",
"network infrastructure management", "security information system", "network forensics platform"
]
ADVERSARIAL_RISK_OUTCOMES = [
"the action was obfuscated", "a low-and-slow data transfer was detected",
"the process terminated abnormally after execution", "security controls were temporarily disabled",
"alert thresholds were bypassed", "data integrity was compromised", "system uptime was impacted",
"a backdoor was established", "a privilege escalation was achieved", "system resources were depleted",
"data encryption initiated",
"container escape was successful", "Kubernetes RBAC was bypassed", "cloud IAM policies were circumvented",
"container registry was compromised", "serverless function was weaponized", "cloud logging was manipulated",
"microservice communication was intercepted", "container security scanning was evaded",
"cloud resource tagging was abused", "container orchestration was compromised",
"cloud metadata service was exploited", "container networking was hijacked",
"build pipeline was compromised", "dependency repository was poisoned", "artifact signing was bypassed",
"infrastructure automation was weaponized", "secret management was breached", "code repository was compromised",
"deployment process was hijacked", "build environment was escaped", "CI/CD security was bypassed",
"infrastructure monitoring was disabled", "artifact integrity was compromised", "deployment approval was bypassed",
"IoT device was recruited into botnet", "industrial control system was compromised", "edge gateway was breached",
"sensor data was manipulated", "SCADA system was taken offline", "smart city infrastructure was disrupted",
"medical device was compromised", "automotive system was hijacked", "home automation was breached",
"industrial protocol was abused", "edge security was bypassed", "IoT device firmware was modified",
"mobile device was rooted/jailbroken", "enterprise mobile security was bypassed", "mobile app was compromised",
"mobile device management was evaded", "mobile banking was breached", "mobile certificate pinning was bypassed",
"BYOD policy was violated", "mobile endpoint detection was evaded", "mobile app store was poisoned",
"mobile device fingerprinting was spoofed", "mobile phishing was successful", "mobile security scanning was bypassed",
"network segmentation was bypassed", "firewall rules were manipulated", "VPN tunnel was compromised",
"DNS resolution was hijacked", "BGP routing was manipulated", "wireless network was compromised",
"Bluetooth security was bypassed", "NFC communication was intercepted", "network monitoring was evaded",
"traffic analysis was bypassed", "network protocol was abused", "network infrastructure was compromised"
]
ADVERSARIAL_SAFE_ACTORS = [
"Compliance scanning tool", "Automated patching service", "Certificate authority service",
"Security researcher (authorized)", "IT operations team", "Development pipeline script",
"Cloud cost optimization service", "Automated deployment system", "Monitoring agent",
"Authorized administrator",
"Kubernetes security scanner", "Container image vulnerability scanner", "Cloud security posture management tool",
"AWS Config compliance checker", "Azure Security Center agent", "GCP Security Command Center scanner",
"Container runtime security monitor", "Cloud workload protection platform", "Container registry security scanner",
"Serverless security monitoring tool", "Cloud access security broker", "Container orchestration security tool",
"Git security scanning tool", "Dependency vulnerability scanner", "Infrastructure security validator",
"CI/CD security pipeline", "Secret scanning service", "Code quality analyzer",
"Build security scanner", "Deployment security validator", "Infrastructure compliance checker",
"Artifact security scanner", "Pipeline security monitor", "DevSecOps automation tool",
"IoT device security scanner", "Edge security monitoring agent", "Industrial control system validator",
"SCADA security assessment tool", "Smart city security monitor", "Medical device security scanner",
"Automotive security testing tool", "Home automation security validator", "Sensor security monitor",
"Edge gateway security scanner", "Industrial protocol security tool", "IoT device management platform",
"Mobile device management system", "Mobile app security scanner", "Enterprise mobile security platform",
"Mobile threat defense system", "Mobile app store security scanner", "Mobile certificate authority",
"Mobile endpoint detection system", "Mobile security testing platform", "Mobile device compliance checker",
"Mobile app vulnerability scanner", "Mobile security monitoring tool", "Mobile device fingerprinting service",
"Network security scanner", "Firewall management system", "VPN security monitor",
"DNS security service", "BGP security monitoring tool", "Wireless security scanner",
"Bluetooth security validator", "NFC security monitor", "Network traffic analyzer",
"Protocol security scanner", "Network infrastructure monitor", "Security information management system"
]
ADVERSARIAL_SAFE_ACTIONS = [
"successfully completed a vulnerability scan on", "performed a scheduled security certificate rotation for",
"validated firewall integrity for", "completed a successful penetration test on",
"updated system packages for", "deployed new features to", "analyzed resource usage for",
"performed routine backup of", "upgraded database schema for", "configured new firewall rules for",
"successfully scanned container images for vulnerabilities in", "validated Kubernetes RBAC policies for",
"performed cloud security posture assessment on", "completed container runtime security scan on",
"validated cloud IAM policies for", "performed serverless security assessment on",
"completed microservice security audit for", "validated container registry security for",
"performed cloud workload protection scan on", "completed infrastructure as code security review for",
"validated cloud logging and monitoring for", "performed container orchestration security audit on",
"successfully scanned code repository for secrets in", "validated CI/CD pipeline security for",
"performed dependency vulnerability scan on", "completed infrastructure security validation for",
"validated secret management configuration for", "performed build security scan on",
"completed deployment security validation for", "validated artifact integrity for",
"performed code quality security analysis on", "completed pipeline security audit for",
"validated infrastructure compliance for", "performed DevSecOps security assessment on",
"successfully scanned IoT devices for vulnerabilities in", "validated edge security configuration for",
"performed industrial control system security audit on", "completed SCADA security assessment for",
"validated smart city infrastructure security for", "performed medical device security scan on",
"completed automotive system security validation for", "validated home automation security for",
"performed sensor security assessment on", "completed edge gateway security audit for",
"validated industrial protocol security for", "performed IoT device management security scan on",
"successfully scanned mobile devices for threats in", "validated mobile app security for",
"performed enterprise mobile security assessment on", "completed mobile device management validation for",
"validated mobile banking security for", "performed mobile certificate security audit on",
"completed BYOD policy compliance check for", "validated mobile endpoint security for",
"performed mobile app store security scan on", "completed mobile device fingerprinting validation for",
"validated mobile phishing protection for", "performed mobile security testing on",
"successfully scanned network infrastructure for vulnerabilities in", "validated firewall configuration for",
"performed VPN security assessment on", "completed DNS security validation for",
"validated BGP routing security for", "performed wireless security scan on",
"completed Bluetooth security assessment for", "validated NFC security for",
"performed network traffic analysis on", "completed protocol security validation for",
"validated network monitoring security for", "performed network forensics analysis on"
]
ADVERSARIAL_SAFE_TARGETS = [
"the production web server", "the user identity database", "all public-facing endpoints",
"the disaster recovery environment", "the staging environment", "development workstations",
"network perimeter devices", "DNS records", "application load balancer", "storage array",
"Kubernetes cluster security posture", "Docker container security configuration", "AWS cloud infrastructure",
"Azure cloud resources", "GCP cloud services", "container orchestration security",
"serverless function security", "cloud API security", "microservice security architecture",
"container registry security", "cloud logging security", "infrastructure as code repository",
"Git repository security", "CI/CD pipeline security", "Docker image security",
"artifact repository security", "infrastructure provisioning security", "secret management vault",
"code signing certificate store", "dependency management system", "deployment automation platform",
"build environment with elevated privileges", "CI/CD security scanning tools", "infrastructure monitoring system",
"industrial control system security", "SCADA system security", "IoT device security",
"edge computing security", "smart city infrastructure security", "medical device network",
"automotive system security", "home automation security", "sensor security",
"industrial protocol gateway", "edge security monitoring system", "IoT device firmware repository",
"enterprise mobile device fleet", "mobile app store backend", "mobile device management system",
"mobile banking infrastructure", "mobile certificate authority", "mobile security scanning service",
"BYOD policy enforcement system", "mobile endpoint detection system", "mobile app security testing platform",
"mobile device fingerprinting database", "mobile phishing detection system", "mobile app code signing service",
"network segmentation firewall", "VPN concentrator", "DNS authoritative server",
"BGP route reflector", "wireless access point controller", "network monitoring system",
"traffic analysis platform", "network security scanning tool", "protocol analysis system",
"network infrastructure management", "security information system", "network forensics platform"
]
ADVERSARIAL_SAFE_OUTCOMES = [
"all tests passed, security posture confirmed", "the configuration was hardened as per policy",
"the certificate was successfully renewed without downtime", "vulnerabilities were patched and verified",
"system performance improved", "new functionality rolled out successfully",
"resources optimized for cost efficiency", "backup completed without errors",
"schema migration successful", "network policy updated",
"container security scan completed successfully", "Kubernetes RBAC policies validated", "cloud security posture improved",
"container runtime security verified", "cloud IAM policies hardened", "serverless security validated",
"microservice security architecture confirmed", "container registry security verified", "cloud workload protection enabled",
"infrastructure as code security validated", "cloud logging security confirmed", "container orchestration security verified",
"code repository security scan passed", "CI/CD pipeline security validated", "container image security verified",
"artifact repository security confirmed", "infrastructure security validated", "secret management security verified",
"code signing security confirmed", "dependency security validated", "deployment security verified",
"build security validated", "pipeline security confirmed", "infrastructure compliance verified",
"IoT device security scan completed", "edge security configuration validated", "industrial control system security verified",
"SCADA security assessment passed", "smart city infrastructure security confirmed", "medical device security validated",
"automotive system security verified", "home automation security confirmed", "sensor security validated",
"edge gateway security verified", "industrial protocol security confirmed", "IoT device management security validated",
"mobile device security scan completed", "mobile app security validated", "enterprise mobile security confirmed",
"mobile device management security verified", "mobile banking security validated", "mobile certificate security confirmed",
"BYOD policy compliance verified", "mobile endpoint security validated", "mobile app store security confirmed",
"mobile device fingerprinting security verified", "mobile phishing protection validated", "mobile security testing completed",
"network security scan completed", "firewall security validated", "VPN security confirmed",
"DNS security verified", "BGP routing security validated", "wireless security confirmed",
"Bluetooth security validated", "NFC security confirmed", "network monitoring security verified",
"traffic analysis security validated", "protocol security confirmed", "network infrastructure security verified"
]
HIGH_RISK_KEYWORDS = {
'failed': 15, 'unauthorized': 20, 'invalid': 15, 'blocked': 25, 'mfa_failed': 30, 'brute_force': 40, 'attack': 40,
'threat': 30, 'compromise': 30, 'malicious': 35, 'lockout': 25, 'critical': 20, 'urgent': 20, 'severe': 25,
'breach': 40, 'exfiltration': 40, 'injection': 35, 'malware': 35, 'vulnerability': 25, 'exploit': 30,
'lateral movement': 40, 'dns tunneling': 35, 'obfuscated': 25, 'anomaly': 20, 'misconfigured': 30,
'ransomware': 50, 'phishing': 45, 'insider threat': 40, 'zero-day': 50, 'unauthorized access': 35, 'data integrity': 30,
'compromised credential': 40, 'vulnerable library': 30, 'sql injection': 35, 'privilege escalation': 45
}
LOW_RISK_KEYWORDS = {
'success': -20, 'successful': -20, 'normal': -15, 'routine': -15, 'authorized': -10, 'benign': -15, 'secure': -10,
'safe': -15, 'approved': -10, 'expected': -5, 'completed': -10,
'scan completed': -25, 'validated': -15, 'patched': -20, 'renewed': -15, 'posture confirmed': -30,
'performance improved': -10, 'functionality rolled out': -10, 'resources optimized': -15,
'backup completed': -20, 'schema migration successful': -15, 'network policy updated': -10
}
# --- Funções de Geração de Dados de TREINAMENTO (Base Sólida) ---
def generate_event_text_for_training(is_risk: bool) -> tuple[str, float]:
if is_risk:
actor = random.choice(ADVERSARIAL_RISK_ACTORS)
action = random.choice(ADVERSARIAL_RISK_ACTIONS)
target = random.choice(ADVERSARIAL_RISK_TARGETS)
outcome = random.choice(ADVERSARIAL_RISK_OUTCOMES)
text = f"Audit log: {actor} {action} {target} during a routine maintenance window. Status: {outcome}."
base_score = random.uniform(75, 95)
else:
actor = random.choice(ADVERSARIAL_SAFE_ACTORS)
action = random.choice(ADVERSARIAL_SAFE_ACTIONS)
target = random.choice(ADVERSARIAL_SAFE_TARGETS)
outcome = random.choice(ADVERSARIAL_SAFE_OUTCOMES)
text = f"CRITICAL alert resolved: {actor} {action} {target}. Final status: {outcome}."
base_score = random.uniform(5, 25)
score = np.clip(base_score + random.uniform(-5, 5), 0, 100)
return text, float(score)
def generate_data_chunk_for_db_worker(args: tuple[int, bool]) -> list[tuple[str, float]]:
chunk_size, is_risk = args
return [generate_event_text_for_training(is_risk) for _ in range(chunk_size)]
def populate_database_initial():
total_events = NUM_INITIAL_TRAIN_EVENTS_PER_CLASS * 2
print(f"Populando o banco de dados com {total_events} eventos de treinamento iniciais em paralelo usando {NUM_PROCESSES} processos...")
tasks = []
chunk_per_process = NUM_INITIAL_TRAIN_EVENTS_PER_CLASS // NUM_PROCESSES
remainder_risk = NUM_INITIAL_TRAIN_EVENTS_PER_CLASS % NUM_PROCESSES
remainder_safe = NUM_INITIAL_TRAIN_EVENTS_PER_CLASS % NUM_PROCESSES
for i in range(NUM_PROCESSES):
tasks.append((chunk_per_process + (1 if i < remainder_risk else 0), True))
tasks.append((chunk_per_process + (1 if i < remainder_safe else 0), False))
tasks = [(c, ir) for c, ir in tasks if c > 0]
all_data = []
with ProcessPoolExecutor(max_workers=NUM_PROCESSES) as executor:
futures = [executor.submit(generate_data_chunk_for_db_worker, t) for t in tasks]
for future in as_completed(futures):
all_data.extend(future.result())
print(f"Inserindo {len(all_data)} eventos no banco de dados '{DB_NAME}'...")
conn = sqlite3.connect(DB_NAME)
cursor = conn.cursor()
cursor.executemany(f"INSERT INTO {TABLE_NAME} (text, risk_score) VALUES (?, ?)", all_data)
conn.commit()
conn.close()
print("Banco de dados populado inicialmente com sucesso.")
# --- Funções de Embedding e Treinamento ---
def init_sbert_worker():
global model_base
if model_base is None:
print(f"Processo worker {os.getpid()} carregando o modelo {MODEL_NAME}...")
model_base = SentenceTransformer(MODEL_NAME)
torch.set_num_threads(1)
def extract_embeddings_batch_worker(texts: list[str]) -> list[list[float]]:
global model_base
if model_base is None:
raise RuntimeError("SentenceTransformer não foi inicializado no worker.")
embeddings = model_base.encode(texts, convert_to_numpy=True, show_progress_bar=False)
return embeddings.tolist()
def train_and_save_all_models():
global mlp_regressor, scaler, tfidf_vectorizer, tfidf_regressor # Model_base é para os workers, não para o principal
print("Iniciando o treinamento de todos os modelos a partir do banco de dados...")
conn = sqlite3.connect(DB_NAME)
train_data = conn.execute(f"SELECT text, risk_score FROM {TABLE_NAME}").fetchall()
conn.close()
if not train_data:
print("ERRO: Banco de dados de treinamento vazio. Não é possível treinar modelos.")
raise RuntimeError("Banco de dados de treinamento vazio.")
random.shuffle(train_data)
train_texts = [row[0] for row in train_data]
y_train = np.array([row[1] for row in train_data])
# --- Cabeça 1: Embedding Profundo (MLPRegressor) ---
print("1. Treinando modelo de Embedding Profundo (MLPRegressor)...")
X_train_embeddings = []
texts_per_process = math.ceil(len(train_texts) / NUM_PROCESSES)
text_chunks_for_workers = [train_texts[i:i + texts_per_process] for i in range(0, len(train_texts), texts_per_process)]
with ProcessPoolExecutor(max_workers=NUM_PROCESSES, initializer=init_sbert_worker) as executor:
futures = [executor.submit(extract_embeddings_batch_worker, chunk) for chunk in text_chunks_for_workers]
for future in as_completed(futures):
X_train_embeddings.extend(future.result())
X_train_embeddings = np.array(X_train_embeddings)
if X_train_embeddings.shape[0] == 0:
print("ERRO: Nenhum embedding extraído. Verifique os dados de treinamento ou o modelo.")
raise RuntimeError("Nenhum embedding extraído para treinamento.")
scaler = StandardScaler()
X_train_embeddings_scaled = scaler.fit_transform(X_train_embeddings)
mlp_regressor = MLPRegressor(
hidden_layer_sizes=(768, 384, 192),
activation='relu',
solver='adam',
max_iter=500,
random_state=RANDOM_SEED,
early_stopping=True,
n_iter_no_change=30,
alpha=0.005,
learning_rate_init=0.0005,
batch_size=256,
verbose=False
)
print("Treinando a rede neural (MLP)... Isso pode levar alguns minutos.")
mlp_regressor.fit(X_train_embeddings_scaled, y_train)
print(" ... modelo de Embedding Profundo treinado.")
# --- Cabeça 2: Vetorial Clássico (TF-IDF) ---
print("2. Treinando modelo Vetorial Clássico (TF-IDF + Ridge)...")
tfidf_vectorizer = TfidfVectorizer(ngram_range=(1, 3), min_df=5, max_df=0.7, max_features=10000)
X_train_tfidf = tfidf_vectorizer.fit_transform(train_texts)
tfidf_regressor = Ridge(alpha=1.0, random_state=RANDOM_SEED)
tfidf_regressor.fit(X_train_tfidf, y_train)
print(" ... modelo TF-IDF treinado.")
print("Todos os modelos de risco foram treinados com sucesso!")
# SALVAR OS MODELOS
print("Salvando modelos treinados...")
joblib.dump(mlp_regressor, "mlp_regressor.joblib")
joblib.dump(scaler, "scaler.joblib")
joblib.dump(tfidf_vectorizer, "tfidf_vectorizer.joblib")
joblib.dump(tfidf_regressor, "tfidf_regressor.joblib")
print("Modelos salvos com sucesso.")
if __name__ == "__main__":
print("Iniciando pré-carregamento e treinamento de modelos para salvar...")
db_path = Path(DB_NAME)
if not db_path.exists() or db_path.stat().st_size == 0:
print(f"Banco de dados de treinamento '{DB_NAME}' não encontrado ou vazio. Criando e populando inicialmente...")
if db_path.exists():
os.remove(db_path)
conn = sqlite3.connect(DB_NAME)
cursor = conn.cursor()
cursor.execute(f'CREATE TABLE IF NOT EXISTS {TABLE_NAME} (id INTEGER PRIMARY KEY, text TEXT NOT NULL, risk_score REAL NOT NULL)')
conn.commit()
conn.close()
populate_database_initial()
else:
print(f"Banco de dados de treinamento '{DB_NAME}' encontrado. Pulando a geração de dados inicial.")
try:
train_and_save_all_models()
print("Processo de treinamento e salvamento concluído com sucesso.")
except RuntimeError as e:
print(f"ERRO FATAL: Falha ao treinar ou carregar modelos. Erro: {e}")
exit(1)