antoniaebner's picture
add code
9af3c0c
raw
history blame
3.68 kB
"""
This files includes a XGBoost model for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""
# ---------------------------------------------------------------------------------------
# Dependencies
from typing import Literal
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
from .utils import TASKS
# ---------------------------------------------------------------------------------------
@dataclass
class SNNConfig:
hidden_dim: int
n_layers: int
dropout: float
layer_form: Literal["conic", "rect"]
in_features: int
out_features: int
class Tox21SNNClassifier(nn.Module):
"""A XGBoost classifier that assigns a toxicity score to a given SMILES string."""
def __init__(self, config: SNNConfig):
"""Initialize an XGBoost classifier for each of the 12 Tox21 tasks.
Args:
seed (int, optional): seed for XGBoost to ensure reproducibility. Defaults to 42.
"""
super(Tox21SNNClassifier, self).__init__()
self.tasks = TASKS
self.num_tasks = len(TASKS)
activation = nn.SELU()
dropout = nn.AlphaDropout(p=config.dropout)
n_hidden = (
(
config.hidden_dim
* np.power(
np.power(
config.out_features / config.hidden_dim, 1 / (config.n_layers)
),
range(-1, config.n_layers),
)
).astype(int)
if config.layer_form == "conic"
else [config.hidden_dim] * (config.n_layers + 1)
)
n_hidden[0] = config.in_features
n_hidden[config.n_layers] = config.out_features
layers = []
for l in range(config.n_layers + 1):
fc = nn.Linear(
in_features=n_hidden[l],
out_features=(
n_hidden[config.n_layers]
if l == config.n_layers
else n_hidden[l + 1]
),
)
if l < config.n_layers:
block = [
fc,
activation,
dropout,
]
else: # last layer
block = [fc]
layers.extend(block)
self.model = nn.Sequential(*layers)
self.reset_parameters()
def reset_parameters(self):
for param in self.model.parameters():
# biases zero
if len(param.shape) == 1:
nn.init.constant_(param, 0)
# others using lecun-normal initialization
else:
nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="linear")
def forward(self, x) -> torch.Tensor:
x = self.model(x)
return x # x.view(x.size(0), self.num_tasks)
def load_model(self, path: str):
self.load_state_dict(torch.load(path, weights_only=True)["model"])
self.eval()
@torch.no_grad()
def predict(self, features: torch.tensor) -> np.ndarray:
"""Predicts labels for a given Tox21 target using molecule features
Args:
task (str): the Tox21 target to predict for
features (torch.tensor): molecule features used for prediction
Returns:
np.ndarray: predicted probability for positive class
"""
assert (
len(features.shape) == 2
), f"Function expects 2D torch.tensor. Current shape: {features.shape}"
return torch.nn.functional.sigmoid(self.model(features)).detach().cpu().numpy()