tox21_gin_classifier / predict.py
Sonja Topf
README, utils to src folder
1ce331f
from torch_geometric.data import Batch
from torch_geometric.utils import from_rdmol
import torch
from src.model import GIN
from src.preprocess import create_clean_mol_objects
from src.seed import set_seed
def predict(smiles_list):
"""
Predict toxicity targets for a list of SMILES strings.
Args:
smiles_list (list[str]): SMILES strings
Returns:
dict: {smiles: {target_name: prediction_prob}}
"""
set_seed(0)
# tox21 targets
TARGET_NAMES = [
"NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Received {len(smiles_list)} SMILES strings")
# setup model
model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean")
model_path = "./checkpoints/model.pt"
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
print(f"Loaded model from {model_path}")
model.to(DEVICE)
model.eval()
predictions = {}
for smiles in smiles_list:
try:
# Convert SMILES to graph
mol, _ = create_clean_mol_objects([smiles])
data = from_rdmol(mol[0]).to(DEVICE)
batch = Batch.from_data_list([data])
# Forward pass
with torch.no_grad():
logits = model(batch.x, batch.edge_index, batch.batch)
probs = torch.sigmoid(logits).cpu().numpy().flatten()
# Map predictions to targets
pred_dict = {t: float(p) for t, p in zip(TARGET_NAMES, probs)}
predictions[smiles] = pred_dict
except Exception as e:
# If SMILES fails, return zeros
pred_dict = {t: 0.5 for t in TARGET_NAMES}
predictions[smiles] = pred_dict
return predictions