Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import torch | |
| import random | |
| import sqlite3 | |
| import math | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| from concurrent.futures import ProcessPoolExecutor, as_completed | |
| import numpy as np | |
| import multiprocessing | |
| import joblib # Importado para salvar os modelos | |
| # REMOVIDAS importações Flask e CORS, pois este script é apenas para treinamento local. | |
| # from flask import Flask, request, jsonify | |
| # from flask_cors import CORS | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.neural_network import MLPRegressor | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.linear_model import Ridge | |
| # --- Configuração Inicial --- | |
| DB_NAME = "training_data_large.db" | |
| TABLE_NAME = "events" | |
| MODEL_NAME = "markusbayer/CySecBERT" | |
| RANDOM_SEED = 42 | |
| NUM_INITIAL_TRAIN_EVENTS_PER_CLASS = 2000 | |
| RISK_THRESHOLD = 50.0 | |
| try: | |
| multiprocessing.set_start_method('spawn', force=True) | |
| except RuntimeError: | |
| pass | |
| NUM_PROCESSES = os.cpu_count() or 4 | |
| random.seed(RANDOM_SEED) | |
| np.random.seed(RANDOM_SEED) | |
| torch.manual_seed(RANDOM_SEED) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(RANDOM_SEED) | |
| model_base = None # Será inicializado durante o treinamento | |
| mlp_regressor, scaler = None, None | |
| tfidf_vectorizer, tfidf_regressor = None, None | |
| # --- Vocabulário Expandido (compartilhado para geração de dados) --- | |
| # Estas listas são usadas SOMENTE no trainer.py para gerar os dados de treinamento. | |
| # Elas serão duplicadas (ou uma versão reduzida) no app.py se o "Gerar Evento Aleatório" for mantido. | |
| ADVERSARIAL_RISK_ACTORS = [ | |
| "Unsandboxed process", "Leaked API key", "Misconfigured service account", "Shadow IT application", | |
| "Dormant user account", "Ransomware payload", "Phishing attempt", "Insider threat", | |
| "Zero-day exploit", "Malicious actor", "Compromised credential", "Vulnerable third-party library", | |
| "Compromised Kubernetes pod", "Malicious Docker container", "AWS IAM role escalation", | |
| "Azure AD privilege escalation", "GCP service account abuse", "Container escape attempt", | |
| "Serverless function injection", "Cloud storage bucket enumeration", "API gateway bypass", | |
| "Microservice lateral movement", "Container registry poisoning", "Cloud metadata exploitation", | |
| "CI/CD pipeline compromise", "Git repository poisoning", "Build artifact tampering", | |
| "Deployment script injection", "Infrastructure as Code attack", "Secret scanning bypass", | |
| "Dependency confusion attack", "Supply chain compromise", "Code signing certificate theft", | |
| "Pipeline privilege escalation", "Artifact repository poisoning", "Build environment escape", | |
| "Compromised IoT device", "Edge computing exploit", "Industrial control system breach", | |
| "SCADA system compromise", "Smart city infrastructure attack", "Medical device exploitation", | |
| "Automotive system breach", "Home automation compromise", "Sensor data manipulation", | |
| "Edge gateway exploitation", "Industrial protocol abuse", "IoT botnet recruitment", | |
| "Mobile app sandbox escape", "iOS jailbreak exploitation", "Android rootkit installation", | |
| "Mobile banking trojan", "Enterprise device compromise", "BYOD policy violation", | |
| "Mobile device management bypass", "App store poisoning", "Mobile certificate pinning bypass", | |
| "Endpoint detection evasion", "Mobile phishing campaign", "Device fingerprinting abuse", | |
| "Network segmentation bypass", "Firewall rule manipulation", "VPN tunnel exploitation", | |
| "DNS hijacking attempt", "BGP route hijacking", "Network protocol abuse", | |
| "Wireless network compromise", "Bluetooth attack vector", "NFC exploitation", | |
| "Network monitoring evasion", "Traffic analysis bypass", "Protocol fuzzing attack" | |
| ] | |
| ADVERSARIAL_RISK_ACTIONS = [ | |
| "attempted lateral movement via", "initiated a DNS tunneling request to", | |
| "executed a living-off-the-land binary on", "was flagged for unusual API call patterns against", | |
| "triggered a data access anomaly in", "exfiltrated data from", "modified critical system files in", | |
| "gained unauthorized access to", "deployed malicious code on", "brute-forced login for", | |
| "injected SQL into", "exploited a vulnerability in", | |
| "attempted container escape from", "escalated privileges in Kubernetes cluster", | |
| "abused IAM role permissions for", "enumerated cloud storage buckets through", | |
| "bypassed API gateway authentication to", "injected malicious code into serverless function", | |
| "compromised container registry access for", "exploited cloud metadata service to", | |
| "performed lateral movement across microservices in", "poisoned container image in", | |
| "abused cloud resource tagging for", "exploited cloud logging service to", | |
| "compromised CI/CD pipeline to", "injected malicious code into build process for", | |
| "poisoned dependency repository to", "tampered with build artifacts in", | |
| "escalated privileges in deployment pipeline for", "bypassed security scanning in", | |
| "abused infrastructure automation to", "compromised secret management system for", | |
| "injected malicious code into deployment scripts for", "exploited build environment to", | |
| "abused artifact signing process for", "compromised code repository access to", | |
| "compromised IoT device firmware to", "exploited edge computing vulnerability in", | |
| "breached industrial control system through", "manipulated sensor data from", | |
| "exploited SCADA system vulnerability to", "compromised smart city infrastructure via", | |
| "abused industrial protocol to", "exploited edge gateway vulnerability in", | |
| "recruited device into botnet through", "compromised medical device firmware to", | |
| "exploited automotive system vulnerability in", "breached home automation system via", | |
| "escaped mobile app sandbox to", "exploited iOS jailbreak vulnerability in", | |
| "installed rootkit on Android device to", "compromised enterprise mobile device through", | |
| "bypassed mobile device management to", "poisoned mobile app store listing for", | |
| "exploited mobile certificate pinning in", "compromised mobile banking app through", | |
| "abused device fingerprinting to", "exploited mobile phishing vulnerability in", | |
| "breached BYOD policy through", "compromised mobile endpoint security via", | |
| "bypassed network segmentation to", "manipulated firewall rules for", | |
| "exploited VPN tunnel vulnerability in", "hijacked DNS resolution for", | |
| "abused BGP routing protocol to", "compromised wireless network through", | |
| "exploited Bluetooth vulnerability in", "abused NFC communication to", | |
| "evaded network monitoring through", "bypassed traffic analysis via", | |
| "exploited network protocol vulnerability in", "compromised network infrastructure through" | |
| ] | |
| ADVERSARIAL_RISK_TARGETS = [ | |
| "a code repository", "the CI/CD pipeline", "a cloud storage bucket", "the internal DNS server", | |
| "the virtual machine hypervisor", "sensitive customer data", "financial databases", | |
| "intellectual property servers", "critical infrastructure controls", "user authentication service", | |
| "production web server", "database backup storage", | |
| "Kubernetes cluster control plane", "Docker container registry", "AWS S3 bucket with sensitive data", | |
| "Azure Active Directory tenant", "GCP Cloud Storage bucket", "container orchestration system", | |
| "serverless function environment", "cloud API gateway", "microservice mesh network", | |
| "container security scanning service", "cloud logging and monitoring system", "infrastructure as code repository", | |
| "Git repository with production secrets", "Jenkins build pipeline", "Docker image registry", | |
| "artifact repository with signed packages", "infrastructure provisioning system", "secret management vault", | |
| "code signing certificate store", "dependency management system", "deployment automation platform", | |
| "build environment with elevated privileges", "CI/CD security scanning tools", "infrastructure monitoring system", | |
| "industrial control system network", "SCADA system database", "IoT device management platform", | |
| "edge computing gateway", "smart city infrastructure", "medical device network", | |
| "automotive system bus", "home automation hub", "sensor data collection system", | |
| "industrial protocol gateway", "edge security monitoring system", "IoT device firmware repository", | |
| "enterprise mobile device fleet", "mobile app store backend", "mobile device management system", | |
| "mobile banking infrastructure", "mobile certificate authority", "mobile security scanning service", | |
| "BYOD policy enforcement system", "mobile endpoint detection system", "mobile app security testing platform", | |
| "mobile device fingerprinting database", "mobile phishing detection system", "mobile app code signing service", | |
| "network segmentation firewall", "VPN concentrator", "DNS authoritative server", | |
| "BGP route reflector", "wireless access point controller", "network monitoring system", | |
| "traffic analysis platform", "network security scanning tool", "protocol analysis system", | |
| "network infrastructure management", "security information system", "network forensics platform" | |
| ] | |
| ADVERSARIAL_RISK_OUTCOMES = [ | |
| "the action was obfuscated", "a low-and-slow data transfer was detected", | |
| "the process terminated abnormally after execution", "security controls were temporarily disabled", | |
| "alert thresholds were bypassed", "data integrity was compromised", "system uptime was impacted", | |
| "a backdoor was established", "a privilege escalation was achieved", "system resources were depleted", | |
| "data encryption initiated", | |
| "container escape was successful", "Kubernetes RBAC was bypassed", "cloud IAM policies were circumvented", | |
| "container registry was compromised", "serverless function was weaponized", "cloud logging was manipulated", | |
| "microservice communication was intercepted", "container security scanning was evaded", | |
| "cloud resource tagging was abused", "container orchestration was compromised", | |
| "cloud metadata service was exploited", "container networking was hijacked", | |
| "build pipeline was compromised", "dependency repository was poisoned", "artifact signing was bypassed", | |
| "infrastructure automation was weaponized", "secret management was breached", "code repository was compromised", | |
| "deployment process was hijacked", "build environment was escaped", "CI/CD security was bypassed", | |
| "infrastructure monitoring was disabled", "artifact integrity was compromised", "deployment approval was bypassed", | |
| "IoT device was recruited into botnet", "industrial control system was compromised", "edge gateway was breached", | |
| "sensor data was manipulated", "SCADA system was taken offline", "smart city infrastructure was disrupted", | |
| "medical device was compromised", "automotive system was hijacked", "home automation was breached", | |
| "industrial protocol was abused", "edge security was bypassed", "IoT device firmware was modified", | |
| "mobile device was rooted/jailbroken", "enterprise mobile security was bypassed", "mobile app was compromised", | |
| "mobile device management was evaded", "mobile banking was breached", "mobile certificate pinning was bypassed", | |
| "BYOD policy was violated", "mobile endpoint detection was evaded", "mobile app store was poisoned", | |
| "mobile device fingerprinting was spoofed", "mobile phishing was successful", "mobile security scanning was bypassed", | |
| "network segmentation was bypassed", "firewall rules were manipulated", "VPN tunnel was compromised", | |
| "DNS resolution was hijacked", "BGP routing was manipulated", "wireless network was compromised", | |
| "Bluetooth security was bypassed", "NFC communication was intercepted", "network monitoring was evaded", | |
| "traffic analysis was bypassed", "network protocol was abused", "network infrastructure was compromised" | |
| ] | |
| ADVERSARIAL_SAFE_ACTORS = [ | |
| "Compliance scanning tool", "Automated patching service", "Certificate authority service", | |
| "Security researcher (authorized)", "IT operations team", "Development pipeline script", | |
| "Cloud cost optimization service", "Automated deployment system", "Monitoring agent", | |
| "Authorized administrator", | |
| "Kubernetes security scanner", "Container image vulnerability scanner", "Cloud security posture management tool", | |
| "AWS Config compliance checker", "Azure Security Center agent", "GCP Security Command Center scanner", | |
| "Container runtime security monitor", "Cloud workload protection platform", "Container registry security scanner", | |
| "Serverless security monitoring tool", "Cloud access security broker", "Container orchestration security tool", | |
| "Git security scanning tool", "Dependency vulnerability scanner", "Infrastructure security validator", | |
| "CI/CD security pipeline", "Secret scanning service", "Code quality analyzer", | |
| "Build security scanner", "Deployment security validator", "Infrastructure compliance checker", | |
| "Artifact security scanner", "Pipeline security monitor", "DevSecOps automation tool", | |
| "IoT device security scanner", "Edge security monitoring agent", "Industrial control system validator", | |
| "SCADA security assessment tool", "Smart city security monitor", "Medical device security scanner", | |
| "Automotive security testing tool", "Home automation security validator", "Sensor security monitor", | |
| "Edge gateway security scanner", "Industrial protocol security tool", "IoT device management platform", | |
| "Mobile device management system", "Mobile app security scanner", "Enterprise mobile security platform", | |
| "Mobile threat defense system", "Mobile app store security scanner", "Mobile certificate authority", | |
| "Mobile endpoint detection system", "Mobile security testing platform", "Mobile device compliance checker", | |
| "Mobile app vulnerability scanner", "Mobile security monitoring tool", "Mobile device fingerprinting service", | |
| "Network security scanner", "Firewall management system", "VPN security monitor", | |
| "DNS security service", "BGP security monitoring tool", "Wireless security scanner", | |
| "Bluetooth security validator", "NFC security monitor", "Network traffic analyzer", | |
| "Protocol security scanner", "Network infrastructure monitor", "Security information management system" | |
| ] | |
| ADVERSARIAL_SAFE_ACTIONS = [ | |
| "successfully completed a vulnerability scan on", "performed a scheduled security certificate rotation for", | |
| "validated firewall integrity for", "completed a successful penetration test on", | |
| "updated system packages for", "deployed new features to", "analyzed resource usage for", | |
| "performed routine backup of", "upgraded database schema for", "configured new firewall rules for", | |
| "successfully scanned container images for vulnerabilities in", "validated Kubernetes RBAC policies for", | |
| "performed cloud security posture assessment on", "completed container runtime security scan on", | |
| "validated cloud IAM policies for", "performed serverless security assessment on", | |
| "completed microservice security audit for", "validated container registry security for", | |
| "performed cloud workload protection scan on", "completed infrastructure as code security review for", | |
| "validated cloud logging and monitoring for", "performed container orchestration security audit on", | |
| "successfully scanned code repository for secrets in", "validated CI/CD pipeline security for", | |
| "performed dependency vulnerability scan on", "completed infrastructure security validation for", | |
| "validated secret management configuration for", "performed build security scan on", | |
| "completed deployment security validation for", "validated artifact integrity for", | |
| "performed code quality security analysis on", "completed pipeline security audit for", | |
| "validated infrastructure compliance for", "performed DevSecOps security assessment on", | |
| "successfully scanned IoT devices for vulnerabilities in", "validated edge security configuration for", | |
| "performed industrial control system security audit on", "completed SCADA security assessment for", | |
| "validated smart city infrastructure security for", "performed medical device security scan on", | |
| "completed automotive system security validation for", "validated home automation security for", | |
| "performed sensor security assessment on", "completed edge gateway security audit for", | |
| "validated industrial protocol security for", "performed IoT device management security scan on", | |
| "successfully scanned mobile devices for threats in", "validated mobile app security for", | |
| "performed enterprise mobile security assessment on", "completed mobile device management validation for", | |
| "validated mobile banking security for", "performed mobile certificate security audit on", | |
| "completed BYOD policy compliance check for", "validated mobile endpoint security for", | |
| "performed mobile app store security scan on", "completed mobile device fingerprinting validation for", | |
| "validated mobile phishing protection for", "performed mobile security testing on", | |
| "successfully scanned network infrastructure for vulnerabilities in", "validated firewall configuration for", | |
| "performed VPN security assessment on", "completed DNS security validation for", | |
| "validated BGP routing security for", "performed wireless security scan on", | |
| "completed Bluetooth security assessment for", "validated NFC security for", | |
| "performed network traffic analysis on", "completed protocol security validation for", | |
| "validated network monitoring security for", "performed network forensics analysis on" | |
| ] | |
| ADVERSARIAL_SAFE_TARGETS = [ | |
| "the production web server", "the user identity database", "all public-facing endpoints", | |
| "the disaster recovery environment", "the staging environment", "development workstations", | |
| "network perimeter devices", "DNS records", "application load balancer", "storage array", | |
| "Kubernetes cluster security posture", "Docker container security configuration", "AWS cloud infrastructure", | |
| "Azure cloud resources", "GCP cloud services", "container orchestration security", | |
| "serverless function security", "cloud API security", "microservice security architecture", | |
| "container registry security", "cloud logging security", "infrastructure as code repository", | |
| "Git repository security", "CI/CD pipeline security", "Docker image security", | |
| "artifact repository security", "infrastructure provisioning security", "secret management vault", | |
| "code signing certificate store", "dependency management system", "deployment automation platform", | |
| "build environment with elevated privileges", "CI/CD security scanning tools", "infrastructure monitoring system", | |
| "industrial control system security", "SCADA system security", "IoT device security", | |
| "edge computing security", "smart city infrastructure security", "medical device network", | |
| "automotive system security", "home automation security", "sensor security", | |
| "industrial protocol gateway", "edge security monitoring system", "IoT device firmware repository", | |
| "enterprise mobile device fleet", "mobile app store backend", "mobile device management system", | |
| "mobile banking infrastructure", "mobile certificate authority", "mobile security scanning service", | |
| "BYOD policy enforcement system", "mobile endpoint detection system", "mobile app security testing platform", | |
| "mobile device fingerprinting database", "mobile phishing detection system", "mobile app code signing service", | |
| "network segmentation firewall", "VPN concentrator", "DNS authoritative server", | |
| "BGP route reflector", "wireless access point controller", "network monitoring system", | |
| "traffic analysis platform", "network security scanning tool", "protocol analysis system", | |
| "network infrastructure management", "security information system", "network forensics platform" | |
| ] | |
| ADVERSARIAL_SAFE_OUTCOMES = [ | |
| "all tests passed, security posture confirmed", "the configuration was hardened as per policy", | |
| "the certificate was successfully renewed without downtime", "vulnerabilities were patched and verified", | |
| "system performance improved", "new functionality rolled out successfully", | |
| "resources optimized for cost efficiency", "backup completed without errors", | |
| "schema migration successful", "network policy updated", | |
| "container security scan completed successfully", "Kubernetes RBAC policies validated", "cloud security posture improved", | |
| "container runtime security verified", "cloud IAM policies hardened", "serverless security validated", | |
| "microservice security architecture confirmed", "container registry security verified", "cloud workload protection enabled", | |
| "infrastructure as code security validated", "cloud logging security confirmed", "container orchestration security verified", | |
| "code repository security scan passed", "CI/CD pipeline security validated", "container image security verified", | |
| "artifact repository security confirmed", "infrastructure security validated", "secret management security verified", | |
| "code signing security confirmed", "dependency security validated", "deployment security verified", | |
| "build security validated", "pipeline security confirmed", "infrastructure compliance verified", | |
| "IoT device security scan completed", "edge security configuration validated", "industrial control system security verified", | |
| "SCADA security assessment passed", "smart city infrastructure security confirmed", "medical device security validated", | |
| "automotive system security verified", "home automation security confirmed", "sensor security validated", | |
| "edge gateway security verified", "industrial protocol security confirmed", "IoT device management security validated", | |
| "mobile device security scan completed", "mobile app security validated", "enterprise mobile security confirmed", | |
| "mobile device management security verified", "mobile banking security validated", "mobile certificate security confirmed", | |
| "BYOD policy compliance verified", "mobile endpoint security validated", "mobile app store security confirmed", | |
| "mobile device fingerprinting security verified", "mobile phishing protection validated", "mobile security testing completed", | |
| "network security scan completed", "firewall security validated", "VPN security confirmed", | |
| "DNS security verified", "BGP routing security validated", "wireless security confirmed", | |
| "Bluetooth security validated", "NFC security confirmed", "network monitoring security verified", | |
| "traffic analysis security validated", "protocol security confirmed", "network infrastructure security verified" | |
| ] | |
| HIGH_RISK_KEYWORDS = { | |
| 'failed': 15, 'unauthorized': 20, 'invalid': 15, 'blocked': 25, 'mfa_failed': 30, 'brute_force': 40, 'attack': 40, | |
| 'threat': 30, 'compromise': 30, 'malicious': 35, 'lockout': 25, 'critical': 20, 'urgent': 20, 'severe': 25, | |
| 'breach': 40, 'exfiltration': 40, 'injection': 35, 'malware': 35, 'vulnerability': 25, 'exploit': 30, | |
| 'lateral movement': 40, 'dns tunneling': 35, 'obfuscated': 25, 'anomaly': 20, 'misconfigured': 30, | |
| 'ransomware': 50, 'phishing': 45, 'insider threat': 40, 'zero-day': 50, 'unauthorized access': 35, 'data integrity': 30, | |
| 'compromised credential': 40, 'vulnerable library': 30, 'sql injection': 35, 'privilege escalation': 45 | |
| } | |
| LOW_RISK_KEYWORDS = { | |
| 'success': -20, 'successful': -20, 'normal': -15, 'routine': -15, 'authorized': -10, 'benign': -15, 'secure': -10, | |
| 'safe': -15, 'approved': -10, 'expected': -5, 'completed': -10, | |
| 'scan completed': -25, 'validated': -15, 'patched': -20, 'renewed': -15, 'posture confirmed': -30, | |
| 'performance improved': -10, 'functionality rolled out': -10, 'resources optimized': -15, | |
| 'backup completed': -20, 'schema migration successful': -15, 'network policy updated': -10 | |
| } | |
| # --- Funções de Geração de Dados de TREINAMENTO (Base Sólida) --- | |
| def generate_event_text_for_training(is_risk: bool) -> tuple[str, float]: | |
| if is_risk: | |
| actor = random.choice(ADVERSARIAL_RISK_ACTORS) | |
| action = random.choice(ADVERSARIAL_RISK_ACTIONS) | |
| target = random.choice(ADVERSARIAL_RISK_TARGETS) | |
| outcome = random.choice(ADVERSARIAL_RISK_OUTCOMES) | |
| text = f"Audit log: {actor} {action} {target} during a routine maintenance window. Status: {outcome}." | |
| base_score = random.uniform(75, 95) | |
| else: | |
| actor = random.choice(ADVERSARIAL_SAFE_ACTORS) | |
| action = random.choice(ADVERSARIAL_SAFE_ACTIONS) | |
| target = random.choice(ADVERSARIAL_SAFE_TARGETS) | |
| outcome = random.choice(ADVERSARIAL_SAFE_OUTCOMES) | |
| text = f"CRITICAL alert resolved: {actor} {action} {target}. Final status: {outcome}." | |
| base_score = random.uniform(5, 25) | |
| score = np.clip(base_score + random.uniform(-5, 5), 0, 100) | |
| return text, float(score) | |
| def generate_data_chunk_for_db_worker(args: tuple[int, bool]) -> list[tuple[str, float]]: | |
| chunk_size, is_risk = args | |
| return [generate_event_text_for_training(is_risk) for _ in range(chunk_size)] | |
| def populate_database_initial(): | |
| total_events = NUM_INITIAL_TRAIN_EVENTS_PER_CLASS * 2 | |
| print(f"Populando o banco de dados com {total_events} eventos de treinamento iniciais em paralelo usando {NUM_PROCESSES} processos...") | |
| tasks = [] | |
| chunk_per_process = NUM_INITIAL_TRAIN_EVENTS_PER_CLASS // NUM_PROCESSES | |
| remainder_risk = NUM_INITIAL_TRAIN_EVENTS_PER_CLASS % NUM_PROCESSES | |
| remainder_safe = NUM_INITIAL_TRAIN_EVENTS_PER_CLASS % NUM_PROCESSES | |
| for i in range(NUM_PROCESSES): | |
| tasks.append((chunk_per_process + (1 if i < remainder_risk else 0), True)) | |
| tasks.append((chunk_per_process + (1 if i < remainder_safe else 0), False)) | |
| tasks = [(c, ir) for c, ir in tasks if c > 0] | |
| all_data = [] | |
| with ProcessPoolExecutor(max_workers=NUM_PROCESSES) as executor: | |
| futures = [executor.submit(generate_data_chunk_for_db_worker, t) for t in tasks] | |
| for future in as_completed(futures): | |
| all_data.extend(future.result()) | |
| print(f"Inserindo {len(all_data)} eventos no banco de dados '{DB_NAME}'...") | |
| conn = sqlite3.connect(DB_NAME) | |
| cursor = conn.cursor() | |
| cursor.executemany(f"INSERT INTO {TABLE_NAME} (text, risk_score) VALUES (?, ?)", all_data) | |
| conn.commit() | |
| conn.close() | |
| print("Banco de dados populado inicialmente com sucesso.") | |
| # --- Funções de Embedding e Treinamento --- | |
| def init_sbert_worker(): | |
| global model_base | |
| if model_base is None: | |
| print(f"Processo worker {os.getpid()} carregando o modelo {MODEL_NAME}...") | |
| model_base = SentenceTransformer(MODEL_NAME) | |
| torch.set_num_threads(1) | |
| def extract_embeddings_batch_worker(texts: list[str]) -> list[list[float]]: | |
| global model_base | |
| if model_base is None: | |
| raise RuntimeError("SentenceTransformer não foi inicializado no worker.") | |
| embeddings = model_base.encode(texts, convert_to_numpy=True, show_progress_bar=False) | |
| return embeddings.tolist() | |
| def train_and_save_all_models(): | |
| global mlp_regressor, scaler, tfidf_vectorizer, tfidf_regressor # Model_base é para os workers, não para o principal | |
| print("Iniciando o treinamento de todos os modelos a partir do banco de dados...") | |
| conn = sqlite3.connect(DB_NAME) | |
| train_data = conn.execute(f"SELECT text, risk_score FROM {TABLE_NAME}").fetchall() | |
| conn.close() | |
| if not train_data: | |
| print("ERRO: Banco de dados de treinamento vazio. Não é possível treinar modelos.") | |
| raise RuntimeError("Banco de dados de treinamento vazio.") | |
| random.shuffle(train_data) | |
| train_texts = [row[0] for row in train_data] | |
| y_train = np.array([row[1] for row in train_data]) | |
| # --- Cabeça 1: Embedding Profundo (MLPRegressor) --- | |
| print("1. Treinando modelo de Embedding Profundo (MLPRegressor)...") | |
| X_train_embeddings = [] | |
| texts_per_process = math.ceil(len(train_texts) / NUM_PROCESSES) | |
| text_chunks_for_workers = [train_texts[i:i + texts_per_process] for i in range(0, len(train_texts), texts_per_process)] | |
| with ProcessPoolExecutor(max_workers=NUM_PROCESSES, initializer=init_sbert_worker) as executor: | |
| futures = [executor.submit(extract_embeddings_batch_worker, chunk) for chunk in text_chunks_for_workers] | |
| for future in as_completed(futures): | |
| X_train_embeddings.extend(future.result()) | |
| X_train_embeddings = np.array(X_train_embeddings) | |
| if X_train_embeddings.shape[0] == 0: | |
| print("ERRO: Nenhum embedding extraído. Verifique os dados de treinamento ou o modelo.") | |
| raise RuntimeError("Nenhum embedding extraído para treinamento.") | |
| scaler = StandardScaler() | |
| X_train_embeddings_scaled = scaler.fit_transform(X_train_embeddings) | |
| mlp_regressor = MLPRegressor( | |
| hidden_layer_sizes=(768, 384, 192), | |
| activation='relu', | |
| solver='adam', | |
| max_iter=500, | |
| random_state=RANDOM_SEED, | |
| early_stopping=True, | |
| n_iter_no_change=30, | |
| alpha=0.005, | |
| learning_rate_init=0.0005, | |
| batch_size=256, | |
| verbose=False | |
| ) | |
| print("Treinando a rede neural (MLP)... Isso pode levar alguns minutos.") | |
| mlp_regressor.fit(X_train_embeddings_scaled, y_train) | |
| print(" ... modelo de Embedding Profundo treinado.") | |
| # --- Cabeça 2: Vetorial Clássico (TF-IDF) --- | |
| print("2. Treinando modelo Vetorial Clássico (TF-IDF + Ridge)...") | |
| tfidf_vectorizer = TfidfVectorizer(ngram_range=(1, 3), min_df=5, max_df=0.7, max_features=10000) | |
| X_train_tfidf = tfidf_vectorizer.fit_transform(train_texts) | |
| tfidf_regressor = Ridge(alpha=1.0, random_state=RANDOM_SEED) | |
| tfidf_regressor.fit(X_train_tfidf, y_train) | |
| print(" ... modelo TF-IDF treinado.") | |
| print("Todos os modelos de risco foram treinados com sucesso!") | |
| # SALVAR OS MODELOS | |
| print("Salvando modelos treinados...") | |
| joblib.dump(mlp_regressor, "mlp_regressor.joblib") | |
| joblib.dump(scaler, "scaler.joblib") | |
| joblib.dump(tfidf_vectorizer, "tfidf_vectorizer.joblib") | |
| joblib.dump(tfidf_regressor, "tfidf_regressor.joblib") | |
| print("Modelos salvos com sucesso.") | |
| if __name__ == "__main__": | |
| print("Iniciando pré-carregamento e treinamento de modelos para salvar...") | |
| db_path = Path(DB_NAME) | |
| if not db_path.exists() or db_path.stat().st_size == 0: | |
| print(f"Banco de dados de treinamento '{DB_NAME}' não encontrado ou vazio. Criando e populando inicialmente...") | |
| if db_path.exists(): | |
| os.remove(db_path) | |
| conn = sqlite3.connect(DB_NAME) | |
| cursor = conn.cursor() | |
| cursor.execute(f'CREATE TABLE IF NOT EXISTS {TABLE_NAME} (id INTEGER PRIMARY KEY, text TEXT NOT NULL, risk_score REAL NOT NULL)') | |
| conn.commit() | |
| conn.close() | |
| populate_database_initial() | |
| else: | |
| print(f"Banco de dados de treinamento '{DB_NAME}' encontrado. Pulando a geração de dados inicial.") | |
| try: | |
| train_and_save_all_models() | |
| print("Processo de treinamento e salvamento concluído com sucesso.") | |
| except RuntimeError as e: | |
| print(f"ERRO FATAL: Falha ao treinar ou carregar modelos. Erro: {e}") | |
| exit(1) |