import torch from torch_geometric.loader import DataLoader import torch_geometric import numpy as np import json import os from dotenv import load_dotenv from src.model import GIN from src.preprocess import get_graph_datasets from src.train_evaluate import train_model, evaluate, compute_roc_auc_avg_and_per_class from src.seed import set_seed def train(config): SEED=config["seed"] set_seed(SEED) best_model_path = "./checkpoints/model.pt" # get dataloaders print("Loading Datasets...") torch_geometric.seed_everything(SEED) token = os.getenv("TOKEN") train_dataset, val_dataset = get_graph_datasets(token) train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True) val_loader = DataLoader(val_dataset, batch_size=config["batch_size"]) # initialize device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = GIN(num_features=9, num_classes=12, dropout=config["dropout"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], add_or_mean=config["add_or_mean"]).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) # training loop best_mean_auc = -float("inf") best_mean_epoch = 0 aucs = [] window_size = config["window_size"] epoch_checkpoints = {} print("Starting Training...") for epoch in range(0, config["max_epochs"]): train_loss = train_model(model, train_loader, optimizer, device) val_loss = evaluate(model, val_loader, device) val_auc_per_class, val_auc_avg = compute_roc_auc_avg_and_per_class(model, val_loader, device) aucs.append(val_auc_avg) # log if epoch % 10 == 0: print(f"Epoch {epoch:03d} | " f"Train Loss: {train_loss:.4f} | " f"Val Loss: {val_loss:.4f} | " f"Val ROC-AUC: {val_auc_avg:.4f}") # store model parameters for this epoch in cache (on CPU to save GPU memory) epoch_checkpoints[epoch] = {k: v.cpu() for k, v in model.state_dict().items()} # keep cache size limited if len(epoch_checkpoints) > window_size + 2: oldest = min(epoch_checkpoints.keys()) del epoch_checkpoints[oldest] # once we have enough epochs, compute rolling mean if len(aucs) >= window_size: current_window = aucs[-window_size:] current_mean_auc = np.mean(current_window) middle_epoch = epoch - window_size // 2 # check if current mean beats the best so far if current_mean_auc > best_mean_auc: best_mean_auc = current_mean_auc best_mean_epoch = middle_epoch # save only the best middle model if middle_epoch in epoch_checkpoints: torch.save(epoch_checkpoints[middle_epoch], best_model_path) print(f"🟢 New best mean AUC = {best_mean_auc:.4f} " f"(center epoch {best_mean_epoch}) — model saved!") # early stopping based on best mean epoch if epoch - best_mean_epoch >= config["patience"]: print(f"⛔ Early stopping at epoch {epoch}. " f"Best mean AUC = {best_mean_auc:.4f} (center epoch {best_mean_epoch})") break print("best_smoothed_val_auc" + str(best_mean_auc) + ", best_middle_epoch" + str(best_mean_epoch)) if __name__ == "__main__": with open("./config/config.json", "r") as f: config = json.load(f) load_dotenv() train(config)