File size: 3,655 Bytes
f0bc9a8
 
 
 
 
 
 
 
1ce331f
 
 
 
f0bc9a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from torch_geometric.loader import DataLoader
import torch_geometric
import numpy as np
import json
import os
from dotenv import load_dotenv

from src.model import GIN
from src.preprocess import get_graph_datasets
from src.train_evaluate import train_model, evaluate, compute_roc_auc_avg_and_per_class
from src.seed import set_seed


def train(config):
    SEED=config["seed"]
    set_seed(SEED)
    best_model_path = "./checkpoints/model.pt"

     # get dataloaders
    print("Loading Datasets...")
    torch_geometric.seed_everything(SEED)
    token = os.getenv("TOKEN")
    train_dataset, val_dataset = get_graph_datasets(token)
    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=config["batch_size"])

    # initialize
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GIN(num_features=9, num_classes=12, dropout=config["dropout"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], add_or_mean=config["add_or_mean"]).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])

    # training loop
    best_mean_auc = -float("inf") 
    best_mean_epoch = 0
    aucs = []
    window_size = config["window_size"]
    epoch_checkpoints = {}
    print("Starting Training...")

    for epoch in range(0, config["max_epochs"]):
        train_loss = train_model(model, train_loader, optimizer, device)
        val_loss = evaluate(model, val_loader, device)
        val_auc_per_class, val_auc_avg = compute_roc_auc_avg_and_per_class(model, val_loader, device)
        aucs.append(val_auc_avg)

        # log
        if epoch % 10 == 0:
            print(f"Epoch {epoch:03d} | "
                f"Train Loss: {train_loss:.4f} | "
                f"Val Loss: {val_loss:.4f} | "
                f"Val ROC-AUC: {val_auc_avg:.4f}")
        
        # store model parameters for this epoch in cache (on CPU to save GPU memory)
        epoch_checkpoints[epoch] = {k: v.cpu() for k, v in model.state_dict().items()}

        # keep cache size limited
        if len(epoch_checkpoints) > window_size + 2:
            oldest = min(epoch_checkpoints.keys())
            del epoch_checkpoints[oldest]

        # once we have enough epochs, compute rolling mean
            if len(aucs) >= window_size:
                current_window = aucs[-window_size:]
                current_mean_auc = np.mean(current_window)
                middle_epoch = epoch - window_size // 2

                # check if current mean beats the best so far
                if current_mean_auc > best_mean_auc:
                    best_mean_auc = current_mean_auc
                    best_mean_epoch = middle_epoch

                    # save only the best middle model
                    if middle_epoch in epoch_checkpoints:
                        torch.save(epoch_checkpoints[middle_epoch], best_model_path)
                        print(f"🟢 New best mean AUC = {best_mean_auc:.4f} "
                            f"(center epoch {best_mean_epoch}) — model saved!")

            # early stopping based on best mean epoch
            if epoch - best_mean_epoch >= config["patience"]:
                print(f"⛔ Early stopping at epoch {epoch}. "
                    f"Best mean AUC = {best_mean_auc:.4f} (center epoch {best_mean_epoch})")
                break

    print("best_smoothed_val_auc" + str(best_mean_auc) + ", best_middle_epoch" + str(best_mean_epoch))

if __name__ == "__main__":
    with open("./config/config.json", "r") as f:
        config = json.load(f)
    load_dotenv()
    train(config)