""" This files includes a predict function for the Tox21. As an input it takes a list of SMILES and it outputs a nested dictionary with SMILES and target names as keys. """ # --------------------------------------------------------------------------------------- # Dependencies from collections import defaultdict import numpy as np import torch from src.data import create_descriptors from src.model import Tox21SNNClassifier, SNNConfig from src.utils import load_pickle, KNOWN_DESCR # --------------------------------------------------------------------------------------- ECDFS_PATH = "data/ecdfs.pkl" FEATURE_SELECTION_PATH = "data/feat_selection.npz" SCALER1_PATH = "data/scaler1.pkl" SCALER2_PATH = "data/scaler2.pkl" CHECKPOINT_PATH = "checkpoints/noble-butterfly-499_last.pth" ECFP_RADIUS = 3 ECFP_FPSIZE = 8192 def predict( smiles_list: list[str], default_prediction=0.5 ) -> dict[str, dict[str, float]]: """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for any molecule that could not be cleaned. Args: smiles_list (list[str]): list of SMILES strings Returns: dict: nested prediction dictionary, following {'': {'': }} """ print(f"Received {len(smiles_list)} SMILES strings") # preprocessing pipeline scaler1 = load_pickle(SCALER1_PATH) scaler2 = load_pickle(SCALER2_PATH) ecdfs = load_pickle(ECDFS_PATH) feature_selection = np.load(FEATURE_SELECTION_PATH) print(f"Loaded scaler1 from: {SCALER1_PATH}") print(f"Loaded scaler2 from: {SCALER2_PATH}") print(f"Loaded ecdfs from: {ECDFS_PATH}") print(f"Loaded feature selection from: {FEATURE_SELECTION_PATH}") features = create_descriptors( smiles_list, ecdfs=ecdfs, feature_selection=feature_selection, radius=ECFP_RADIUS, fpsize=ECFP_FPSIZE, )["features"] features = np.concatenate([features[descr] for descr in KNOWN_DESCR], axis=1) features = scaler1.transform(features) features = np.tanh(features) features = scaler2.transform(features) print(f"Created descriptors for molecules.") print(f"{(np.isnan(features).all(axis=1).sum())} molecules removed during cleaning") is_clean = ~np.isnan(features).all(axis=1) dataset = torch.utils.data.TensorDataset(torch.FloatTensor(features[is_clean])) loader = torch.utils.data.DataLoader( dataset, batch_size=256, shuffle=False, num_workers=0 ) # setup model cfg = SNNConfig( hidden_dim=512, n_layers=8, dropout=0.05, layer_form="rect", in_features=features.shape[1], out_features=12, ) model = Tox21SNNClassifier(cfg) model.load_model(CHECKPOINT_PATH) model.eval() print(f"Loaded model from {CHECKPOINT_PATH}") predictions = defaultdict(dict) print(f"Create predictions:") preds = [] with torch.no_grad(): preds = np.concatenate([model.predict(batch[0]) for batch in loader], axis=0) for i, target in enumerate(model.tasks): target_preds = np.empty_like(is_clean, dtype=float) target_preds[~is_clean] = default_prediction target_preds[is_clean] = preds[:, i] for smiles, pred in zip(smiles_list, target_preds): predictions[smiles][target] = float(pred) return predictions