Spaces:
Sleeping
Sleeping
File size: 3,784 Bytes
9af3c0c c70bdf2 9af3c0c c70bdf2 9af3c0c c70bdf2 9af3c0c c70bdf2 9af3c0c 163605e 02c0b48 9af3c0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
"""
This files includes a XGBoost model for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""
# ---------------------------------------------------------------------------------------
# Dependencies
from typing import Literal
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
from .utils import TASKS
# ---------------------------------------------------------------------------------------
@dataclass
class SNNConfig:
hidden_dim: int
n_layers: int
dropout: float
layer_form: Literal["conic", "rect"]
in_features: int
out_features: int
class Tox21SNNClassifier(nn.Module):
"""An SNN classifier that assigns a toxicity score to a given SMILES string."""
def __init__(self, config: SNNConfig):
"""Initialize an SNN classifier for each of the 12 Tox21 tasks.
Args:
seed (int, optional): seed for SNN to ensure reproducibility. Defaults to 42.
"""
super(Tox21SNNClassifier, self).__init__()
self.tasks = TASKS
self.num_tasks = len(TASKS)
activation = nn.SELU()
dropout = nn.AlphaDropout(p=config.dropout)
n_hidden = (
(
config.hidden_dim
* np.power(
np.power(
config.out_features / config.hidden_dim, 1 / (config.n_layers)
),
range(-1, config.n_layers),
)
).astype(int)
if config.layer_form == "conic"
else [config.hidden_dim] * (config.n_layers + 1)
)
n_hidden[0] = config.in_features
n_hidden[config.n_layers] = config.out_features
layers = []
for l in range(config.n_layers + 1):
fc = nn.Linear(
in_features=n_hidden[l],
out_features=(
n_hidden[config.n_layers]
if l == config.n_layers
else n_hidden[l + 1]
),
)
if l < config.n_layers:
block = [
fc,
activation,
dropout,
]
else: # last layer
block = [fc]
layers.extend(block)
self.model = nn.Sequential(*layers)
self.config = config
self.reset_parameters()
def reset_parameters(self):
for param in self.model.parameters():
# biases zero
if len(param.shape) == 1:
nn.init.constant_(param, 0)
# others using lecun-normal initialization
else:
nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="linear")
def forward(self, x) -> torch.Tensor:
x = self.model(x)
return x # x.view(x.size(0), self.num_tasks)
def load_model(self, path: str):
state_dict = torch.load(
path, weights_only=False, map_location=torch.device("cpu")
)["model"]
self.load_state_dict(state_dict)
self.eval()
@torch.no_grad()
def predict(self, features: torch.tensor) -> np.ndarray:
"""Predicts labels for a given Tox21 target using molecule features
Args:
task (str): the Tox21 target to predict for
features (torch.tensor): molecule features used for prediction
Returns:
np.ndarray: predicted probability for positive class
"""
assert (
len(features.shape) == 2
), f"Function expects 2D torch.tensor. Current shape: {features.shape}"
return torch.nn.functional.sigmoid(self.model(features)).detach().cpu().numpy()
|