File size: 1,919 Bytes
f484830
 
 
 
1ce331f
 
 
f484830
4e8b53d
f484830
 
 
 
 
 
 
 
 
a78b381
f484830
 
e4f4cf0
f484830
 
 
 
 
 
f0bc9a8
f484830
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4f4cf0
f484830
 
f0bc9a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from torch_geometric.data import Batch
from torch_geometric.utils import from_rdmol
import torch

from src.model import GIN
from src.preprocess import create_clean_mol_objects
from src.seed import set_seed

def predict(smiles_list):
    """
    Predict toxicity targets for a list of SMILES strings.

    Args:
        smiles_list (list[str]): SMILES strings

    Returns:
        dict: {smiles: {target_name: prediction_prob}}
    """
    set_seed(0)
    # tox21 targets
    TARGET_NAMES = [
            "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
        ]
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Received {len(smiles_list)} SMILES strings")

    # setup model
    model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean")
    model_path = "./checkpoints/model.pt"
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    print(f"Loaded model from {model_path}")
    model.to(DEVICE)
    model.eval()
    predictions = {}

    for smiles in smiles_list:
        try:
            # Convert SMILES to graph
            mol, _ = create_clean_mol_objects([smiles])
            data = from_rdmol(mol[0]).to(DEVICE)
            batch = Batch.from_data_list([data])

            # Forward pass
            with torch.no_grad():
                logits = model(batch.x, batch.edge_index, batch.batch)
                probs = torch.sigmoid(logits).cpu().numpy().flatten()

            # Map predictions to targets
            pred_dict = {t: float(p) for t, p in zip(TARGET_NAMES, probs)}
            predictions[smiles] = pred_dict

        except Exception as e:
            # If SMILES fails, return zeros
            pred_dict = {t: 0.5 for t in TARGET_NAMES}
            predictions[smiles] = pred_dict

    return predictions