Sonja Topf
README, utils to src folder
1ce331f
import torch
from torch_geometric.loader import DataLoader
import torch_geometric
import numpy as np
import json
import os
from dotenv import load_dotenv
from src.model import GIN
from src.preprocess import get_graph_datasets
from src.train_evaluate import train_model, evaluate, compute_roc_auc_avg_and_per_class
from src.seed import set_seed
def train(config):
SEED=config["seed"]
set_seed(SEED)
best_model_path = "./checkpoints/model.pt"
# get dataloaders
print("Loading Datasets...")
torch_geometric.seed_everything(SEED)
token = os.getenv("TOKEN")
train_dataset, val_dataset = get_graph_datasets(token)
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"])
# initialize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GIN(num_features=9, num_classes=12, dropout=config["dropout"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], add_or_mean=config["add_or_mean"]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
# training loop
best_mean_auc = -float("inf")
best_mean_epoch = 0
aucs = []
window_size = config["window_size"]
epoch_checkpoints = {}
print("Starting Training...")
for epoch in range(0, config["max_epochs"]):
train_loss = train_model(model, train_loader, optimizer, device)
val_loss = evaluate(model, val_loader, device)
val_auc_per_class, val_auc_avg = compute_roc_auc_avg_and_per_class(model, val_loader, device)
aucs.append(val_auc_avg)
# log
if epoch % 10 == 0:
print(f"Epoch {epoch:03d} | "
f"Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | "
f"Val ROC-AUC: {val_auc_avg:.4f}")
# store model parameters for this epoch in cache (on CPU to save GPU memory)
epoch_checkpoints[epoch] = {k: v.cpu() for k, v in model.state_dict().items()}
# keep cache size limited
if len(epoch_checkpoints) > window_size + 2:
oldest = min(epoch_checkpoints.keys())
del epoch_checkpoints[oldest]
# once we have enough epochs, compute rolling mean
if len(aucs) >= window_size:
current_window = aucs[-window_size:]
current_mean_auc = np.mean(current_window)
middle_epoch = epoch - window_size // 2
# check if current mean beats the best so far
if current_mean_auc > best_mean_auc:
best_mean_auc = current_mean_auc
best_mean_epoch = middle_epoch
# save only the best middle model
if middle_epoch in epoch_checkpoints:
torch.save(epoch_checkpoints[middle_epoch], best_model_path)
print(f"🟒 New best mean AUC = {best_mean_auc:.4f} "
f"(center epoch {best_mean_epoch}) β€” model saved!")
# early stopping based on best mean epoch
if epoch - best_mean_epoch >= config["patience"]:
print(f"β›” Early stopping at epoch {epoch}. "
f"Best mean AUC = {best_mean_auc:.4f} (center epoch {best_mean_epoch})")
break
print("best_smoothed_val_auc" + str(best_mean_auc) + ", best_middle_epoch" + str(best_mean_epoch))
if __name__ == "__main__":
with open("./config/config.json", "r") as f:
config = json.load(f)
load_dotenv()
train(config)