Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from sklearn.metrics import roc_auc_score | |
| def masked_bce_loss(logits, labels, mask): | |
| """ | |
| logits: [batch_size, num_classes] (raw outputs) | |
| labels: [batch_size, num_classes] (0/1 with filler) | |
| mask: [batch_size, num_classes] (True if label is valid) | |
| """ | |
| criterion = nn.BCEWithLogitsLoss(reduction="none") | |
| loss_raw = criterion(logits, labels) | |
| loss = (loss_raw * mask.float()).sum() / mask.float().sum() | |
| return loss | |
| def train_model(model, loader, optimizer, device): | |
| model.train() | |
| total_loss = 0 | |
| for batch in loader: | |
| batch = batch.to(device) | |
| optimizer.zero_grad() | |
| out = model(batch.x, batch.edge_index, batch.batch) # [num_graphs, num_classes] | |
| loss = masked_bce_loss(out, batch.y, batch.mask) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() * batch.num_graphs | |
| return total_loss / len(loader.dataset) | |
| def evaluate(model, loader, device): | |
| model.eval() | |
| total_loss = 0 | |
| for batch in loader: | |
| batch = batch.to(device) | |
| out = model(batch.x, batch.edge_index, batch.batch) | |
| loss = masked_bce_loss(out, batch.y, batch.mask) | |
| total_loss += loss.item() * batch.num_graphs | |
| return total_loss / len(loader.dataset) | |
| def compute_roc_auc(model, loader, device): | |
| model.eval() | |
| y_true, y_pred, y_mask = [], [], [] | |
| for batch in loader: | |
| batch = batch.to(device) | |
| out = model(batch.x, batch.edge_index, batch.batch) | |
| # Store predictions (sigmoid β probabilities) | |
| y_pred.append(torch.sigmoid(out).cpu()) | |
| y_true.append(batch.y.cpu()) | |
| y_mask.append(batch.mask.cpu()) | |
| # Concatenate across all batches | |
| y_true = torch.cat(y_true, dim=0).numpy() | |
| y_pred = torch.cat(y_pred, dim=0).numpy() | |
| y_mask = torch.cat(y_mask, dim=0).numpy() | |
| auc_list = [] | |
| for i in range(y_true.shape[1]): # per label | |
| mask_i = y_mask[:, i].astype(bool) | |
| if mask_i.sum() > 0: # at least one valid label | |
| try: | |
| auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i]) | |
| auc_list.append(auc) | |
| except ValueError: | |
| # happens if only one class present (all 0 or all 1) | |
| pass | |
| return np.mean(auc_list) if len(auc_list) > 0 else float("nan") | |
| def compute_roc_auc_avg_and_per_class(model, loader, device): | |
| model.eval() | |
| y_true, y_pred, y_mask = [], [], [] | |
| with torch.no_grad(): | |
| for batch in loader: | |
| batch = batch.to(device) | |
| out = model(batch.x, batch.edge_index, batch.batch) | |
| # Store predictions (sigmoid β probabilities) | |
| y_pred.append(torch.sigmoid(out).cpu()) | |
| y_true.append(batch.y.cpu()) | |
| y_mask.append(batch.mask.cpu()) | |
| # Concatenate across all batches | |
| y_true = torch.cat(y_true, dim=0).numpy() | |
| y_pred = torch.cat(y_pred, dim=0).numpy() | |
| y_mask = torch.cat(y_mask, dim=0).numpy() | |
| # Compute AUC per class | |
| auc_list = [] | |
| for i in range(y_true.shape[1]): | |
| mask_i = y_mask[:, i].astype(bool) | |
| if mask_i.sum() > 0: | |
| try: | |
| auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i]) | |
| except ValueError: | |
| auc = np.nan # in case only one class present | |
| else: | |
| auc = np.nan | |
| auc_list.append(auc) | |
| # Convert to numpy array for easier manipulation | |
| auc_array = np.array(auc_list, dtype=np.float32) | |
| mean_auc = np.nanmean(auc_array) # overall mean ignoring NaNs | |
| # Return both per-class and mean | |
| return auc_array, mean_auc |