import numpy as np import torch import pandas as pd from rdkit import Chem from rdkit.Chem.MolStandardize import rdMolStandardize from rdkit import Chem from torch_geometric.data import InMemoryDataset from torch_geometric.utils import from_rdmol def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]: """Create cleaned RDKit Mol objects from SMILES. Returns (list of mols, mask of valid mols). """ clean_mol_mask = [] mols = [] # Standardizer components cleaner = rdMolStandardize.CleanupParameters() tautomer_enumerator = rdMolStandardize.TautomerEnumerator() for smi in smiles: try: mol = Chem.MolFromSmiles(smi) if mol is None: clean_mol_mask.append(False) continue # Cleanup and canonicalize mol = rdMolStandardize.Cleanup(mol, cleaner) mol = tautomer_enumerator.Canonicalize(mol) # Recompute canonical SMILES & reload can_smi = Chem.MolToSmiles(mol) mol = Chem.MolFromSmiles(can_smi) if mol is not None: mols.append(mol) clean_mol_mask.append(True) else: clean_mol_mask.append(False) except Exception as e: print(f"Failed to standardize {smi}: {e}") clean_mol_mask.append(False) return mols, np.array(clean_mol_mask, dtype=bool) class Tox21Dataset(InMemoryDataset): def __init__(self, dataframe): super().__init__() data_list = [] # Clean molecules & filter dataframe mols, clean_mask = create_clean_mol_objects(dataframe["smiles"].tolist()) dataframe = dataframe[clean_mask].reset_index(drop=True) # Now mols and dataframe are aligned, so we can zip for mol, (_, row) in zip(mols, dataframe.iterrows()): try: data = from_rdmol(mol) # Extract labels as a pandas Series drop_cols = ["ID","smiles","inchikey","sdftitle","order","set","CVfold"] labels = row.drop(drop_cols) # Mask for valid labels mask = ~labels.isna() # Explicit numeric conversion, replaces NaN with 0.0 safely labels = pd.to_numeric(labels, errors="coerce").fillna(0.0).astype(float).values # Convert to tensors y = torch.tensor(labels, dtype=torch.float).unsqueeze(0) m = torch.tensor(mask.values, dtype=torch.bool).unsqueeze(0) data.y = y data.mask = m data_list.append(data) except Exception as e: print(f"Skipping molecule {row['smiles']} due to error: {e}") # Collate into dataset self.data, self.slices = self.collate(data_list) def get_graph_dataset(filepath:str): """returns an InMemoryDataset that can be used in dataloaders Args: filepath (str): the filepath of the data csv Returns: Tox21Dataset: dataset for dataloaders """ df = pd.read_csv(filepath) dataset = Tox21Dataset(df) return dataset