File size: 4,210 Bytes
f484830
 
 
 
 
 
 
 
f0bc9a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f484830
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0bc9a8
f484830
 
 
 
 
 
 
 
f0bc9a8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit import Chem
from torch_geometric.data import InMemoryDataset
from torch_geometric.utils import from_rdmol
from datasets import load_dataset

def get_tox21_split(token, cvfold=None):
    ds = load_dataset("tschouis/tox21", token=token)
    
    train_df = ds["train"].to_pandas()
    val_df = ds["validation"].to_pandas()

    if cvfold is None:
        return {
            "train": train_df,
            "validation": val_df
        }
    
    combined_df = pd.concat([train_df, val_df], ignore_index=True)
    cvfold = float(cvfold)

    # create new splits
    cvfold = float(cvfold)
    train_df = combined_df[combined_df.CVfold != cvfold]
    val_df = combined_df[combined_df.CVfold == cvfold]

    # exclude train mols that occur in the validation split
    val_inchikeys = set(val_df["inchikey"])
    train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)]

    return {"train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True)}

def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]:
    """Create cleaned RDKit Mol objects from SMILES.
       Returns (list of mols, mask of valid mols).
    """
    clean_mol_mask = []
    mols = []

    # Standardizer components
    cleaner = rdMolStandardize.CleanupParameters()
    tautomer_enumerator = rdMolStandardize.TautomerEnumerator()

    for smi in smiles:
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                clean_mol_mask.append(False)
                continue

            # Cleanup and canonicalize
            mol = rdMolStandardize.Cleanup(mol, cleaner)
            mol = tautomer_enumerator.Canonicalize(mol)

            # Recompute canonical SMILES & reload
            can_smi = Chem.MolToSmiles(mol)
            mol = Chem.MolFromSmiles(can_smi)

            if mol is not None:
                mols.append(mol)
                clean_mol_mask.append(True)
            else:
                clean_mol_mask.append(False)

        except Exception as e:
            print(f"Failed to standardize {smi}: {e}")
            clean_mol_mask.append(False)

    return mols, np.array(clean_mol_mask, dtype=bool)


class Tox21Dataset(InMemoryDataset):
    def __init__(self, dataframe):
        super().__init__()
        data_list = []

        # Clean molecules & filter dataframe
        mols, clean_mask = create_clean_mol_objects(dataframe["smiles"].tolist())
        dataframe = dataframe[clean_mask].reset_index(drop=True)

        # Now mols and dataframe are aligned, so we can zip
        for mol, (_, row) in zip(mols, dataframe.iterrows()):
            try:
                data = from_rdmol(mol)

                # Extract labels as a pandas Series
                drop_cols = ["ID","smiles","inchikey","sdftitle","order","set","CVfold"]
                labels = row.drop(drop_cols)

                # Mask for valid labels
                mask = ~labels.isna()

                # Explicit numeric conversion, replaces NaN with 0.0 safely
                labels = pd.to_numeric(labels, errors="coerce").fillna(0.0).astype(float).values

                # Convert to tensors
                y = torch.tensor(labels, dtype=torch.float).unsqueeze(0)
                m = torch.tensor(mask.values, dtype=torch.bool).unsqueeze(0)

                data.y = y
                data.mask = m

                data_list.append(data)

            except Exception as e:
                print(f"Skipping molecule {row['smiles']} due to error: {e}")

        # Collate into dataset
        self.data, self.slices = self.collate(data_list)


def get_graph_datasets(token):
    """returns an InMemoryDataset that can be used in dataloaders

    Args:
        filepath (str): the filepath of the data csv

    Returns:
        Tox21Dataset: dataset for dataloaders
    """
    datasets = get_tox21_split(token, cvfold=4)
    train_df, val_df = datasets["train"], datasets["validation"]
    train_dataset = Tox21Dataset(train_df)
    val_dataset = Tox21Dataset(val_df)
    return train_dataset, val_dataset