Spaces:
Sleeping
Sleeping
File size: 4,210 Bytes
f484830 f0bc9a8 f484830 f0bc9a8 f484830 f0bc9a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import numpy as np
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit import Chem
from torch_geometric.data import InMemoryDataset
from torch_geometric.utils import from_rdmol
from datasets import load_dataset
def get_tox21_split(token, cvfold=None):
ds = load_dataset("tschouis/tox21", token=token)
train_df = ds["train"].to_pandas()
val_df = ds["validation"].to_pandas()
if cvfold is None:
return {
"train": train_df,
"validation": val_df
}
combined_df = pd.concat([train_df, val_df], ignore_index=True)
cvfold = float(cvfold)
# create new splits
cvfold = float(cvfold)
train_df = combined_df[combined_df.CVfold != cvfold]
val_df = combined_df[combined_df.CVfold == cvfold]
# exclude train mols that occur in the validation split
val_inchikeys = set(val_df["inchikey"])
train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)]
return {"train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True)}
def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]:
"""Create cleaned RDKit Mol objects from SMILES.
Returns (list of mols, mask of valid mols).
"""
clean_mol_mask = []
mols = []
# Standardizer components
cleaner = rdMolStandardize.CleanupParameters()
tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
for smi in smiles:
try:
mol = Chem.MolFromSmiles(smi)
if mol is None:
clean_mol_mask.append(False)
continue
# Cleanup and canonicalize
mol = rdMolStandardize.Cleanup(mol, cleaner)
mol = tautomer_enumerator.Canonicalize(mol)
# Recompute canonical SMILES & reload
can_smi = Chem.MolToSmiles(mol)
mol = Chem.MolFromSmiles(can_smi)
if mol is not None:
mols.append(mol)
clean_mol_mask.append(True)
else:
clean_mol_mask.append(False)
except Exception as e:
print(f"Failed to standardize {smi}: {e}")
clean_mol_mask.append(False)
return mols, np.array(clean_mol_mask, dtype=bool)
class Tox21Dataset(InMemoryDataset):
def __init__(self, dataframe):
super().__init__()
data_list = []
# Clean molecules & filter dataframe
mols, clean_mask = create_clean_mol_objects(dataframe["smiles"].tolist())
dataframe = dataframe[clean_mask].reset_index(drop=True)
# Now mols and dataframe are aligned, so we can zip
for mol, (_, row) in zip(mols, dataframe.iterrows()):
try:
data = from_rdmol(mol)
# Extract labels as a pandas Series
drop_cols = ["ID","smiles","inchikey","sdftitle","order","set","CVfold"]
labels = row.drop(drop_cols)
# Mask for valid labels
mask = ~labels.isna()
# Explicit numeric conversion, replaces NaN with 0.0 safely
labels = pd.to_numeric(labels, errors="coerce").fillna(0.0).astype(float).values
# Convert to tensors
y = torch.tensor(labels, dtype=torch.float).unsqueeze(0)
m = torch.tensor(mask.values, dtype=torch.bool).unsqueeze(0)
data.y = y
data.mask = m
data_list.append(data)
except Exception as e:
print(f"Skipping molecule {row['smiles']} due to error: {e}")
# Collate into dataset
self.data, self.slices = self.collate(data_list)
def get_graph_datasets(token):
"""returns an InMemoryDataset that can be used in dataloaders
Args:
filepath (str): the filepath of the data csv
Returns:
Tox21Dataset: dataset for dataloaders
"""
datasets = get_tox21_split(token, cvfold=4)
train_df, val_df = datasets["train"], datasets["validation"]
train_dataset = Tox21Dataset(train_df)
val_dataset = Tox21Dataset(val_df)
return train_dataset, val_dataset |