Spaces:
Sleeping
Sleeping
File size: 4,324 Bytes
e517ec0 e448508 e517ec0 976e441 e517ec0 e448508 e517ec0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit import Chem
import numpy as np
import pandas as pd
from datasets import load_dataset
def create_clean_smiles(smiles_list: list[str]) -> tuple[list[str], np.ndarray]:
"""
Clean and canonicalize SMILES strings while staying in SMILES space.
Returns (list of cleaned SMILES, mask of valid SMILES).
"""
clean_smis = []
valid_mask = []
cleaner = rdMolStandardize.CleanupParameters()
tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
for smi in smiles_list:
try:
mol = Chem.MolFromSmiles(smi)
if mol is None:
valid_mask.append(False)
continue
# Cleanup and tautomer canonicalization
mol = rdMolStandardize.Cleanup(mol, cleaner)
mol = tautomer_enumerator.Canonicalize(mol)
# Canonical SMILES output
clean_smi = Chem.MolToSmiles(mol, canonical=True)
clean_smis.append(clean_smi)
valid_mask.append(True)
except Exception as e:
print(f"Failed to clean {smi}: {e}")
valid_mask.append(False)
return clean_smis, np.array(valid_mask, dtype=bool)
def clean_smiles_in_csv(input_csv: str, output_csv: str, smiles_col: str = "smiles", target_cols: list[str] | None = None):
"""
Reads a CSV, cleans SMILES, and saves only valid cleaned rows with all target columns to a new CSV.
"""
# Load dataset
df = pd.read_csv(input_csv)
if smiles_col not in df.columns:
raise ValueError(f"'{smiles_col}' column not found in CSV.")
# Infer target columns if not specified
if target_cols is None:
target_cols = [c for c in df.columns if c != smiles_col]
keep_cols = target_cols + ["split"]
# Validate target columns
missing_targets = [c for c in target_cols if c not in df.columns]
if missing_targets:
raise ValueError(f"Missing target columns in CSV: {missing_targets}")
# Clean SMILES
clean_smis, valid_mask = create_clean_smiles(df[smiles_col].tolist())
# Keep only valid rows
df_clean = df.loc[valid_mask, keep_cols].copy()
df_clean.insert(0, smiles_col, clean_smis) # smiles first column
# Save cleaned dataset
df_clean.to_csv(output_csv, index=False)
print(f"✅ Cleaned dataset saved to '{output_csv}' ({len(df_clean)} valid molecules).")
return valid_mask
def get_tox21_split(token, cvfold=None):
ds = load_dataset("ml-jku/tox21", token=token)
train_df = ds["train"].to_pandas()
val_df = ds["validation"].to_pandas()
if cvfold is None:
return {
"train": train_df,
"validation": val_df
}
combined_df = pd.concat([train_df, val_df], ignore_index=True)
cvfold = float(cvfold)
# create new splits
cvfold = float(cvfold)
train_df = combined_df[combined_df.CVfold != cvfold]
val_df = combined_df[combined_df.CVfold == cvfold]
# exclude train mols that occur in the validation split
val_inchikeys = set(val_df["inchikey"])
train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)]
return {"train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True)}
def get_combined_dataset_csv(token, save_path):
datasets = get_tox21_split(token, cvfold=4)
train_df, val_df = datasets["train"], datasets["validation"]
test_df = val_df.copy()
# Add split column
train_df["split"] = "train"
val_df["split"] = "val"
test_df["split"] = "test"
# Combine all into one DataFrame
combined_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
# Save to a new CSV
combined_df.to_csv(save_path, index=False)
def get_combined_dataset_with_testset_csv(token, save_path, testset_path):
datasets = get_tox21_split(token, cvfold=4)
train_df, val_df = datasets["train"], datasets["validation"]
test_df = pd.read_csv(testset_path)
# Add split column
train_df["split"] = "train"
val_df["split"] = "val"
test_df["split"] = "test"
# Combine all into one DataFrame
combined_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
# Save to a new CSV
combined_df.to_csv(save_path, index=False) |