Sonja Topf
big refactoring
e517ec0
raw
history blame
3.77 kB
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit import Chem
import numpy as np
import pandas as pd
from datasets import load_dataset
def create_clean_smiles(smiles_list: list[str]) -> tuple[list[str], np.ndarray]:
"""
Clean and canonicalize SMILES strings while staying in SMILES space.
Returns (list of cleaned SMILES, mask of valid SMILES).
"""
clean_smis = []
valid_mask = []
cleaner = rdMolStandardize.CleanupParameters()
tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
for smi in smiles_list:
try:
mol = Chem.MolFromSmiles(smi)
if mol is None:
valid_mask.append(False)
continue
# Cleanup and tautomer canonicalization
mol = rdMolStandardize.Cleanup(mol, cleaner)
mol = tautomer_enumerator.Canonicalize(mol)
# Canonical SMILES output
clean_smi = Chem.MolToSmiles(mol, canonical=True)
clean_smis.append(clean_smi)
valid_mask.append(True)
except Exception as e:
print(f"Failed to clean {smi}: {e}")
valid_mask.append(False)
return clean_smis, np.array(valid_mask, dtype=bool)
def clean_smiles_in_csv(input_csv: str, output_csv: str, smiles_col: str = "smiles", target_cols: list[str] | None = None):
"""
Reads a CSV, cleans SMILES, and saves only valid cleaned rows with all target columns to a new CSV.
"""
# Load dataset
df = pd.read_csv(input_csv)
if smiles_col not in df.columns:
raise ValueError(f"'{smiles_col}' column not found in CSV.")
# Infer target columns if not specified
if target_cols is None:
target_cols = [c for c in df.columns if c != smiles_col]
keep_cols = target_cols + ["split"]
# Validate target columns
missing_targets = [c for c in target_cols if c not in df.columns]
if missing_targets:
raise ValueError(f"Missing target columns in CSV: {missing_targets}")
# Clean SMILES
clean_smis, valid_mask = create_clean_smiles(df[smiles_col].tolist())
# Keep only valid rows
df_clean = df.loc[valid_mask, keep_cols].copy()
df_clean.insert(0, smiles_col, clean_smis) # smiles first column
# Save cleaned dataset
df_clean.to_csv(output_csv, index=False)
print(f"✅ Cleaned dataset saved to '{output_csv}' ({len(df_clean)} valid molecules).")
def get_tox21_split(token, cvfold=None):
ds = load_dataset("tschouis/tox21", token=token)
train_df = ds["train"].to_pandas()
val_df = ds["validation"].to_pandas()
if cvfold is None:
return {
"train": train_df,
"validation": val_df
}
combined_df = pd.concat([train_df, val_df], ignore_index=True)
cvfold = float(cvfold)
# create new splits
cvfold = float(cvfold)
train_df = combined_df[combined_df.CVfold != cvfold]
val_df = combined_df[combined_df.CVfold == cvfold]
# exclude train mols that occur in the validation split
val_inchikeys = set(val_df["inchikey"])
train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)]
return {"train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True)}
def get_combined_dataset_csv(token, save_path):
datasets = get_tox21_split(token, cvfold=4)
train_df, val_df = datasets["train"], datasets["validation"]
test_df = val_df.copy()
# Add split column
train_df["split"] = "train"
val_df["split"] = "val"
test_df["split"] = "test"
# Combine all into one DataFrame
combined_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
# Save to a new CSV
combined_df.to_csv(save_path, index=False)