File size: 2,498 Bytes
9f78de0
 
 
 
2e511fb
9f78de0
 
 
 
 
 
 
 
 
 
 
338b5f6
 
9f78de0
338b5f6
 
9f78de0
338b5f6
 
 
 
 
9f78de0
338b5f6
 
9f78de0
338b5f6
 
 
 
 
 
 
 
40e74c5
338b5f6
 
 
 
 
 
 
40e74c5
0d7dfdb
40e74c5
 
338b5f6
 
 
 
 
 
 
40e74c5
338b5f6
 
e517ec0
338b5f6
 
 
 
 
 
 
 
 
 
 
5567bdb
 
 
ad23c3f
9f78de0
40e74c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import csv
import subprocess

from src.preprocess import create_clean_smiles

def predict(smiles_list):
    """
    Predict toxicity targets for a list of SMILES strings.

    Args:
        smiles_list (list[str]): SMILES strings

    Returns:
        dict: {smiles: {target_name: prediction_prob}}
    """
    # clean smiles
    clean_smiles, valid_mask = create_clean_smiles(smiles_list)

    # Mapping from cleaned to original for valid ones
    originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok]

    # sanity check (optional but nice to have)
    if len(originals_valid) != len(clean_smiles):
        raise ValueError(
            f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES"
        )

    # map cleaned → original
    cleaned_to_original = dict(zip(clean_smiles, originals_valid))

    # tox21 targets
    TARGET_NAMES = [
            "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
        ]
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Received {len(smiles_list)} SMILES strings")

    # put smiles into csv
    with open("./data/smiles.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["smiles"])  # header
        for smi in clean_smiles:
            writer.writerow([smi])
    # predict
    command = [
        "chemprop", "predict",
        "--test-path", "data/smiles.csv",
        "--model-path", "checkpoints/model.pt",
        "--smiles-columns", "smiles",
        "--preds-path", "data/preds.csv"
    ]

    # Run the command
    subprocess.run(command, check=True)

    # create results dictionary from predictions

    csv_path = "./data/preds.csv"

    predictions = {}
    with open(csv_path, "r", newline="") as f:
        reader = csv.DictReader(f)
        rows = list(reader)
        target_names = [col for col in reader.fieldnames if col != "smiles"]

        for row in rows:
            clean_smi = row["smiles"]
            original_smi = cleaned_to_original.get(clean_smi, clean_smi)
            pred_dict = {t: float(row[t]) for t in target_names}
            predictions[original_smi] = pred_dict

    # Add placeholder predictions for invalid SMILES
    for smi, is_valid in zip(smiles_list, valid_mask):
        if not is_valid:
            predictions[smi] = {t: 0.5 for t in TARGET_NAMES}
        

    return predictions