Spaces:
Sleeping
Sleeping
devjas1
(FEAT)[Enhanced Results Widget]: Integrate advanced probability breakdown, QC, and provenance export
fe030dd
| """ | |
| Core Training Engine for the POLYMEROS project. | |
| This module contains the primary logic for model training and validation, | |
| encapsulated in a reusable `TrainingEngine` class. It is designed to be | |
| called by different interfaces, such as the command-line script | |
| (train_model.py) and the web UI's TrainingManager. | |
| This approach ensures that the core training process is consistent, | |
| maintainable, and follows the DRY (Don't Repeat Yourself) principle. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from torch.utils.data import TensorDataset, DataLoader | |
| from sklearn.metrics import confusion_matrix, accuracy_score | |
| from .training_types import ( | |
| TrainingConfig, | |
| TrainingProgress, | |
| get_cv_splitter, | |
| augment_spectral_data, | |
| ) | |
| from models.registry import build as build_model | |
| class TrainingEngine: | |
| """Encapsulates the core model training and validation logic.""" | |
| def __init__(self, config: TrainingConfig): | |
| """ | |
| Initializes the TrainingEngine with a given configuration. | |
| Args: | |
| config (TrainingConfig): The configuration object for the training run. | |
| """ | |
| self.config = config | |
| self.device = self._get_device() | |
| def _get_device(self) -> torch.device: | |
| """Selects the appropriate compute device.""" | |
| if self.config.device == "auto": | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| return torch.device(self.config.device) | |
| def run( | |
| self, X: np.ndarray, y: np.ndarray, progress_callback: callable = None | |
| ) -> dict: | |
| """ | |
| Executes the full cross-validation training and evaluation loop. | |
| Args: | |
| X (np.ndarray): Feature data. | |
| y (np.ndarray): Label data. | |
| progress_callback (callable, optional): A function to call with | |
| progress updates. Defaults to None. | |
| Returns: | |
| dict: A dictionary containing the final results and metrics. | |
| """ | |
| cv_splitter = get_cv_splitter(self.config.cv_strategy, self.config.num_folds) | |
| fold_accuracies = [] | |
| all_conf_matrices = [] | |
| final_model_state = None | |
| for fold, (train_idx, val_idx) in enumerate(cv_splitter.split(X, y), 1): | |
| if progress_callback: | |
| progress_callback( | |
| { | |
| "type": "fold_start", | |
| "fold": fold, | |
| "total_folds": self.config.num_folds, | |
| } | |
| ) | |
| X_train, X_val = X[train_idx], X[val_idx] | |
| y_train, y_val = y[train_idx], y[val_idx] | |
| # Apply data augmentation if enabled | |
| if self.config.enable_augmentation: | |
| X_train, y_train = augment_spectral_data( | |
| X_train, y_train, noise_level=self.config.noise_level | |
| ) | |
| train_loader = DataLoader( | |
| TensorDataset( | |
| torch.tensor(X_train, dtype=torch.float32), | |
| torch.tensor(y_train, dtype=torch.long), | |
| ), | |
| batch_size=self.config.batch_size, | |
| shuffle=True, | |
| ) | |
| val_loader = DataLoader( | |
| TensorDataset( | |
| torch.tensor(X_val, dtype=torch.float32), | |
| torch.tensor(y_val, dtype=torch.long), | |
| ) | |
| ) | |
| model = build_model(self.config.model_name, self.config.target_len).to( | |
| self.device | |
| ) | |
| optimizer = torch.optim.Adam( | |
| model.parameters(), lr=self.config.learning_rate | |
| ) | |
| criterion = nn.CrossEntropyLoss() | |
| for epoch in range(self.config.epochs): | |
| model.train() | |
| running_loss = 0.0 | |
| for inputs, labels in train_loader: | |
| inputs = inputs.unsqueeze(1).to(self.device) | |
| labels = labels.to(self.device) | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| if progress_callback: | |
| progress_callback( | |
| { | |
| "type": "epoch_end", | |
| "fold": fold, | |
| "epoch": epoch + 1, | |
| "total_epochs": self.config.epochs, | |
| "loss": running_loss / len(train_loader), | |
| } | |
| ) | |
| # Validation | |
| model.eval() | |
| all_true, all_pred = [], [] | |
| with torch.no_grad(): | |
| for inputs, labels in val_loader: | |
| inputs = inputs.unsqueeze(1).to(self.device) | |
| outputs = model(inputs) | |
| _, predicted = torch.max(outputs, 1) | |
| all_true.extend(labels.cpu().numpy()) | |
| all_pred.extend(predicted.cpu().numpy()) | |
| acc = accuracy_score(all_true, all_pred) | |
| fold_accuracies.append(acc) | |
| all_conf_matrices.append(confusion_matrix(all_true, all_pred).tolist()) | |
| final_model_state = model.state_dict() | |
| if progress_callback: | |
| progress_callback({"type": "fold_end", "fold": fold, "accuracy": acc}) | |
| return { | |
| "fold_accuracies": fold_accuracies, | |
| "confusion_matrices": all_conf_matrices, | |
| "mean_accuracy": np.mean(fold_accuracies), | |
| "std_accuracy": np.std(fold_accuracies), | |
| "model_state_dict": final_model_state, | |
| } | |