tox21_gin_classifier / src /preprocess.py
Sonja Topf
initial commit
f484830
raw
history blame
3.2 kB
import numpy as np
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit import Chem
from torch_geometric.data import InMemoryDataset
from torch_geometric.utils import from_rdmol
def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]:
"""Create cleaned RDKit Mol objects from SMILES.
Returns (list of mols, mask of valid mols).
"""
clean_mol_mask = []
mols = []
# Standardizer components
cleaner = rdMolStandardize.CleanupParameters()
tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
for smi in smiles:
try:
mol = Chem.MolFromSmiles(smi)
if mol is None:
clean_mol_mask.append(False)
continue
# Cleanup and canonicalize
mol = rdMolStandardize.Cleanup(mol, cleaner)
mol = tautomer_enumerator.Canonicalize(mol)
# Recompute canonical SMILES & reload
can_smi = Chem.MolToSmiles(mol)
mol = Chem.MolFromSmiles(can_smi)
if mol is not None:
mols.append(mol)
clean_mol_mask.append(True)
else:
clean_mol_mask.append(False)
except Exception as e:
print(f"Failed to standardize {smi}: {e}")
clean_mol_mask.append(False)
return mols, np.array(clean_mol_mask, dtype=bool)
class Tox21Dataset(InMemoryDataset):
def __init__(self, dataframe):
super().__init__()
data_list = []
# Clean molecules & filter dataframe
mols, clean_mask = create_clean_mol_objects(dataframe["smiles"].tolist())
dataframe = dataframe[clean_mask].reset_index(drop=True)
# Now mols and dataframe are aligned, so we can zip
for mol, (_, row) in zip(mols, dataframe.iterrows()):
try:
data = from_rdmol(mol)
# Extract labels as a pandas Series
drop_cols = ["ID","smiles","inchikey","sdftitle","order","set","CVfold"]
labels = row.drop(drop_cols)
# Mask for valid labels
mask = ~labels.isna()
# Explicit numeric conversion, replaces NaN with 0.0 safely
labels = pd.to_numeric(labels, errors="coerce").fillna(0.0).astype(float).values
# Convert to tensors
y = torch.tensor(labels, dtype=torch.float).unsqueeze(0)
m = torch.tensor(mask.values, dtype=torch.bool).unsqueeze(0)
data.y = y
data.mask = m
data_list.append(data)
except Exception as e:
print(f"Skipping molecule {row['smiles']} due to error: {e}")
# Collate into dataset
self.data, self.slices = self.collate(data_list)
def get_graph_dataset(filepath:str):
"""returns an InMemoryDataset that can be used in dataloaders
Args:
filepath (str): the filepath of the data csv
Returns:
Tox21Dataset: dataset for dataloaders
"""
df = pd.read_csv(filepath)
dataset = Tox21Dataset(df)
return dataset