File size: 3,784 Bytes
9af3c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c70bdf2
9af3c0c
 
c70bdf2
9af3c0c
 
c70bdf2
9af3c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c70bdf2
9af3c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163605e
 
 
02c0b48
9af3c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
This files includes a XGBoost model for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""

# ---------------------------------------------------------------------------------------
# Dependencies
from typing import Literal

from dataclasses import dataclass

import numpy as np

import torch
import torch.nn as nn

from .utils import TASKS


# ---------------------------------------------------------------------------------------
@dataclass
class SNNConfig:
    hidden_dim: int
    n_layers: int
    dropout: float
    layer_form: Literal["conic", "rect"]
    in_features: int
    out_features: int


class Tox21SNNClassifier(nn.Module):
    """An SNN classifier that assigns a toxicity score to a given SMILES string."""

    def __init__(self, config: SNNConfig):
        """Initialize an SNN classifier for each of the 12 Tox21 tasks.

        Args:
            seed (int, optional): seed for SNN to ensure reproducibility. Defaults to 42.
        """
        super(Tox21SNNClassifier, self).__init__()

        self.tasks = TASKS
        self.num_tasks = len(TASKS)

        activation = nn.SELU()
        dropout = nn.AlphaDropout(p=config.dropout)

        n_hidden = (
            (
                config.hidden_dim
                * np.power(
                    np.power(
                        config.out_features / config.hidden_dim, 1 / (config.n_layers)
                    ),
                    range(-1, config.n_layers),
                )
            ).astype(int)
            if config.layer_form == "conic"
            else [config.hidden_dim] * (config.n_layers + 1)
        )

        n_hidden[0] = config.in_features
        n_hidden[config.n_layers] = config.out_features

        layers = []
        for l in range(config.n_layers + 1):
            fc = nn.Linear(
                in_features=n_hidden[l],
                out_features=(
                    n_hidden[config.n_layers]
                    if l == config.n_layers
                    else n_hidden[l + 1]
                ),
            )
            if l < config.n_layers:
                block = [
                    fc,
                    activation,
                    dropout,
                ]
            else:  # last layer
                block = [fc]
            layers.extend(block)

        self.model = nn.Sequential(*layers)
        self.config = config

        self.reset_parameters()

    def reset_parameters(self):
        for param in self.model.parameters():
            # biases zero
            if len(param.shape) == 1:
                nn.init.constant_(param, 0)
            # others using lecun-normal initialization
            else:
                nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="linear")

    def forward(self, x) -> torch.Tensor:
        x = self.model(x)
        return x  # x.view(x.size(0), self.num_tasks)

    def load_model(self, path: str):
        state_dict = torch.load(
            path, weights_only=False, map_location=torch.device("cpu")
        )["model"]
        self.load_state_dict(state_dict)
        self.eval()

    @torch.no_grad()
    def predict(self, features: torch.tensor) -> np.ndarray:
        """Predicts labels for a given Tox21 target using molecule features

        Args:
            task (str): the Tox21 target to predict for
            features (torch.tensor): molecule features used for prediction

        Returns:
            np.ndarray: predicted probability for positive class
        """
        assert (
            len(features.shape) == 2
        ), f"Function expects 2D torch.tensor. Current shape: {features.shape}"

        return torch.nn.functional.sigmoid(self.model(features)).detach().cpu().numpy()