Spaces:
Sleeping
Sleeping
Sonja Topf
commited on
Commit
·
4efd766
1
Parent(s):
ad23c3f
fixed bug
Browse files- .gitignore +2 -1
- predict.py +14 -12
.gitignore
CHANGED
|
@@ -3,4 +3,5 @@ results.csv
|
|
| 3 |
predict copy.py
|
| 4 |
debug.py
|
| 5 |
__pycache__
|
| 6 |
-
tox21_test.csv
|
|
|
|
|
|
| 3 |
predict copy.py
|
| 4 |
debug.py
|
| 5 |
__pycache__
|
| 6 |
+
tox21_test.csv
|
| 7 |
+
predictions.json
|
predict.py
CHANGED
|
@@ -3,6 +3,7 @@ import csv
|
|
| 3 |
import subprocess
|
| 4 |
import pandas as pd
|
| 5 |
import logging
|
|
|
|
| 6 |
|
| 7 |
from src.preprocess import create_clean_smiles
|
| 8 |
|
|
@@ -22,9 +23,17 @@ def predict(smiles_list):
|
|
| 22 |
clean_smiles, valid_mask = create_clean_smiles(smiles_list)
|
| 23 |
|
| 24 |
# Mapping from cleaned to original for valid ones
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# tox21 targets
|
| 30 |
TARGET_NAMES = [
|
|
@@ -59,17 +68,10 @@ def predict(smiles_list):
|
|
| 59 |
predictions = {}
|
| 60 |
with open("./src/preds.csv", "r", newline="") as f:
|
| 61 |
reader = csv.DictReader(f)
|
|
|
|
| 62 |
target_names = [col for col in reader.fieldnames if col != "smiles"]
|
| 63 |
|
| 64 |
-
|
| 65 |
-
extra = [t for t in target_names if t not in TARGET_NAMES]
|
| 66 |
-
|
| 67 |
-
if missing:
|
| 68 |
-
logging.error(f"❌ Missing target columns in preds.csv: {missing}")
|
| 69 |
-
if extra:
|
| 70 |
-
logging.warning(f"⚠ Warning: Extra columns in preds.csv not expected: {extra}")
|
| 71 |
-
|
| 72 |
-
for row in reader:
|
| 73 |
clean_smi = row["smiles"]
|
| 74 |
original_smi = cleaned_to_original.get(clean_smi, clean_smi)
|
| 75 |
pred_dict = {t: float(row[t]) for t in target_names}
|
|
|
|
| 3 |
import subprocess
|
| 4 |
import pandas as pd
|
| 5 |
import logging
|
| 6 |
+
import json
|
| 7 |
|
| 8 |
from src.preprocess import create_clean_smiles
|
| 9 |
|
|
|
|
| 23 |
clean_smiles, valid_mask = create_clean_smiles(smiles_list)
|
| 24 |
|
| 25 |
# Mapping from cleaned to original for valid ones
|
| 26 |
+
originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok]
|
| 27 |
+
|
| 28 |
+
# sanity check (optional but nice to have)
|
| 29 |
+
if len(originals_valid) != len(clean_smiles):
|
| 30 |
+
raise ValueError(
|
| 31 |
+
f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# map cleaned → original
|
| 35 |
+
cleaned_to_original = dict(zip(clean_smiles, originals_valid))
|
| 36 |
+
print(len(cleaned_to_original.keys()))
|
| 37 |
|
| 38 |
# tox21 targets
|
| 39 |
TARGET_NAMES = [
|
|
|
|
| 68 |
predictions = {}
|
| 69 |
with open("./src/preds.csv", "r", newline="") as f:
|
| 70 |
reader = csv.DictReader(f)
|
| 71 |
+
rows = list(reader)
|
| 72 |
target_names = [col for col in reader.fieldnames if col != "smiles"]
|
| 73 |
|
| 74 |
+
for row in rows:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
clean_smi = row["smiles"]
|
| 76 |
original_smi = cleaned_to_original.get(clean_smi, clean_smi)
|
| 77 |
pred_dict = {t: float(row[t]) for t in target_names}
|