File size: 4,324 Bytes
e517ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e448508
e517ec0
 
 
976e441
e517ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e448508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e517ec0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit import Chem
import numpy as np
import pandas as pd
from datasets import load_dataset

def create_clean_smiles(smiles_list: list[str]) -> tuple[list[str], np.ndarray]:
    """
    Clean and canonicalize SMILES strings while staying in SMILES space.
    Returns (list of cleaned SMILES, mask of valid SMILES).
    """
    clean_smis = []
    valid_mask = []

    cleaner = rdMolStandardize.CleanupParameters()
    tautomer_enumerator = rdMolStandardize.TautomerEnumerator()

    for smi in smiles_list:
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                valid_mask.append(False)
                continue

            # Cleanup and tautomer canonicalization
            mol = rdMolStandardize.Cleanup(mol, cleaner)
            mol = tautomer_enumerator.Canonicalize(mol)

            # Canonical SMILES output
            clean_smi = Chem.MolToSmiles(mol, canonical=True)
            clean_smis.append(clean_smi)
            valid_mask.append(True)

        except Exception as e:
            print(f"Failed to clean {smi}: {e}")
            valid_mask.append(False)

    return clean_smis, np.array(valid_mask, dtype=bool)


def clean_smiles_in_csv(input_csv: str, output_csv: str, smiles_col: str = "smiles", target_cols: list[str] | None = None):
    """
    Reads a CSV, cleans SMILES, and saves only valid cleaned rows with all target columns to a new CSV.
    """
    # Load dataset
    df = pd.read_csv(input_csv)
    if smiles_col not in df.columns:
        raise ValueError(f"'{smiles_col}' column not found in CSV.")

    # Infer target columns if not specified
    if target_cols is None:
        target_cols = [c for c in df.columns if c != smiles_col]
    keep_cols = target_cols + ["split"]
    # Validate target columns
    missing_targets = [c for c in target_cols if c not in df.columns]
    if missing_targets:
        raise ValueError(f"Missing target columns in CSV: {missing_targets}")

    # Clean SMILES
    clean_smis, valid_mask = create_clean_smiles(df[smiles_col].tolist())

    # Keep only valid rows
    df_clean = df.loc[valid_mask, keep_cols].copy()
    df_clean.insert(0, smiles_col, clean_smis)  # smiles first column

    # Save cleaned dataset
    df_clean.to_csv(output_csv, index=False)
    print(f"✅ Cleaned dataset saved to '{output_csv}' ({len(df_clean)} valid molecules).")
    return valid_mask


def get_tox21_split(token, cvfold=None):
    ds = load_dataset("ml-jku/tox21", token=token)
    
    train_df = ds["train"].to_pandas()
    val_df = ds["validation"].to_pandas()

    if cvfold is None:
        return {
            "train": train_df,
            "validation": val_df
        }
    
    combined_df = pd.concat([train_df, val_df], ignore_index=True)
    cvfold = float(cvfold)

    # create new splits
    cvfold = float(cvfold)
    train_df = combined_df[combined_df.CVfold != cvfold]
    val_df = combined_df[combined_df.CVfold == cvfold]

    # exclude train mols that occur in the validation split
    val_inchikeys = set(val_df["inchikey"])
    train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)]

    return {"train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True)}


def get_combined_dataset_csv(token, save_path):
    datasets = get_tox21_split(token, cvfold=4)
    train_df, val_df = datasets["train"], datasets["validation"]
    test_df = val_df.copy()
    # Add split column
    train_df["split"] = "train"
    val_df["split"] = "val"
    test_df["split"] = "test"

    # Combine all into one DataFrame
    combined_df = pd.concat([train_df, val_df, test_df], ignore_index=True)

    # Save to a new CSV
    combined_df.to_csv(save_path, index=False)

def get_combined_dataset_with_testset_csv(token, save_path, testset_path):
    datasets = get_tox21_split(token, cvfold=4)
    train_df, val_df = datasets["train"], datasets["validation"]
    test_df = pd.read_csv(testset_path)
    # Add split column
    train_df["split"] = "train"
    val_df["split"] = "val"
    test_df["split"] = "test"

    # Combine all into one DataFrame
    combined_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
    
    # Save to a new CSV
    combined_df.to_csv(save_path, index=False)