Spaces:
Running
Running
| import torch | |
| from torch_geometric.loader import DataLoader | |
| import torch_geometric | |
| import numpy as np | |
| import json | |
| import os | |
| from dotenv import load_dotenv | |
| from src.model import GIN | |
| from src.preprocess import get_graph_datasets | |
| from src.train_evaluate import train_model, evaluate, compute_roc_auc_avg_and_per_class | |
| from src.seed import set_seed | |
| def train(config): | |
| SEED=config["seed"] | |
| set_seed(SEED) | |
| best_model_path = "./checkpoints/model.pt" | |
| # get dataloaders | |
| print("Loading Datasets...") | |
| torch_geometric.seed_everything(SEED) | |
| token = os.getenv("TOKEN") | |
| train_dataset, val_dataset = get_graph_datasets(token) | |
| train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=config["batch_size"]) | |
| # initialize | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = GIN(num_features=9, num_classes=12, dropout=config["dropout"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], add_or_mean=config["add_or_mean"]).to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) | |
| # training loop | |
| best_mean_auc = -float("inf") | |
| best_mean_epoch = 0 | |
| aucs = [] | |
| window_size = config["window_size"] | |
| epoch_checkpoints = {} | |
| print("Starting Training...") | |
| for epoch in range(0, config["max_epochs"]): | |
| train_loss = train_model(model, train_loader, optimizer, device) | |
| val_loss = evaluate(model, val_loader, device) | |
| val_auc_per_class, val_auc_avg = compute_roc_auc_avg_and_per_class(model, val_loader, device) | |
| aucs.append(val_auc_avg) | |
| # log | |
| if epoch % 10 == 0: | |
| print(f"Epoch {epoch:03d} | " | |
| f"Train Loss: {train_loss:.4f} | " | |
| f"Val Loss: {val_loss:.4f} | " | |
| f"Val ROC-AUC: {val_auc_avg:.4f}") | |
| # store model parameters for this epoch in cache (on CPU to save GPU memory) | |
| epoch_checkpoints[epoch] = {k: v.cpu() for k, v in model.state_dict().items()} | |
| # keep cache size limited | |
| if len(epoch_checkpoints) > window_size + 2: | |
| oldest = min(epoch_checkpoints.keys()) | |
| del epoch_checkpoints[oldest] | |
| # once we have enough epochs, compute rolling mean | |
| if len(aucs) >= window_size: | |
| current_window = aucs[-window_size:] | |
| current_mean_auc = np.mean(current_window) | |
| middle_epoch = epoch - window_size // 2 | |
| # check if current mean beats the best so far | |
| if current_mean_auc > best_mean_auc: | |
| best_mean_auc = current_mean_auc | |
| best_mean_epoch = middle_epoch | |
| # save only the best middle model | |
| if middle_epoch in epoch_checkpoints: | |
| torch.save(epoch_checkpoints[middle_epoch], best_model_path) | |
| print(f"π’ New best mean AUC = {best_mean_auc:.4f} " | |
| f"(center epoch {best_mean_epoch}) β model saved!") | |
| # early stopping based on best mean epoch | |
| if epoch - best_mean_epoch >= config["patience"]: | |
| print(f"β Early stopping at epoch {epoch}. " | |
| f"Best mean AUC = {best_mean_auc:.4f} (center epoch {best_mean_epoch})") | |
| break | |
| print("best_smoothed_val_auc" + str(best_mean_auc) + ", best_middle_epoch" + str(best_mean_epoch)) | |
| if __name__ == "__main__": | |
| with open("./config/config.json", "r") as f: | |
| config = json.load(f) | |
| load_dotenv() | |
| train(config) |