tox21_gin_classifier / utils /train_evaluate.py
Sonja Topf
big refactoring
f0bc9a8
raw
history blame
3.74 kB
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import roc_auc_score
def masked_bce_loss(logits, labels, mask):
"""
logits: [batch_size, num_classes] (raw outputs)
labels: [batch_size, num_classes] (0/1 with filler)
mask: [batch_size, num_classes] (True if label is valid)
"""
criterion = nn.BCEWithLogitsLoss(reduction="none")
loss_raw = criterion(logits, labels)
loss = (loss_raw * mask.float()).sum() / mask.float().sum()
return loss
def train_model(model, loader, optimizer, device):
model.train()
total_loss = 0
for batch in loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index, batch.batch) # [num_graphs, num_classes]
loss = masked_bce_loss(out, batch.y, batch.mask)
loss.backward()
optimizer.step()
total_loss += loss.item() * batch.num_graphs
return total_loss / len(loader.dataset)
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
total_loss = 0
for batch in loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index, batch.batch)
loss = masked_bce_loss(out, batch.y, batch.mask)
total_loss += loss.item() * batch.num_graphs
return total_loss / len(loader.dataset)
@torch.no_grad()
def compute_roc_auc(model, loader, device):
model.eval()
y_true, y_pred, y_mask = [], [], []
for batch in loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index, batch.batch)
# Store predictions (sigmoid β†’ probabilities)
y_pred.append(torch.sigmoid(out).cpu())
y_true.append(batch.y.cpu())
y_mask.append(batch.mask.cpu())
# Concatenate across all batches
y_true = torch.cat(y_true, dim=0).numpy()
y_pred = torch.cat(y_pred, dim=0).numpy()
y_mask = torch.cat(y_mask, dim=0).numpy()
auc_list = []
for i in range(y_true.shape[1]): # per label
mask_i = y_mask[:, i].astype(bool)
if mask_i.sum() > 0: # at least one valid label
try:
auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
auc_list.append(auc)
except ValueError:
# happens if only one class present (all 0 or all 1)
pass
return np.mean(auc_list) if len(auc_list) > 0 else float("nan")
@torch.no_grad()
def compute_roc_auc_avg_and_per_class(model, loader, device):
model.eval()
y_true, y_pred, y_mask = [], [], []
with torch.no_grad():
for batch in loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index, batch.batch)
# Store predictions (sigmoid β†’ probabilities)
y_pred.append(torch.sigmoid(out).cpu())
y_true.append(batch.y.cpu())
y_mask.append(batch.mask.cpu())
# Concatenate across all batches
y_true = torch.cat(y_true, dim=0).numpy()
y_pred = torch.cat(y_pred, dim=0).numpy()
y_mask = torch.cat(y_mask, dim=0).numpy()
# Compute AUC per class
auc_list = []
for i in range(y_true.shape[1]):
mask_i = y_mask[:, i].astype(bool)
if mask_i.sum() > 0:
try:
auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
except ValueError:
auc = np.nan # in case only one class present
else:
auc = np.nan
auc_list.append(auc)
# Convert to numpy array for easier manipulation
auc_array = np.array(auc_list, dtype=np.float32)
mean_auc = np.nanmean(auc_array) # overall mean ignoring NaNs
# Return both per-class and mean
return auc_array, mean_auc