""" This files includes a XGBoost model for Tox21. As an input it takes a list of SMILES and it outputs a nested dictionary with SMILES and target names as keys. """ # --------------------------------------------------------------------------------------- # Dependencies from typing import Literal from dataclasses import dataclass import numpy as np import torch import torch.nn as nn from .utils import TASKS # --------------------------------------------------------------------------------------- @dataclass class SNNConfig: hidden_dim: int n_layers: int dropout: float layer_form: Literal["conic", "rect"] in_features: int out_features: int class Tox21SNNClassifier(nn.Module): """An SNN classifier that assigns a toxicity score to a given SMILES string.""" def __init__(self, config: SNNConfig): """Initialize an SNN classifier for each of the 12 Tox21 tasks. Args: seed (int, optional): seed for SNN to ensure reproducibility. Defaults to 42. """ super(Tox21SNNClassifier, self).__init__() self.tasks = TASKS self.num_tasks = len(TASKS) activation = nn.SELU() dropout = nn.AlphaDropout(p=config.dropout) n_hidden = ( ( config.hidden_dim * np.power( np.power( config.out_features / config.hidden_dim, 1 / (config.n_layers) ), range(-1, config.n_layers), ) ).astype(int) if config.layer_form == "conic" else [config.hidden_dim] * (config.n_layers + 1) ) n_hidden[0] = config.in_features n_hidden[config.n_layers] = config.out_features layers = [] for l in range(config.n_layers + 1): fc = nn.Linear( in_features=n_hidden[l], out_features=( n_hidden[config.n_layers] if l == config.n_layers else n_hidden[l + 1] ), ) if l < config.n_layers: block = [ fc, activation, dropout, ] else: # last layer block = [fc] layers.extend(block) self.model = nn.Sequential(*layers) self.config = config self.reset_parameters() def reset_parameters(self): for param in self.model.parameters(): # biases zero if len(param.shape) == 1: nn.init.constant_(param, 0) # others using lecun-normal initialization else: nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="linear") def forward(self, x) -> torch.Tensor: x = self.model(x) return x # x.view(x.size(0), self.num_tasks) def load_model(self, path: str): state_dict = torch.load( path, weights_only=False, map_location=torch.device("cpu") )["model"] self.load_state_dict(state_dict) self.eval() @torch.no_grad() def predict(self, features: torch.tensor) -> np.ndarray: """Predicts labels for a given Tox21 target using molecule features Args: task (str): the Tox21 target to predict for features (torch.tensor): molecule features used for prediction Returns: np.ndarray: predicted probability for positive class """ assert ( len(features.shape) == 2 ), f"Function expects 2D torch.tensor. Current shape: {features.shape}" return torch.nn.functional.sigmoid(self.model(features)).detach().cpu().numpy()