import torch import csv import subprocess from src.preprocess import create_clean_smiles def predict(smiles_list): """ Predict toxicity targets for a list of SMILES strings. Args: smiles_list (list[str]): SMILES strings Returns: dict: {smiles: {target_name: prediction_prob}} """ # clean smiles clean_smiles, valid_mask = create_clean_smiles(smiles_list) # Mapping from cleaned to original for valid ones originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok] # sanity check (optional but nice to have) if len(originals_valid) != len(clean_smiles): raise ValueError( f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES" ) # map cleaned → original cleaned_to_original = dict(zip(clean_smiles, originals_valid)) # tox21 targets TARGET_NAMES = [ "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53" ] DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Received {len(smiles_list)} SMILES strings") # put smiles into csv with open("./data/smiles.csv", "w", newline="") as f: writer = csv.writer(f) writer.writerow(["smiles"]) # header for smi in clean_smiles: writer.writerow([smi]) # predict command = [ "chemprop", "predict", "--test-path", "data/smiles.csv", "--model-path", "checkpoints/model.pt", "--smiles-columns", "smiles", "--preds-path", "data/preds.csv" ] # Run the command subprocess.run(command, check=True) # create results dictionary from predictions csv_path = "./data/preds.csv" predictions = {} with open(csv_path, "r", newline="") as f: reader = csv.DictReader(f) rows = list(reader) target_names = [col for col in reader.fieldnames if col != "smiles"] for row in rows: clean_smi = row["smiles"] original_smi = cleaned_to_original.get(clean_smi, clean_smi) pred_dict = {t: float(row[t]) for t in target_names} predictions[original_smi] = pred_dict # Add placeholder predictions for invalid SMILES for smi, is_valid in zip(smiles_list, valid_mask): if not is_valid: predictions[smi] = {t: 0.5 for t in TARGET_NAMES} return predictions