Sonja Topf commited on
Commit
4efd766
·
1 Parent(s): ad23c3f
Files changed (2) hide show
  1. .gitignore +2 -1
  2. predict.py +14 -12
.gitignore CHANGED
@@ -3,4 +3,5 @@ results.csv
3
  predict copy.py
4
  debug.py
5
  __pycache__
6
- tox21_test.csv
 
 
3
  predict copy.py
4
  debug.py
5
  __pycache__
6
+ tox21_test.csv
7
+ predictions.json
predict.py CHANGED
@@ -3,6 +3,7 @@ import csv
3
  import subprocess
4
  import pandas as pd
5
  import logging
 
6
 
7
  from src.preprocess import create_clean_smiles
8
 
@@ -22,9 +23,17 @@ def predict(smiles_list):
22
  clean_smiles, valid_mask = create_clean_smiles(smiles_list)
23
 
24
  # Mapping from cleaned to original for valid ones
25
- cleaned_to_original = {
26
- clean: orig for clean, orig, valid in zip(clean_smiles, smiles_list, valid_mask) if valid
27
- }
 
 
 
 
 
 
 
 
28
 
29
  # tox21 targets
30
  TARGET_NAMES = [
@@ -59,17 +68,10 @@ def predict(smiles_list):
59
  predictions = {}
60
  with open("./src/preds.csv", "r", newline="") as f:
61
  reader = csv.DictReader(f)
 
62
  target_names = [col for col in reader.fieldnames if col != "smiles"]
63
 
64
- missing = [t for t in TARGET_NAMES if t not in target_names]
65
- extra = [t for t in target_names if t not in TARGET_NAMES]
66
-
67
- if missing:
68
- logging.error(f"❌ Missing target columns in preds.csv: {missing}")
69
- if extra:
70
- logging.warning(f"⚠ Warning: Extra columns in preds.csv not expected: {extra}")
71
-
72
- for row in reader:
73
  clean_smi = row["smiles"]
74
  original_smi = cleaned_to_original.get(clean_smi, clean_smi)
75
  pred_dict = {t: float(row[t]) for t in target_names}
 
3
  import subprocess
4
  import pandas as pd
5
  import logging
6
+ import json
7
 
8
  from src.preprocess import create_clean_smiles
9
 
 
23
  clean_smiles, valid_mask = create_clean_smiles(smiles_list)
24
 
25
  # Mapping from cleaned to original for valid ones
26
+ originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok]
27
+
28
+ # sanity check (optional but nice to have)
29
+ if len(originals_valid) != len(clean_smiles):
30
+ raise ValueError(
31
+ f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES"
32
+ )
33
+
34
+ # map cleaned → original
35
+ cleaned_to_original = dict(zip(clean_smiles, originals_valid))
36
+ print(len(cleaned_to_original.keys()))
37
 
38
  # tox21 targets
39
  TARGET_NAMES = [
 
68
  predictions = {}
69
  with open("./src/preds.csv", "r", newline="") as f:
70
  reader = csv.DictReader(f)
71
+ rows = list(reader)
72
  target_names = [col for col in reader.fieldnames if col != "smiles"]
73
 
74
+ for row in rows:
 
 
 
 
 
 
 
 
75
  clean_smi = row["smiles"]
76
  original_smi = cleaned_to_original.get(clean_smi, clean_smi)
77
  pred_dict = {t: float(row[t]) for t in target_names}