Spaces:
Running
Running
File size: 3,655 Bytes
f0bc9a8 1ce331f f0bc9a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
from torch_geometric.loader import DataLoader
import torch_geometric
import numpy as np
import json
import os
from dotenv import load_dotenv
from src.model import GIN
from src.preprocess import get_graph_datasets
from src.train_evaluate import train_model, evaluate, compute_roc_auc_avg_and_per_class
from src.seed import set_seed
def train(config):
SEED=config["seed"]
set_seed(SEED)
best_model_path = "./checkpoints/model.pt"
# get dataloaders
print("Loading Datasets...")
torch_geometric.seed_everything(SEED)
token = os.getenv("TOKEN")
train_dataset, val_dataset = get_graph_datasets(token)
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"])
# initialize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GIN(num_features=9, num_classes=12, dropout=config["dropout"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], add_or_mean=config["add_or_mean"]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
# training loop
best_mean_auc = -float("inf")
best_mean_epoch = 0
aucs = []
window_size = config["window_size"]
epoch_checkpoints = {}
print("Starting Training...")
for epoch in range(0, config["max_epochs"]):
train_loss = train_model(model, train_loader, optimizer, device)
val_loss = evaluate(model, val_loader, device)
val_auc_per_class, val_auc_avg = compute_roc_auc_avg_and_per_class(model, val_loader, device)
aucs.append(val_auc_avg)
# log
if epoch % 10 == 0:
print(f"Epoch {epoch:03d} | "
f"Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | "
f"Val ROC-AUC: {val_auc_avg:.4f}")
# store model parameters for this epoch in cache (on CPU to save GPU memory)
epoch_checkpoints[epoch] = {k: v.cpu() for k, v in model.state_dict().items()}
# keep cache size limited
if len(epoch_checkpoints) > window_size + 2:
oldest = min(epoch_checkpoints.keys())
del epoch_checkpoints[oldest]
# once we have enough epochs, compute rolling mean
if len(aucs) >= window_size:
current_window = aucs[-window_size:]
current_mean_auc = np.mean(current_window)
middle_epoch = epoch - window_size // 2
# check if current mean beats the best so far
if current_mean_auc > best_mean_auc:
best_mean_auc = current_mean_auc
best_mean_epoch = middle_epoch
# save only the best middle model
if middle_epoch in epoch_checkpoints:
torch.save(epoch_checkpoints[middle_epoch], best_model_path)
print(f"🟢 New best mean AUC = {best_mean_auc:.4f} "
f"(center epoch {best_mean_epoch}) — model saved!")
# early stopping based on best mean epoch
if epoch - best_mean_epoch >= config["patience"]:
print(f"⛔ Early stopping at epoch {epoch}. "
f"Best mean AUC = {best_mean_auc:.4f} (center epoch {best_mean_epoch})")
break
print("best_smoothed_val_auc" + str(best_mean_auc) + ", best_middle_epoch" + str(best_mean_epoch))
if __name__ == "__main__":
with open("./config/config.json", "r") as f:
config = json.load(f)
load_dotenv()
train(config) |