Sonja Topf commited on
Commit
338b5f6
·
1 Parent(s): 4efd766

removed logging

Browse files
Files changed (1) hide show
  1. predict.py +57 -63
predict.py CHANGED
@@ -17,73 +17,67 @@ def predict(smiles_list):
17
  Returns:
18
  dict: {smiles: {target_name: prediction_prob}}
19
  """
20
- logging.basicConfig(level=logging.INFO)
21
- try:
22
- # clean smiles
23
- clean_smiles, valid_mask = create_clean_smiles(smiles_list)
24
-
25
- # Mapping from cleaned to original for valid ones
26
- originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok]
27
-
28
- # sanity check (optional but nice to have)
29
- if len(originals_valid) != len(clean_smiles):
30
- raise ValueError(
31
- f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES"
32
- )
33
-
34
- # map cleaned → original
35
- cleaned_to_original = dict(zip(clean_smiles, originals_valid))
36
- print(len(cleaned_to_original.keys()))
37
-
38
- # tox21 targets
39
- TARGET_NAMES = [
40
- "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
41
- ]
42
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
- print(f"Received {len(smiles_list)} SMILES strings")
44
-
45
- # put smiles into csv
46
- with open("./src/smiles.csv", "w", newline="") as f:
47
- writer = csv.writer(f)
48
- writer.writerow(["smiles"]) # header
49
- for smi in clean_smiles:
50
- writer.writerow([smi])
51
- logging.info("here")
52
- # predict
53
- command = [
54
- "chemprop", "predict",
55
- "--test-path", "src/smiles.csv",
56
- "--model-path", "assets/best1.pt",
57
- "--smiles-columns", "smiles",
58
- "--preds-path", "src/preds.csv"
59
- ]
60
-
61
- # Run the command
62
- subprocess.run(command, check=True)
63
 
64
- # create results dictionary from predictions
 
65
 
66
- csv_path = "./src/preds.csv"
 
 
 
 
67
 
68
- predictions = {}
69
- with open("./src/preds.csv", "r", newline="") as f:
70
- reader = csv.DictReader(f)
71
- rows = list(reader)
72
- target_names = [col for col in reader.fieldnames if col != "smiles"]
73
 
74
- for row in rows:
75
- clean_smi = row["smiles"]
76
- original_smi = cleaned_to_original.get(clean_smi, clean_smi)
77
- pred_dict = {t: float(row[t]) for t in target_names}
78
- predictions[original_smi] = pred_dict
79
-
80
- # Add placeholder predictions for invalid SMILES
81
- for smi, is_valid in zip(smiles_list, valid_mask):
82
- if not is_valid:
83
- predictions[smi] = {t: 0.0 for t in TARGET_NAMES}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- except Exception as e:
86
- logging.error(f"Error: {e}")
87
- return
88
 
89
  return predictions
 
17
  Returns:
18
  dict: {smiles: {target_name: prediction_prob}}
19
  """
20
+ # clean smiles
21
+ clean_smiles, valid_mask = create_clean_smiles(smiles_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Mapping from cleaned to original for valid ones
24
+ originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok]
25
 
26
+ # sanity check (optional but nice to have)
27
+ if len(originals_valid) != len(clean_smiles):
28
+ raise ValueError(
29
+ f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES"
30
+ )
31
 
32
+ # map cleaned → original
33
+ cleaned_to_original = dict(zip(clean_smiles, originals_valid))
 
 
 
34
 
35
+ # tox21 targets
36
+ TARGET_NAMES = [
37
+ "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
38
+ ]
39
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ print(f"Received {len(smiles_list)} SMILES strings")
41
+
42
+ # put smiles into csv
43
+ with open("./src/smiles.csv", "w", newline="") as f:
44
+ writer = csv.writer(f)
45
+ writer.writerow(["smiles"]) # header
46
+ for smi in clean_smiles:
47
+ writer.writerow([smi])
48
+ logging.info("here")
49
+ # predict
50
+ command = [
51
+ "chemprop", "predict",
52
+ "--test-path", "src/smiles.csv",
53
+ "--model-path", "assets/best1.pt",
54
+ "--smiles-columns", "smiles",
55
+ "--preds-path", "src/preds.csv"
56
+ ]
57
+
58
+ # Run the command
59
+ subprocess.run(command, check=True)
60
+
61
+ # create results dictionary from predictions
62
+
63
+ csv_path = "./src/preds.csv"
64
+
65
+ predictions = {}
66
+ with open("./src/preds.csv", "r", newline="") as f:
67
+ reader = csv.DictReader(f)
68
+ rows = list(reader)
69
+ target_names = [col for col in reader.fieldnames if col != "smiles"]
70
+
71
+ for row in rows:
72
+ clean_smi = row["smiles"]
73
+ original_smi = cleaned_to_original.get(clean_smi, clean_smi)
74
+ pred_dict = {t: float(row[t]) for t in target_names}
75
+ predictions[original_smi] = pred_dict
76
+
77
+ # Add placeholder predictions for invalid SMILES
78
+ for smi, is_valid in zip(smiles_list, valid_mask):
79
+ if not is_valid:
80
+ predictions[smi] = {t: 0.0 for t in TARGET_NAMES}
81
 
 
 
 
82
 
83
  return predictions