Spaces:
Sleeping
Sleeping
| """ | |
| This files includes a predict function for the Tox21. | |
| As an input it takes a list of SMILES and it outputs a nested dictionary with | |
| SMILES and target names as keys. | |
| """ | |
| # --------------------------------------------------------------------------------------- | |
| # Dependencies | |
| from collections import defaultdict | |
| import numpy as np | |
| import torch | |
| from src.preprocess import create_descriptors | |
| from src.model import Tox21SNNClassifier, SNNConfig | |
| from src.utils import load_pickle | |
| # --------------------------------------------------------------------------------------- | |
| def predict(smiles_list: list[str]) -> dict[str, dict[str, float]]: | |
| """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for | |
| any molecule that could not be cleaned. | |
| Args: | |
| smiles_list (list[str]): list of SMILES strings | |
| Returns: | |
| dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}} | |
| """ | |
| print(f"Received {len(smiles_list)} SMILES strings") | |
| # preprocessing pipeline | |
| ecdfs_path = "assets/ecdfs.pkl" | |
| scaler_path = "assets/scaler.pkl" | |
| ecdfs = load_pickle(ecdfs_path) | |
| scaler = load_pickle(scaler_path) | |
| print(f"Loaded ecdfs from {ecdfs_path}") | |
| print(f"Loaded scaler from {scaler_path}") | |
| descriptors = ["rdkit_descr_quantiles", "tox"] | |
| features, mol_mask = create_descriptors( | |
| smiles, | |
| ecdfs=ecdfs, | |
| scaler=scaler, | |
| descriptors=descriptors, | |
| ) | |
| print(f"Created descriptors {descriptors} for molecules.") | |
| print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning") | |
| # setup model | |
| cfg = SNNConfig( | |
| hidden_dim=1024, | |
| n_layers=8, | |
| dropout=0.05, | |
| layer_form="conic", | |
| in_features=features.shape[0], | |
| out_features=12, | |
| ) | |
| model = Tox21SNNClassifier(cfg) | |
| model_path = "assets/snn_best.pth" | |
| model.load_model(model_path) | |
| model.eval() | |
| print(f"Loaded model from {model_path}") | |
| # make predicitons | |
| predictions = defaultdict(dict) | |
| # create a list with same length as smiles_list to obtain indices for respective features | |
| feat_indices = np.cumsum(mol_mask) - 1 | |
| mask = ~np.isnan(features).any(axis=1) | |
| dataset = torch.utils.data.TensorDataset(torch.FloatTensor(features[mask])) | |
| loader = torch.utils.data.DataLoader(dataset, 128, shuffle=False, num_workers=0) | |
| with torch.no_grad(): | |
| preds = np.concatenate([model.predict(batch) for batch in loader], axis=0) | |
| for i, target in enumerate(model.tasks): | |
| for smiles, is_clean, j in zip(smiles_list, mol_mask, feat_indices): | |
| predictions[smiles][target] = float(preds[j, i]) if is_clean else 0.5 | |
| return predictions | |