from torch_geometric.data import Batch from torch_geometric.utils import from_rdmol import torch from src.model import GIN from src.preprocess import create_clean_mol_objects from src.seed import set_seed def predict(smiles_list): """ Predict toxicity targets for a list of SMILES strings. Args: smiles_list (list[str]): SMILES strings Returns: dict: {smiles: {target_name: prediction_prob}} """ set_seed(0) # tox21 targets TARGET_NAMES = [ "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53" ] DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Received {len(smiles_list)} SMILES strings") # setup model model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean") model_path = "./checkpoints/model.pt" model.load_state_dict(torch.load(model_path, map_location=DEVICE)) print(f"Loaded model from {model_path}") model.to(DEVICE) model.eval() predictions = {} for smiles in smiles_list: try: # Convert SMILES to graph mol, _ = create_clean_mol_objects([smiles]) data = from_rdmol(mol[0]).to(DEVICE) batch = Batch.from_data_list([data]) # Forward pass with torch.no_grad(): logits = model(batch.x, batch.edge_index, batch.batch) probs = torch.sigmoid(logits).cpu().numpy().flatten() # Map predictions to targets pred_dict = {t: float(p) for t, p in zip(TARGET_NAMES, probs)} predictions[smiles] = pred_dict except Exception as e: # If SMILES fails, return zeros pred_dict = {t: 0.5 for t in TARGET_NAMES} predictions[smiles] = pred_dict return predictions