from rdkit import Chem from rdkit.Chem.MolStandardize import rdMolStandardize from rdkit import Chem import numpy as np import pandas as pd from datasets import load_dataset def create_clean_smiles(smiles_list: list[str]) -> tuple[list[str], np.ndarray]: """ Clean and canonicalize SMILES strings while staying in SMILES space. Returns (list of cleaned SMILES, mask of valid SMILES). """ clean_smis = [] valid_mask = [] cleaner = rdMolStandardize.CleanupParameters() tautomer_enumerator = rdMolStandardize.TautomerEnumerator() for smi in smiles_list: try: mol = Chem.MolFromSmiles(smi) if mol is None: valid_mask.append(False) continue # Cleanup and tautomer canonicalization mol = rdMolStandardize.Cleanup(mol, cleaner) mol = tautomer_enumerator.Canonicalize(mol) # Canonical SMILES output clean_smi = Chem.MolToSmiles(mol, canonical=True) clean_smis.append(clean_smi) valid_mask.append(True) except Exception as e: print(f"Failed to clean {smi}: {e}") valid_mask.append(False) return clean_smis, np.array(valid_mask, dtype=bool) def clean_smiles_in_csv(input_csv: str, output_csv: str, smiles_col: str = "smiles", target_cols: list[str] | None = None): """ Reads a CSV, cleans SMILES, and saves only valid cleaned rows with all target columns to a new CSV. """ # Load dataset df = pd.read_csv(input_csv) if smiles_col not in df.columns: raise ValueError(f"'{smiles_col}' column not found in CSV.") # Infer target columns if not specified if target_cols is None: target_cols = [c for c in df.columns if c != smiles_col] keep_cols = target_cols + ["split"] # Validate target columns missing_targets = [c for c in target_cols if c not in df.columns] if missing_targets: raise ValueError(f"Missing target columns in CSV: {missing_targets}") # Clean SMILES clean_smis, valid_mask = create_clean_smiles(df[smiles_col].tolist()) # Keep only valid rows df_clean = df.loc[valid_mask, keep_cols].copy() df_clean.insert(0, smiles_col, clean_smis) # smiles first column # Save cleaned dataset df_clean.to_csv(output_csv, index=False) print(f"✅ Cleaned dataset saved to '{output_csv}' ({len(df_clean)} valid molecules).") return valid_mask def get_tox21_split(token, cvfold=None): ds = load_dataset("ml-jku/tox21", token=token) train_df = ds["train"].to_pandas() val_df = ds["validation"].to_pandas() if cvfold is None: return { "train": train_df, "validation": val_df } combined_df = pd.concat([train_df, val_df], ignore_index=True) cvfold = float(cvfold) # create new splits cvfold = float(cvfold) train_df = combined_df[combined_df.CVfold != cvfold] val_df = combined_df[combined_df.CVfold == cvfold] # exclude train mols that occur in the validation split val_inchikeys = set(val_df["inchikey"]) train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)] return {"train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True)} def get_combined_dataset_csv(token, save_path): datasets = get_tox21_split(token, cvfold=4) train_df, val_df = datasets["train"], datasets["validation"] test_df = val_df.copy() # Add split column train_df["split"] = "train" val_df["split"] = "val" test_df["split"] = "test" # Combine all into one DataFrame combined_df = pd.concat([train_df, val_df, test_df], ignore_index=True) # Save to a new CSV combined_df.to_csv(save_path, index=False) def get_combined_dataset_with_testset_csv(token, save_path, testset_path): datasets = get_tox21_split(token, cvfold=4) train_df, val_df = datasets["train"], datasets["validation"] test_df = pd.read_csv(testset_path) # Add split column train_df["split"] = "train" val_df["split"] = "val" test_df["split"] = "test" # Combine all into one DataFrame combined_df = pd.concat([train_df, val_df, test_df], ignore_index=True) # Save to a new CSV combined_df.to_csv(save_path, index=False)