antoniaebner commited on
Commit
9af3c0c
·
1 Parent(s): 506cb3a
Files changed (10) hide show
  1. .gitignore +1 -0
  2. Dockerfile +16 -0
  3. README.md +93 -2
  4. app.py +78 -0
  5. predict.py +83 -0
  6. requirements.txt +10 -0
  7. src/__init__.py +0 -0
  8. src/model.py +122 -0
  9. src/preprocess.py +285 -0
  10. src/utils.py +444 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.11
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Tox21 Snn Classifier
3
  emoji: 🌖
4
  colorFrom: green
5
  colorTo: pink
@@ -9,4 +9,95 @@ license: apache-2.0
9
  short_description: Self-Normalizing Neural Network Baseline for Tox21
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Tox21 SNN Classifier
3
  emoji: 🌖
4
  colorFrom: green
5
  colorTo: pink
 
9
  short_description: Self-Normalizing Neural Network Baseline for Tox21
10
  ---
11
 
12
+ # Tox21 XGBoost Classifier
13
+
14
+ This repository hosts a Hugging Face Space that provides an examplary API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/tschouis/tox21_leaderboard).
15
+
16
+ In this example, we train a XGBoost classifier on the Tox21 targets and save the trained model in the `assets/` folder.
17
+
18
+ **Important:** For leaderboard submission, your Space does not need to include training code. It only needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a prediction dictionary as output, with SMILES and targets as keys. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference.
19
+
20
+ # Repository Structure
21
+ - `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference).
22
+ - `app.py` - FastAPI application wrapper (can be used as-is).
23
+
24
+ - `src/` - Core model & preprocessing logic:
25
+ - `data.py` - SMILES preprocessing pipeline
26
+ - `model.py` - XGBoost classifier wrapper
27
+ - `train.py` - Script to train the classifier
28
+ - `utils.py` – Constants and Helper functions
29
+
30
+ # Quickstart with Spaces
31
+
32
+ You can easily adapt this project in your own Hugging Face account:
33
+
34
+ - Open this Space on Hugging Face.
35
+
36
+ - Click "Duplicate this Space" (top-right corner).
37
+
38
+ - Modify `src/` for your preprocessing pipeline and model class
39
+
40
+ - Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard.
41
+
42
+ That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard.
43
+
44
+ # Installation
45
+ To run (and train) the XGBoost, clone the repository and install dependencies:
46
+
47
+ ```bash
48
+ git clone https://huggingface.co/spaces/tschouis/tox21_xgboost_classifier
49
+ cd tox_21_xgb_classifier
50
+
51
+ conda create -n tox21_xgb_cls python=3.11
52
+ conda activate tox21_xgb_cls
53
+ pip install -r requirements.txt
54
+ ```
55
+
56
+ # Training
57
+
58
+ To train the XGBoost model from scratch:
59
+
60
+ ```bash
61
+ python -m src/train.py
62
+ ```
63
+
64
+ This will:
65
+
66
+ 1. Load and preprocess the Tox21 training dataset.
67
+ 2. Train a XGBoost classifier.
68
+ 3. Save the trained model to the assets/ folder.
69
+ 4. Evaluate the trained XGBoost classifier on the validation split.
70
+
71
+
72
+ # Inference
73
+
74
+ For inference, you only need `predict.py`.
75
+
76
+ Example usage inside Python:
77
+
78
+ ```python
79
+ from predict import predict
80
+
81
+ smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
82
+ results = predict(smiles_list)
83
+
84
+ print(results)
85
+ ```
86
+
87
+ The output will be a nested dictionary in the format:
88
+
89
+ ```python
90
+ {
91
+ "CCO": {"target1": 0, "target2": 1, ..., "target12": 0},
92
+ "c1ccccc1": {"target1": 1, "target2": 0, ..., "target12": 1},
93
+ "CC(=O)O": {"target1": 0, "target2": 0, ..., "target12": 0}
94
+ }
95
+ ```
96
+
97
+ # Notes
98
+
99
+ - Only adapting `predict.py` for your model inference is required for leaderboard submission.
100
+
101
+ - Training (`src/train.py`) is provided for reproducibility.
102
+
103
+ - Preprocessing (here inside `src/data.py`) must be applied at inference time, not just training.
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is the main entry point for the FastAPI application.
3
+ The app handles the request to predict toxicity for a list of SMILES strings.
4
+ """
5
+
6
+ # ---------------------------------------------------------------------------------------
7
+ # Dependencies and global variable definition
8
+ import os
9
+ from typing import List, Dict, Optional
10
+ from fastapi import FastAPI, Header, HTTPException
11
+ from pydantic import BaseModel, Field
12
+
13
+ from predict import predict as predict_func
14
+
15
+ API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
+
17
+
18
+ # ---------------------------------------------------------------------------------------
19
+ class Request(BaseModel):
20
+ smiles: List[str] = Field(min_items=1, max_items=1000)
21
+
22
+
23
+ class Response(BaseModel):
24
+ predictions: dict
25
+ model_info: Dict[str, str] = {}
26
+
27
+
28
+ app = FastAPI(title="toxicity-api")
29
+
30
+
31
+ @app.get("/")
32
+ def root():
33
+ return {
34
+ "message": "Toxicity Prediction API",
35
+ "endpoints": {
36
+ "/metadata": "GET - API metadata and capabilities",
37
+ "/healthz": "GET - Health check",
38
+ "/predict": "POST - Predict toxicity for SMILES",
39
+ },
40
+ "usage": "Send POST to /predict with {'smiles': ['your_smiles_here']} and Authorization header",
41
+ }
42
+
43
+
44
+ @app.get("/metadata")
45
+ def metadata():
46
+ return {
47
+ "name": "SNN",
48
+ "version": "1.0.0",
49
+ "max_batch_size": 256,
50
+ "tox_endpoints": [
51
+ "NR-AR",
52
+ "NR-AR-LBD",
53
+ "NR-AhR",
54
+ "NR-Aromatase",
55
+ "NR-ER",
56
+ "NR-ER-LBD",
57
+ "NR-PPAR-gamma",
58
+ "SR-ARE",
59
+ "SR-ATAD5",
60
+ "SR-HSE",
61
+ "SR-MMP",
62
+ "SR-p53",
63
+ ],
64
+ }
65
+
66
+
67
+ @app.get("/healthz")
68
+ def healthz():
69
+ return {"ok": True}
70
+
71
+
72
+ @app.post("/predict", response_model=Response)
73
+ def predict(request: Request):
74
+ predictions = predict_func(request.smiles)
75
+ return {
76
+ "predictions": predictions,
77
+ "model_info": {"name": "random_clf", "version": "1.0.0"},
78
+ }
predict.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a predict function for the Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ from collections import defaultdict
10
+
11
+ import numpy as np
12
+
13
+ import torch
14
+
15
+ from src.preprocess import create_descriptors
16
+ from src.model import Tox21SNNClassifier, SNNConfig
17
+ from src.utils import load_pickle
18
+
19
+ # ---------------------------------------------------------------------------------------
20
+
21
+
22
+ def predict(smiles_list: list[str]) -> dict[str, dict[str, float]]:
23
+ """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for
24
+ any molecule that could not be cleaned.
25
+
26
+ Args:
27
+ smiles_list (list[str]): list of SMILES strings
28
+
29
+ Returns:
30
+ dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}}
31
+ """
32
+ print(f"Received {len(smiles_list)} SMILES strings")
33
+
34
+ # preprocessing pipeline
35
+ ecdfs_path = "assets/ecdfs.pkl"
36
+ scaler_path = "assets/scaler.pkl"
37
+ ecdfs = load_pickle(ecdfs_path)
38
+ scaler = load_pickle(scaler_path)
39
+ print(f"Loaded ecdfs from {ecdfs_path}")
40
+ print(f"Loaded scaler from {scaler_path}")
41
+
42
+ descriptors = ["rdkit_descr_quantiles", "tox"]
43
+ features, mol_mask = create_descriptors(
44
+ smiles,
45
+ ecdfs=ecdfs,
46
+ scaler=scaler,
47
+ descriptors=descriptors,
48
+ )
49
+ print(f"Created descriptors {descriptors} for molecules.")
50
+ print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning")
51
+
52
+ # setup model
53
+ cfg = SNNConfig(
54
+ hidden_dim=1024,
55
+ n_layers=8,
56
+ dropout=0.05,
57
+ layer_form="conic",
58
+ in_features=features.shape[0],
59
+ out_features=12,
60
+ )
61
+
62
+ model = Tox21SNNClassifier(cfg)
63
+ model_path = "assets/snn_best.pth"
64
+ model.load_model(model_path)
65
+ model.eval()
66
+ print(f"Loaded model from {model_path}")
67
+
68
+ # make predicitons
69
+ predictions = defaultdict(dict)
70
+ # create a list with same length as smiles_list to obtain indices for respective features
71
+ feat_indices = np.cumsum(mol_mask) - 1
72
+
73
+ mask = ~np.isnan(features).any(axis=1)
74
+ dataset = torch.utils.data.TensorDataset(torch.FloatTensor(features[mask]))
75
+ loader = torch.utils.data.DataLoader(dataset, 128, shuffle=False, num_workers=0)
76
+
77
+ with torch.no_grad():
78
+ preds = np.concatenate([model.predict(batch) for batch in loader], axis=0)
79
+
80
+ for i, target in enumerate(model.tasks):
81
+ for smiles, is_clean, j in zip(smiles_list, mol_mask, feat_indices):
82
+ predictions[smiles][target] = float(preds[j, i]) if is_clean else 0.5
83
+ return predictions
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ statsmodels
4
+ rdkit
5
+ numpy
6
+ scikit-learn==1.7.1
7
+ joblib
8
+ tabulate
9
+ datasets
10
+ torch==2.8.0
src/__init__.py ADDED
File without changes
src/model.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a XGBoost model for Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ from typing import Literal
10
+
11
+ from dataclasses import dataclass
12
+
13
+ import numpy as np
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from .utils import TASKS
19
+
20
+
21
+ # ---------------------------------------------------------------------------------------
22
+ @dataclass
23
+ class SNNConfig:
24
+ hidden_dim: int
25
+ n_layers: int
26
+ dropout: float
27
+ layer_form: Literal["conic", "rect"]
28
+ in_features: int
29
+ out_features: int
30
+
31
+
32
+ class Tox21SNNClassifier(nn.Module):
33
+ """A XGBoost classifier that assigns a toxicity score to a given SMILES string."""
34
+
35
+ def __init__(self, config: SNNConfig):
36
+ """Initialize an XGBoost classifier for each of the 12 Tox21 tasks.
37
+
38
+ Args:
39
+ seed (int, optional): seed for XGBoost to ensure reproducibility. Defaults to 42.
40
+ """
41
+ super(Tox21SNNClassifier, self).__init__()
42
+
43
+ self.tasks = TASKS
44
+ self.num_tasks = len(TASKS)
45
+
46
+ activation = nn.SELU()
47
+ dropout = nn.AlphaDropout(p=config.dropout)
48
+
49
+ n_hidden = (
50
+ (
51
+ config.hidden_dim
52
+ * np.power(
53
+ np.power(
54
+ config.out_features / config.hidden_dim, 1 / (config.n_layers)
55
+ ),
56
+ range(-1, config.n_layers),
57
+ )
58
+ ).astype(int)
59
+ if config.layer_form == "conic"
60
+ else [config.hidden_dim] * (config.n_layers + 1)
61
+ )
62
+
63
+ n_hidden[0] = config.in_features
64
+ n_hidden[config.n_layers] = config.out_features
65
+
66
+ layers = []
67
+ for l in range(config.n_layers + 1):
68
+ fc = nn.Linear(
69
+ in_features=n_hidden[l],
70
+ out_features=(
71
+ n_hidden[config.n_layers]
72
+ if l == config.n_layers
73
+ else n_hidden[l + 1]
74
+ ),
75
+ )
76
+ if l < config.n_layers:
77
+ block = [
78
+ fc,
79
+ activation,
80
+ dropout,
81
+ ]
82
+ else: # last layer
83
+ block = [fc]
84
+ layers.extend(block)
85
+
86
+ self.model = nn.Sequential(*layers)
87
+
88
+ self.reset_parameters()
89
+
90
+ def reset_parameters(self):
91
+ for param in self.model.parameters():
92
+ # biases zero
93
+ if len(param.shape) == 1:
94
+ nn.init.constant_(param, 0)
95
+ # others using lecun-normal initialization
96
+ else:
97
+ nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="linear")
98
+
99
+ def forward(self, x) -> torch.Tensor:
100
+ x = self.model(x)
101
+ return x # x.view(x.size(0), self.num_tasks)
102
+
103
+ def load_model(self, path: str):
104
+ self.load_state_dict(torch.load(path, weights_only=True)["model"])
105
+ self.eval()
106
+
107
+ @torch.no_grad()
108
+ def predict(self, features: torch.tensor) -> np.ndarray:
109
+ """Predicts labels for a given Tox21 target using molecule features
110
+
111
+ Args:
112
+ task (str): the Tox21 target to predict for
113
+ features (torch.tensor): molecule features used for prediction
114
+
115
+ Returns:
116
+ np.ndarray: predicted probability for positive class
117
+ """
118
+ assert (
119
+ len(features.shape) == 2
120
+ ), f"Function expects 2D torch.tensor. Current shape: {features.shape}"
121
+
122
+ return torch.nn.functional.sigmoid(self.model(features)).detach().cpu().numpy()
src/preprocess.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline taken from https://huggingface.co/spaces/ml-jku/mhnfs/blob/main/src/data_preprocessing/create_descriptors.py
2
+
3
+ """
4
+ This files includes a the data processing for Tox21.
5
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
6
+ SMILES and target names as keys.
7
+ """
8
+
9
+ import json
10
+ from typing import Iterable
11
+
12
+ import numpy as np
13
+
14
+ from sklearn.preprocessing import StandardScaler
15
+ from statsmodels.distributions.empirical_distribution import ECDF
16
+
17
+ from rdkit import Chem, DataStructs
18
+ from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
19
+ from rdkit.Chem.rdchem import Mol
20
+
21
+ from .utils import (
22
+ KNOWN_DESCR,
23
+ USED_200_DESCR,
24
+ Standardizer,
25
+ write_pickle,
26
+ )
27
+
28
+
29
+ def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
30
+ """This function creates cleaned RDKit mol objects from a list of SMILES.
31
+
32
+ Args:
33
+ smiles (list[str]): list of SMILES
34
+
35
+ Returns:
36
+ list[Mol]: list of cleaned molecules
37
+ np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` at
38
+ index `i` could not be cleaned and was removed.
39
+ """
40
+ sm = Standardizer(canon_taut=True)
41
+
42
+ clean_mol_mask = list()
43
+ mols = list()
44
+ for i, smile in enumerate(smiles):
45
+ mol = Chem.MolFromSmiles(smile)
46
+ standardized_mol, _ = sm.standardize_mol(mol)
47
+ is_cleaned = standardized_mol is not None
48
+ clean_mol_mask.append(is_cleaned)
49
+ if not is_cleaned:
50
+ continue
51
+ can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
52
+ mols.append(can_mol)
53
+
54
+ return mols, np.array(clean_mol_mask)
55
+
56
+
57
+ def create_ecfp_fps(mols: list[Mol]) -> np.ndarray:
58
+ """This function ECFP fingerprints for a list of molecules.
59
+
60
+ Args:
61
+ mols (list[Mol]): list of molecules
62
+
63
+ Returns:
64
+ np.ndarray: ECFP fingerprints of molecules
65
+ """
66
+ ecfps = list()
67
+
68
+ for mol in mols:
69
+ fp_sparse_vec = rdFingerprintGenerator.GetCountFPs(
70
+ [mol], fpType=rdFingerprintGenerator.MorganFP
71
+ )[0]
72
+ fp = np.zeros((0,), np.int8)
73
+ DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
74
+
75
+ ecfps.append(fp)
76
+
77
+ return np.array(ecfps)
78
+
79
+
80
+ def create_maccs_keys(mols: list[Mol]) -> np.ndarray:
81
+ maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols]
82
+ return np.array(maccs)
83
+
84
+
85
+ def get_tox_patterns(filepath: str):
86
+ """This calculates tox features defined in tox_smarts.json.
87
+ Args:
88
+ mols: A list of Mol
89
+ n_jobs: If >1 multiprocessing is used
90
+ """
91
+ # load patterns
92
+ with open(filepath) as f:
93
+ smarts_list = [s[1] for s in json.load(f)]
94
+
95
+ # Code does not work for this case
96
+ assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0
97
+
98
+ # Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first
99
+ # and then use them for all molecules. This gives a huge speedup over existing code.
100
+ # a list of patterns, whether to negate the match result and how to join them to obtain one boolean value
101
+ all_patterns = []
102
+ for smarts in smarts_list:
103
+ patterns = [] # list of smarts-patterns
104
+ # value for each of the patterns above. Negates the values of the above later.
105
+ negations = []
106
+
107
+ if " AND " in smarts:
108
+ smarts = smarts.split(" AND ")
109
+ merge_any = False # If an ' AND ' is found all 'subsmarts' have to match
110
+ else:
111
+ # If there is an ' OR ' present it's enough is any of the 'subsmarts' match.
112
+ # This also accumulates smarts where neither ' OR ' nor ' AND ' occur
113
+ smarts = smarts.split(" OR ")
114
+ merge_any = True
115
+
116
+ # for all subsmarts check if they are preceded by 'NOT '
117
+ for s in smarts:
118
+ neg = s.startswith("NOT ")
119
+ if neg:
120
+ s = s[4:]
121
+ patterns.append(Chem.MolFromSmarts(s))
122
+ negations.append(neg)
123
+
124
+ all_patterns.append((patterns, negations, merge_any))
125
+ return all_patterns
126
+
127
+
128
+ def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray:
129
+ """Matches the tox patterns against a molecule. Returns a boolean array"""
130
+ tox_data = []
131
+ for mol in mols:
132
+ mol_features = []
133
+ for patts, negations, merge_any in patterns:
134
+ matches = [mol.HasSubstructMatch(p) for p in patts]
135
+ matches = [m != n for m, n in zip(matches, negations)]
136
+ if merge_any:
137
+ pres = any(matches)
138
+ else:
139
+ pres = all(matches)
140
+ mol_features.append(pres)
141
+
142
+ tox_data.append(np.array(mol_features))
143
+
144
+ return np.array(tox_data)
145
+
146
+
147
+ def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
148
+ """This function creates RDKit descriptors for a list of molecules.
149
+
150
+ Args:
151
+ mols (list[Mol]): list of molecules
152
+
153
+ Returns:
154
+ np.ndarray: RDKit descriptors of molecules
155
+ """
156
+ rdkit_descriptors = list()
157
+
158
+ for mol in mols:
159
+ descrs = []
160
+ for _, descr_calc_fn in Descriptors._descList:
161
+ descrs.append(descr_calc_fn(mol))
162
+
163
+ descrs = np.array(descrs)
164
+ descrs = descrs[USED_200_DESCR]
165
+ rdkit_descriptors.append(descrs)
166
+
167
+ return np.array(rdkit_descriptors)
168
+
169
+
170
+ def create_quantiles(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
171
+ """Create quantile values for given features using the columns
172
+
173
+ Args:
174
+ raw_features (np.ndarray): values to put into quantiles
175
+ ecdfs (list): ECDFs to use
176
+
177
+ Returns:
178
+ np.ndarray: computed quantiles
179
+ """
180
+ quantiles = np.zeros_like(raw_features)
181
+
182
+ for column in range(raw_features.shape[1]):
183
+ raw_values = raw_features[:, column].reshape(-1)
184
+ ecdf = ecdfs[column]
185
+ q = ecdf(raw_values)
186
+ quantiles[:, column] = q
187
+
188
+ return quantiles
189
+
190
+
191
+ def fill(features, mask, value=np.nan):
192
+ n_mols = len(mask)
193
+ n_features = features.shape[1]
194
+
195
+ data = np.zeros(shape=(n_mols, n_features))
196
+ data.fill(value)
197
+ data[~mask] = features
198
+ return data
199
+
200
+
201
+ def normalize_features(
202
+ raw_features,
203
+ scaler=None,
204
+ save_scaler_path: str = "",
205
+ verbose=True,
206
+ ):
207
+ if scaler is None:
208
+ scaler = StandardScaler()
209
+ scaler.fit(raw_features)
210
+ if verbose:
211
+ print("Fitted the StandardScaler")
212
+ if save_scaler_path:
213
+ write_pickle(save_scaler_path, scaler)
214
+ if verbose:
215
+ print(f"Saved the StandardScaler under {save_scaler_path}")
216
+
217
+ # Normalize feature vectors
218
+ normalized_features = scaler.transform(raw_features)
219
+ if verbose:
220
+ print("Normalized molecule features")
221
+ return normalized_features, scaler
222
+
223
+
224
+ def create_descriptors(
225
+ smiles,
226
+ ecdfs=None,
227
+ scaler=None,
228
+ descriptors: Iterable = KNOWN_DESCR,
229
+ ):
230
+ # Create cleanded rdkit mol objects
231
+ mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
232
+ print("Cleaned molecules")
233
+
234
+ features = []
235
+ if "ecfps" in descriptors:
236
+ # Create fingerprints and descriptors
237
+ ecfps = create_ecfp_fps(mols)
238
+ # expand using mol_mask
239
+ ecfps = fill(ecfps, ~clean_mol_mask)
240
+ features.append(ecfps)
241
+ print("Created ECFP fingerprints")
242
+
243
+ if "rdkit_descr_quantiles" in descriptors:
244
+ rdkit_descrs = create_rdkit_descriptors(mols)
245
+ print("Created RDKit descriptors")
246
+
247
+ # Create and save ecdfs
248
+ if ecdfs is None:
249
+ print("Create ECDFs")
250
+ ecdfs = []
251
+ for column in range(rdkit_descrs.shape[1]):
252
+ raw_values = rdkit_descrs[:, column].reshape(-1)
253
+ ecdfs.append(ECDF(raw_values))
254
+
255
+ # Create quantiles
256
+ rdkit_descr_quantiles = create_quantiles(rdkit_descrs, ecdfs)
257
+ # expand using mol_mask
258
+ rdkit_descr_quantiles = fill(rdkit_descr_quantiles, ~clean_mol_mask)
259
+ features.append(rdkit_descr_quantiles)
260
+ print("Created quantiles of RDKit descriptors")
261
+
262
+ if "maccs" in descriptors:
263
+ maccs = create_maccs_keys(mols)
264
+ maccs = fill(maccs, ~clean_mol_mask)
265
+ features.append(rdkit_descr_quantiles)
266
+ print("Created MACCS keys")
267
+
268
+ if "tox" in descriptors:
269
+ tox_patterns = get_tox_patterns("assets/tox_smarts.json")
270
+ tox = create_tox_features(mols, tox_patterns)
271
+ tox = fill(tox, ~clean_mol_mask)
272
+ features.append(rdkit_descr_quantiles)
273
+ print("Created Tox features")
274
+
275
+ # concatenate features
276
+ raw_features = np.concatenate(features, axis=1)
277
+
278
+ # normalize with scaler if scaler is passed, else create scaler
279
+ features = normalize_features(
280
+ raw_features,
281
+ scaler=scaler,
282
+ verbose=True,
283
+ )
284
+
285
+ return features, clean_mol_mask
src/utils.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## These MolStandardizer classes are due to Paolo Tosco
2
+ ## It was taken from the FS-Mol github
3
+ ## (https://github.com/microsoft/FS-Mol/blob/main/fs_mol/preprocessing/utils/
4
+ ## standardizer.py)
5
+ ## They ensure that a sequence of standardization operations are applied
6
+ ## https://gist.github.com/ptosco/7e6b9ab9cc3e44ba0919060beaed198e
7
+
8
+ import os
9
+ import pickle
10
+
11
+ from rdkit import Chem
12
+ from rdkit.Chem.MolStandardize import rdMolStandardize
13
+
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
+ PAD_VALUE = -100
16
+
17
+ TASKS = [
18
+ "NR-AR",
19
+ "NR-AR-LBD",
20
+ "NR-AhR",
21
+ "NR-Aromatase",
22
+ "NR-ER",
23
+ "NR-ER-LBD",
24
+ "NR-PPAR-gamma",
25
+ "SR-ARE",
26
+ "SR-ATAD5",
27
+ "SR-HSE",
28
+ "SR-MMP",
29
+ "SR-p53",
30
+ ]
31
+
32
+ KNOWN_DESCR = ["ecfps", "rdkit_descr_quantiles", "maccs", "tox"]
33
+
34
+ USED_200_DESCR = [
35
+ 0,
36
+ 1,
37
+ 2,
38
+ 3,
39
+ 4,
40
+ 5,
41
+ 6,
42
+ 7,
43
+ 8,
44
+ 9,
45
+ 10,
46
+ 11,
47
+ 12,
48
+ 13,
49
+ 14,
50
+ 15,
51
+ 16,
52
+ 25,
53
+ 26,
54
+ 27,
55
+ 28,
56
+ 29,
57
+ 30,
58
+ 31,
59
+ 32,
60
+ 33,
61
+ 34,
62
+ 35,
63
+ 36,
64
+ 37,
65
+ 38,
66
+ 39,
67
+ 40,
68
+ 41,
69
+ 42,
70
+ 43,
71
+ 44,
72
+ 45,
73
+ 46,
74
+ 47,
75
+ 48,
76
+ 49,
77
+ 50,
78
+ 51,
79
+ 52,
80
+ 53,
81
+ 54,
82
+ 55,
83
+ 56,
84
+ 57,
85
+ 58,
86
+ 59,
87
+ 60,
88
+ 61,
89
+ 62,
90
+ 63,
91
+ 64,
92
+ 65,
93
+ 66,
94
+ 67,
95
+ 68,
96
+ 69,
97
+ 70,
98
+ 71,
99
+ 72,
100
+ 73,
101
+ 74,
102
+ 75,
103
+ 76,
104
+ 77,
105
+ 78,
106
+ 79,
107
+ 80,
108
+ 81,
109
+ 82,
110
+ 83,
111
+ 84,
112
+ 85,
113
+ 86,
114
+ 87,
115
+ 88,
116
+ 89,
117
+ 90,
118
+ 91,
119
+ 92,
120
+ 93,
121
+ 94,
122
+ 95,
123
+ 96,
124
+ 97,
125
+ 98,
126
+ 99,
127
+ 100,
128
+ 101,
129
+ 102,
130
+ 103,
131
+ 104,
132
+ 105,
133
+ 106,
134
+ 107,
135
+ 108,
136
+ 109,
137
+ 110,
138
+ 111,
139
+ 112,
140
+ 113,
141
+ 114,
142
+ 115,
143
+ 116,
144
+ 117,
145
+ 118,
146
+ 119,
147
+ 120,
148
+ 121,
149
+ 122,
150
+ 123,
151
+ 124,
152
+ 125,
153
+ 126,
154
+ 127,
155
+ 128,
156
+ 129,
157
+ 130,
158
+ 131,
159
+ 132,
160
+ 133,
161
+ 134,
162
+ 135,
163
+ 136,
164
+ 137,
165
+ 138,
166
+ 139,
167
+ 140,
168
+ 141,
169
+ 142,
170
+ 143,
171
+ 144,
172
+ 145,
173
+ 146,
174
+ 147,
175
+ 148,
176
+ 149,
177
+ 150,
178
+ 151,
179
+ 152,
180
+ 153,
181
+ 154,
182
+ 155,
183
+ 156,
184
+ 157,
185
+ 158,
186
+ 159,
187
+ 160,
188
+ 161,
189
+ 162,
190
+ 163,
191
+ 164,
192
+ 165,
193
+ 166,
194
+ 167,
195
+ 168,
196
+ 169,
197
+ 170,
198
+ 171,
199
+ 172,
200
+ 173,
201
+ 174,
202
+ 175,
203
+ 176,
204
+ 177,
205
+ 178,
206
+ 179,
207
+ 180,
208
+ 181,
209
+ 182,
210
+ 183,
211
+ 184,
212
+ 185,
213
+ 186,
214
+ 187,
215
+ 188,
216
+ 189,
217
+ 190,
218
+ 191,
219
+ 192,
220
+ 193,
221
+ 194,
222
+ 195,
223
+ 196,
224
+ 197,
225
+ 198,
226
+ 199,
227
+ 200,
228
+ 201,
229
+ 202,
230
+ 203,
231
+ 204,
232
+ 205,
233
+ 206,
234
+ 207,
235
+ ]
236
+
237
+
238
+ class Standardizer:
239
+ """
240
+ Simple wrapper class around rdkit Standardizer.
241
+ """
242
+
243
+ DEFAULT_CANON_TAUT = False
244
+ DEFAULT_METAL_DISCONNECT = False
245
+ MAX_TAUTOMERS = 100
246
+ MAX_TRANSFORMS = 100
247
+ MAX_RESTARTS = 200
248
+ PREFER_ORGANIC = True
249
+
250
+ def __init__(
251
+ self,
252
+ metal_disconnect=None,
253
+ canon_taut=None,
254
+ ):
255
+ """
256
+ Constructor.
257
+ All parameters are optional.
258
+ :param metal_disconnect: if True, metallorganic complexes are
259
+ disconnected
260
+ :param canon_taut: if True, molecules are converted to their
261
+ canonical tautomer
262
+ """
263
+ super().__init__()
264
+ if metal_disconnect is None:
265
+ metal_disconnect = self.DEFAULT_METAL_DISCONNECT
266
+ if canon_taut is None:
267
+ canon_taut = self.DEFAULT_CANON_TAUT
268
+ self._canon_taut = canon_taut
269
+ self._metal_disconnect = metal_disconnect
270
+ self._taut_enumerator = None
271
+ self._uncharger = None
272
+ self._lfrag_chooser = None
273
+ self._metal_disconnector = None
274
+ self._normalizer = None
275
+ self._reionizer = None
276
+ self._params = None
277
+
278
+ @property
279
+ def params(self):
280
+ """Return the MolStandardize CleanupParameters."""
281
+ if self._params is None:
282
+ self._params = rdMolStandardize.CleanupParameters()
283
+ self._params.maxTautomers = self.MAX_TAUTOMERS
284
+ self._params.maxTransforms = self.MAX_TRANSFORMS
285
+ self._params.maxRestarts = self.MAX_RESTARTS
286
+ self._params.preferOrganic = self.PREFER_ORGANIC
287
+ self._params.tautomerRemoveSp3Stereo = False
288
+ return self._params
289
+
290
+ @property
291
+ def canon_taut(self):
292
+ """Return whether tautomer canonicalization will be done."""
293
+ return self._canon_taut
294
+
295
+ @property
296
+ def metal_disconnect(self):
297
+ """Return whether metallorganic complexes will be disconnected."""
298
+ return self._metal_disconnect
299
+
300
+ @property
301
+ def taut_enumerator(self):
302
+ """Return the TautomerEnumerator object."""
303
+ if self._taut_enumerator is None:
304
+ self._taut_enumerator = rdMolStandardize.TautomerEnumerator(self.params)
305
+ return self._taut_enumerator
306
+
307
+ @property
308
+ def uncharger(self):
309
+ """Return the Uncharger object."""
310
+ if self._uncharger is None:
311
+ self._uncharger = rdMolStandardize.Uncharger()
312
+ return self._uncharger
313
+
314
+ @property
315
+ def lfrag_chooser(self):
316
+ """Return the LargestFragmentChooser object."""
317
+ if self._lfrag_chooser is None:
318
+ self._lfrag_chooser = rdMolStandardize.LargestFragmentChooser(
319
+ self.params.preferOrganic
320
+ )
321
+ return self._lfrag_chooser
322
+
323
+ @property
324
+ def metal_disconnector(self):
325
+ """Return the MetalDisconnector object."""
326
+ if self._metal_disconnector is None:
327
+ self._metal_disconnector = rdMolStandardize.MetalDisconnector()
328
+ return self._metal_disconnector
329
+
330
+ @property
331
+ def normalizer(self):
332
+ """Return the Normalizer object."""
333
+ if self._normalizer is None:
334
+ self._normalizer = rdMolStandardize.Normalizer(
335
+ self.params.normalizationsFile, self.params.maxRestarts
336
+ )
337
+ return self._normalizer
338
+
339
+ @property
340
+ def reionizer(self):
341
+ """Return the Reionizer object."""
342
+ if self._reionizer is None:
343
+ self._reionizer = rdMolStandardize.Reionizer(self.params.acidbaseFile)
344
+ return self._reionizer
345
+
346
+ def charge_parent(self, mol_in):
347
+ """Sequentially apply a series of MolStandardize operations:
348
+ * MetalDisconnector
349
+ * Normalizer
350
+ * Reionizer
351
+ * LargestFragmentChooser
352
+ * Uncharger
353
+ The net result is that a desalted, normalized, neutral
354
+ molecule with implicit Hs is returned.
355
+ """
356
+ params = Chem.RemoveHsParameters()
357
+ params.removeAndTrackIsotopes = True
358
+ mol_in = Chem.RemoveHs(mol_in, params, sanitize=False)
359
+ if self._metal_disconnect:
360
+ mol_in = self.metal_disconnector.Disconnect(mol_in)
361
+ normalized = self.normalizer.normalize(mol_in)
362
+ Chem.SanitizeMol(normalized)
363
+ normalized = self.reionizer.reionize(normalized)
364
+ Chem.AssignStereochemistry(normalized)
365
+ normalized = self.lfrag_chooser.choose(normalized)
366
+ normalized = self.uncharger.uncharge(normalized)
367
+ # need this to reassess aromaticity on things like
368
+ # cyclopentadienyl, tropylium, azolium, etc.
369
+ Chem.SanitizeMol(normalized)
370
+ return Chem.RemoveHs(Chem.AddHs(normalized))
371
+
372
+ def standardize_mol(self, mol_in):
373
+ """
374
+ Standardize a single molecule.
375
+ :param mol_in: a Chem.Mol
376
+ :return: * (standardized Chem.Mol, n_taut) tuple
377
+ if success. n_taut will be negative if
378
+ tautomer enumeration was aborted due
379
+ to reaching a limit
380
+ * (None, error_msg) if failure
381
+ This calls self.charge_parent() and, if self._canon_taut
382
+ is True, runs tautomer canonicalization.
383
+ """
384
+ n_tautomers = 0
385
+ if isinstance(mol_in, Chem.Mol):
386
+ name = None
387
+ try:
388
+ name = mol_in.GetProp("_Name")
389
+ except KeyError:
390
+ pass
391
+ if not name:
392
+ name = "NONAME"
393
+ else:
394
+ error = f"Expected SMILES or Chem.Mol as input, got {str(type(mol_in))}"
395
+ return None, error
396
+ try:
397
+ mol_out = self.charge_parent(mol_in)
398
+ except Exception as e:
399
+ error = f"charge_parent FAILED: {str(e).strip()}"
400
+ return None, error
401
+ if self._canon_taut:
402
+ try:
403
+ res = self.taut_enumerator.Enumerate(mol_out, False)
404
+ except TypeError:
405
+ # we are still on the pre-2021 RDKit API
406
+ res = self.taut_enumerator.Enumerate(mol_out)
407
+ except Exception as e:
408
+ # something else went wrong
409
+ error = f"canon_taut FAILED: {str(e).strip()}"
410
+ return None, error
411
+ n_tautomers = len(res)
412
+ if hasattr(res, "status"):
413
+ completed = (
414
+ res.status == rdMolStandardize.TautomerEnumeratorStatus.Completed
415
+ )
416
+ else:
417
+ # we are still on the pre-2021 RDKit API
418
+ completed = len(res) < 1000
419
+ if not completed:
420
+ n_tautomers = -n_tautomers
421
+ try:
422
+ mol_out = self.taut_enumerator.PickCanonical(res)
423
+ except AttributeError:
424
+ # we are still on the pre-2021 RDKit API
425
+ mol_out = max(
426
+ [(self.taut_enumerator.ScoreTautomer(m), m) for m in res]
427
+ )[1]
428
+ except Exception as e:
429
+ # something else went wrong
430
+ error = f"canon_taut FAILED: {str(e).strip()}"
431
+ return None, error
432
+ mol_out.SetProp("_Name", name)
433
+ return mol_out, n_tautomers
434
+
435
+
436
+ def load_pickle(path: str):
437
+ with open(path, "rb") as file:
438
+ content = pickle.load(file)
439
+ return content
440
+
441
+
442
+ def write_pickle(path: str, obj: object):
443
+ with open(path, "wb") as file:
444
+ pickle.dump(obj, file)