File size: 3,677 Bytes
9af3c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
This files includes a XGBoost model for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""

# ---------------------------------------------------------------------------------------
# Dependencies
from typing import Literal

from dataclasses import dataclass

import numpy as np

import torch
import torch.nn as nn

from .utils import TASKS


# ---------------------------------------------------------------------------------------
@dataclass
class SNNConfig:
    hidden_dim: int
    n_layers: int
    dropout: float
    layer_form: Literal["conic", "rect"]
    in_features: int
    out_features: int


class Tox21SNNClassifier(nn.Module):
    """A XGBoost classifier that assigns a toxicity score to a given SMILES string."""

    def __init__(self, config: SNNConfig):
        """Initialize an XGBoost classifier for each of the 12 Tox21 tasks.

        Args:
            seed (int, optional): seed for XGBoost to ensure reproducibility. Defaults to 42.
        """
        super(Tox21SNNClassifier, self).__init__()

        self.tasks = TASKS
        self.num_tasks = len(TASKS)

        activation = nn.SELU()
        dropout = nn.AlphaDropout(p=config.dropout)

        n_hidden = (
            (
                config.hidden_dim
                * np.power(
                    np.power(
                        config.out_features / config.hidden_dim, 1 / (config.n_layers)
                    ),
                    range(-1, config.n_layers),
                )
            ).astype(int)
            if config.layer_form == "conic"
            else [config.hidden_dim] * (config.n_layers + 1)
        )

        n_hidden[0] = config.in_features
        n_hidden[config.n_layers] = config.out_features

        layers = []
        for l in range(config.n_layers + 1):
            fc = nn.Linear(
                in_features=n_hidden[l],
                out_features=(
                    n_hidden[config.n_layers]
                    if l == config.n_layers
                    else n_hidden[l + 1]
                ),
            )
            if l < config.n_layers:
                block = [
                    fc,
                    activation,
                    dropout,
                ]
            else:  # last layer
                block = [fc]
            layers.extend(block)

        self.model = nn.Sequential(*layers)

        self.reset_parameters()

    def reset_parameters(self):
        for param in self.model.parameters():
            # biases zero
            if len(param.shape) == 1:
                nn.init.constant_(param, 0)
            # others using lecun-normal initialization
            else:
                nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="linear")

    def forward(self, x) -> torch.Tensor:
        x = self.model(x)
        return x  # x.view(x.size(0), self.num_tasks)

    def load_model(self, path: str):
        self.load_state_dict(torch.load(path, weights_only=True)["model"])
        self.eval()

    @torch.no_grad()
    def predict(self, features: torch.tensor) -> np.ndarray:
        """Predicts labels for a given Tox21 target using molecule features

        Args:
            task (str): the Tox21 target to predict for
            features (torch.tensor): molecule features used for prediction

        Returns:
            np.ndarray: predicted probability for positive class
        """
        assert (
            len(features.shape) == 2
        ), f"Function expects 2D torch.tensor. Current shape: {features.shape}"

        return torch.nn.functional.sigmoid(self.model(features)).detach().cpu().numpy()