Spaces:
Sleeping
Sleeping
| """ | |
| This files includes a XGBoost model for Tox21. | |
| As an input it takes a list of SMILES and it outputs a nested dictionary with | |
| SMILES and target names as keys. | |
| """ | |
| # --------------------------------------------------------------------------------------- | |
| # Dependencies | |
| from typing import Literal | |
| from dataclasses import dataclass | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from .utils import TASKS | |
| # --------------------------------------------------------------------------------------- | |
| class SNNConfig: | |
| hidden_dim: int | |
| n_layers: int | |
| dropout: float | |
| layer_form: Literal["conic", "rect"] | |
| in_features: int | |
| out_features: int | |
| class Tox21SNNClassifier(nn.Module): | |
| """A XGBoost classifier that assigns a toxicity score to a given SMILES string.""" | |
| def __init__(self, config: SNNConfig): | |
| """Initialize an XGBoost classifier for each of the 12 Tox21 tasks. | |
| Args: | |
| seed (int, optional): seed for XGBoost to ensure reproducibility. Defaults to 42. | |
| """ | |
| super(Tox21SNNClassifier, self).__init__() | |
| self.tasks = TASKS | |
| self.num_tasks = len(TASKS) | |
| activation = nn.SELU() | |
| dropout = nn.AlphaDropout(p=config.dropout) | |
| n_hidden = ( | |
| ( | |
| config.hidden_dim | |
| * np.power( | |
| np.power( | |
| config.out_features / config.hidden_dim, 1 / (config.n_layers) | |
| ), | |
| range(-1, config.n_layers), | |
| ) | |
| ).astype(int) | |
| if config.layer_form == "conic" | |
| else [config.hidden_dim] * (config.n_layers + 1) | |
| ) | |
| n_hidden[0] = config.in_features | |
| n_hidden[config.n_layers] = config.out_features | |
| layers = [] | |
| for l in range(config.n_layers + 1): | |
| fc = nn.Linear( | |
| in_features=n_hidden[l], | |
| out_features=( | |
| n_hidden[config.n_layers] | |
| if l == config.n_layers | |
| else n_hidden[l + 1] | |
| ), | |
| ) | |
| if l < config.n_layers: | |
| block = [ | |
| fc, | |
| activation, | |
| dropout, | |
| ] | |
| else: # last layer | |
| block = [fc] | |
| layers.extend(block) | |
| self.model = nn.Sequential(*layers) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| for param in self.model.parameters(): | |
| # biases zero | |
| if len(param.shape) == 1: | |
| nn.init.constant_(param, 0) | |
| # others using lecun-normal initialization | |
| else: | |
| nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="linear") | |
| def forward(self, x) -> torch.Tensor: | |
| x = self.model(x) | |
| return x # x.view(x.size(0), self.num_tasks) | |
| def load_model(self, path: str): | |
| self.load_state_dict(torch.load(path, weights_only=True)["model"]) | |
| self.eval() | |
| def predict(self, features: torch.tensor) -> np.ndarray: | |
| """Predicts labels for a given Tox21 target using molecule features | |
| Args: | |
| task (str): the Tox21 target to predict for | |
| features (torch.tensor): molecule features used for prediction | |
| Returns: | |
| np.ndarray: predicted probability for positive class | |
| """ | |
| assert ( | |
| len(features.shape) == 2 | |
| ), f"Function expects 2D torch.tensor. Current shape: {features.shape}" | |
| return torch.nn.functional.sigmoid(self.model(features)).detach().cpu().numpy() | |