Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import json | |
| import os | |
| from gradio_molecule3d import Molecule3D # Import Molecule3D component | |
| from rdkit import Chem | |
| from rdkit.Chem import Descriptors, Draw, QED | |
| class VirtualScreeningBOApp: | |
| def __init__( | |
| self, | |
| ligands, | |
| initial_pairs, | |
| protein_pdb_path, # Hardcoded PDB file path | |
| max_iterations=3, | |
| comparisons_per_iteration=2, | |
| show_smiles=True # <--- Added argument | |
| ): | |
| self.ligands = ligands | |
| self.current_pairs = initial_pairs | |
| self.completed_pairs = {} | |
| self.comparison_results = [] | |
| self.bo_iteration = 0 | |
| self.is_completed = False | |
| self.max_iterations = max_iterations | |
| self.comparisons_per_iteration = comparisons_per_iteration | |
| self.protein_pdb_path = protein_pdb_path # Store PDB path | |
| self.protein_pdb_data = self._read_pdb_file() # Read PDB data | |
| self.show_smiles = show_smiles # <--- Store argument | |
| self.app = None | |
| def _read_pdb_file(self): | |
| """Read the PDB file from the hardcoded path.""" | |
| try: | |
| with open(self.protein_pdb_path, 'r') as f: | |
| pdb_data = f.read() | |
| return pdb_data | |
| except FileNotFoundError: | |
| print(f"Error: Protein PDB file not found at {self.protein_pdb_path}") | |
| return None | |
| def _iteration_status(self): | |
| """Return a string like 'Iteration 1/3' (1-based for user display).""" | |
| return f"**Iteration**: {self.bo_iteration + 1}/{self.max_iterations}" | |
| def compute_properties(self, smiles): | |
| """Compute basic properties from RDKit.""" | |
| if Chem is None or smiles is None: | |
| return { | |
| "SMILES": smiles if smiles else "N/A", | |
| "MW": None, | |
| "LogP": None, | |
| "TPSA": None, | |
| "QED": None, | |
| } | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return { | |
| "SMILES": "Invalid SMILES", | |
| "MW": None, | |
| "LogP": None, | |
| "TPSA": None, | |
| "QED": None, | |
| } | |
| return { | |
| "SMILES": smiles, | |
| "MW": round(Descriptors.MolWt(mol), 2), | |
| "LogP": round(Descriptors.MolLogP(mol), 2), | |
| "TPSA": round(Descriptors.TPSA(mol), 2), | |
| "QED": round(QED.qed(mol), 2), | |
| } | |
| def _mol_to_image(self, ligand_name): | |
| """Create a 300x300 image from the SMILES. If RDKit not present, returns None.""" | |
| if Chem is None: | |
| return None | |
| smiles = self.ligands.get(ligand_name, "") | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return None | |
| return np.array(Draw.MolToImage(mol, size=(300, 300))) | |
| def _generate_new_pairs(self, n): | |
| """Randomly pick n pairs from the ligand dictionary.""" | |
| keys = list(self.ligands.keys()) | |
| pairs = [] | |
| for _ in range(n): | |
| a = random.choice(keys) | |
| b = random.choice(keys) | |
| while b == a: | |
| b = random.choice(keys) | |
| pairs.append((a, b)) | |
| return pairs | |
| def _save_results(self): | |
| """Saves iteration results to JSON.""" | |
| filename = f"comparison_results_iter_{self.bo_iteration}.json" | |
| with open(filename, "w") as f: | |
| json.dump(self.comparison_results, f, indent=4) | |
| print(f"Results of iteration {self.bo_iteration} saved to {filename}") | |
| def get_pair_index(self, pair_label): | |
| """Parse 'Pair X (Pending)' or 'Pair X ✔' => integer X-1 (zero-based).""" | |
| try: | |
| parts = pair_label.split() | |
| idx = int(parts[1]) - 1 | |
| return idx | |
| except (IndexError, ValueError): | |
| return 0 | |
| # -------------------------------------------------------------------------- | |
| # Gradio event methods | |
| # -------------------------------------------------------------------------- | |
| def show_initial(self): | |
| iteration_str = self._iteration_status() | |
| if self.bo_iteration >= self.max_iterations: | |
| return self.finish_bo_process() | |
| if not self.current_pairs: | |
| return ( | |
| iteration_str, | |
| "No pairs available.", | |
| None, | |
| None, | |
| gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]), | |
| gr.update(choices=[], value=""), # Set dropdown to empty | |
| gr.update(), | |
| gr.update(), | |
| ) | |
| # Build updated labels | |
| updated_labels = [] | |
| for i, pair in enumerate(self.current_pairs): | |
| if pair in self.completed_pairs: | |
| updated_labels.append(f"Pair {i+1} ✔") | |
| else: | |
| updated_labels.append(f"Pair {i+1} (Pending)") | |
| default_label = updated_labels[0] | |
| ligandA_id, ligandB_id = self.current_pairs[0] | |
| imgA = self._mol_to_image(ligandA_id) | |
| imgB = self._mol_to_image(ligandB_id) | |
| propsA = self.compute_properties(self.ligands[ligandA_id]) | |
| propsB = self.compute_properties(self.ligands[ligandB_id]) | |
| if self.show_smiles: | |
| table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"] | |
| table_data = [ | |
| ["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
| ["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
| ] | |
| else: | |
| table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"] | |
| table_data = [ | |
| ["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
| ["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
| ] | |
| if (ligandA_id, ligandB_id) in self.completed_pairs: | |
| arrow = ">" if self.completed_pairs[(ligandA_id, ligandB_id)] == 1 else "<" | |
| else: | |
| arrow = "vs" | |
| pair_label_str = default_label | |
| current_selection_msg = ( | |
| f"**Currently selected**: {pair_label_str} => **Ligand A** {arrow} **Ligand B**" | |
| ) | |
| return ( | |
| iteration_str, | |
| current_selection_msg, | |
| imgA, | |
| imgB, | |
| gr.update(value=table_data, headers=table_headers), | |
| gr.update(choices=updated_labels, value=default_label), | |
| gr.update(), | |
| gr.update(), | |
| ) | |
| def update_view_on_dropdown(self, pair_label): | |
| iteration_str = self._iteration_status() | |
| if self.bo_iteration >= self.max_iterations: | |
| return self.finish_bo_process() | |
| if not pair_label or pair_label.strip() == "": | |
| return ( | |
| self._iteration_status(), | |
| "Please select a valid pair", | |
| self._mol_to_image(list(self.ligands.keys())[0]), # Show first ligand | |
| self._mol_to_image(list(self.ligands.keys())[1]), # Show second ligand | |
| gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]), | |
| ) | |
| # If pair_label is "" or None, handle gracefully | |
| # if not pair_label: | |
| # return ( | |
| # iteration_str, | |
| # "No pair selected or invalid selection!", | |
| # None, | |
| # None, | |
| # gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]), | |
| # ) | |
| idx = self.get_pair_index(pair_label) | |
| if idx < 0 or idx >= len(self.current_pairs): | |
| return ( | |
| iteration_str, | |
| f"Invalid pair: {pair_label}", | |
| None, | |
| None, | |
| gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]), | |
| ) | |
| ligandA_id, ligandB_id = self.current_pairs[idx] | |
| pair_done = (ligandA_id, ligandB_id) in self.completed_pairs | |
| print(f"completed_pairs: {self.completed_pairs}") | |
| print(f"current_pairs: {self.current_pairs}") | |
| if (ligandA_id, ligandB_id) in self.completed_pairs: | |
| arrow = ">" if self.completed_pairs[(ligandA_id, ligandB_id)] == 1 else "<" | |
| else: | |
| arrow = "vs" | |
| imgA = self._mol_to_image(ligandA_id) | |
| imgB = self._mol_to_image(ligandB_id) | |
| propsA = self.compute_properties(self.ligands[ligandA_id]) | |
| propsB = self.compute_properties(self.ligands[ligandB_id]) | |
| if self.show_smiles: | |
| table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"] | |
| table_data = [ | |
| ["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
| ["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
| ] | |
| else: | |
| table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"] | |
| table_data = [ | |
| ["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
| ["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
| ] | |
| label_symbol = "✔" if pair_done else "(Pending)" | |
| pair_label_str = f"Pair {idx+1} {label_symbol}" | |
| current_selection_msg = ( | |
| f"**Currently selected**: {pair_label_str} => **Ligand A** {arrow} **Ligand B**" | |
| ) | |
| return ( | |
| iteration_str, | |
| current_selection_msg, | |
| imgA, | |
| imgB, | |
| gr.update(value=table_data, headers=table_headers), | |
| ) | |
| def get_pair_index(self, pair_label): | |
| """Parse 'Pair X (Pending)' or 'Pair X ✔' => integer X-1 (zero-based).""" | |
| if not pair_label: | |
| return -1 # Return an invalid index for `None` | |
| try: | |
| parts = pair_label.split() | |
| idx = int(parts[1]) - 1 | |
| return idx | |
| except (IndexError, ValueError): | |
| return -1 | |
| def show_pair(self, preference, pair_label): | |
| iteration_str = self._iteration_status() | |
| idx = self.get_pair_index(pair_label) | |
| if idx < 0 or idx >= len(self.current_pairs): | |
| idx = 0 | |
| ligandA_id, ligandB_id = self.current_pairs[idx] | |
| self.comparison_results.append({ | |
| "Iteration": self.bo_iteration, | |
| "Pair": (ligandA_id, ligandB_id), | |
| "Preference": preference, | |
| }) | |
| print(f"Logged preference: Iter={self.bo_iteration}, Pair=({ligandA_id}, {ligandB_id}), Choice={preference}") | |
| if preference == "Ligand A": | |
| self.completed_pairs[(ligandA_id, ligandB_id)] = 1 | |
| old_pair_str = "**Ligand A** > **Ligand B**" | |
| else: | |
| self.completed_pairs[(ligandA_id, ligandB_id)] = 0 | |
| old_pair_str = "**Ligand B** > **Ligand A**" | |
| updated_labels = [] | |
| for i, p in enumerate(self.current_pairs): | |
| if p in self.completed_pairs: | |
| updated_labels.append(f"Pair {i+1} ✔") | |
| else: | |
| updated_labels.append(f"Pair {i+1} (Pending)") | |
| next_idx = None | |
| for i, p in enumerate(self.current_pairs): | |
| if p not in self.completed_pairs: | |
| next_idx = i | |
| break | |
| if next_idx is not None: | |
| nextA_id, nextB_id = self.current_pairs[next_idx] | |
| imgA = self._mol_to_image(nextA_id) | |
| imgB = self._mol_to_image(nextB_id) | |
| propsA = self.compute_properties(self.ligands[nextA_id]) | |
| propsB = self.compute_properties(self.ligands[nextB_id]) | |
| if self.show_smiles: | |
| table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"] | |
| table_data = [ | |
| ["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
| ["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
| ] | |
| else: | |
| table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"] | |
| table_data = [ | |
| ["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
| ["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
| ] | |
| next_label = updated_labels[next_idx] | |
| current_selection_msg = ( | |
| f"**Currently selected**: {next_label} => **Ligand A** vs **Ligand B**" | |
| ) | |
| bo_btn_state = False | |
| dropdown_val = next_label | |
| else: | |
| current_selection_msg = ( | |
| f"**Currently selected**: {pair_label} => **Ligand A** vs **Ligand B**" | |
| ) | |
| imgA = self._mol_to_image(ligandA_id) | |
| imgB = self._mol_to_image(ligandB_id) | |
| propsA = self.compute_properties(self.ligands[ligandA_id]) | |
| propsB = self.compute_properties(self.ligands[ligandB_id]) | |
| if self.show_smiles: | |
| table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"] | |
| table_data = [ | |
| ["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
| ["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
| ] | |
| else: | |
| table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"] | |
| table_data = [ | |
| ["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
| ["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
| ] | |
| bo_btn_state = True | |
| dropdown_val = updated_labels[-1] if updated_labels else "" | |
| selection_msg = ( | |
| f"You chose {old_pair_str} for {pair_label}.<br>" | |
| f"{current_selection_msg}" | |
| ) | |
| print(selection_msg) | |
| return ( | |
| iteration_str, | |
| selection_msg, | |
| imgA, | |
| imgB, | |
| gr.update(value=table_data, headers=table_headers), | |
| gr.update(choices=updated_labels, value=dropdown_val), | |
| dropdown_val, | |
| gr.update(interactive=bo_btn_state), | |
| ) | |
| def start_bo_iteration(self): | |
| iteration_str = self._iteration_status() | |
| if self.bo_iteration >= self.max_iterations - 1: | |
| self.is_completed = True | |
| return self.finish_bo_process() | |
| # Ensure all pairs are completed before proceeding | |
| if len(self.completed_pairs) < len(self.current_pairs): | |
| return ( | |
| iteration_str, | |
| "Please complete all pairs first!", # Notify user | |
| None, | |
| None, | |
| gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(interactive=False, value="Next BO Iteration"), # Keep button disabled | |
| ) | |
| # Save results and increment the iteration | |
| self._save_results() | |
| self.bo_iteration += 1 | |
| iteration_str = self._iteration_status() | |
| # Check if the BO process is complete | |
| if self.bo_iteration == self.max_iterations: | |
| final_message = """ | |
| <div style="text-align: center; font-size: 24px; margin-top: 50px;"> | |
| <strong>The BO process has been completed.</strong><br> | |
| Thank you for your input! | |
| </div> | |
| """ | |
| return ( | |
| None, # Clear iteration status | |
| gr.update(value=final_message), # Display final message | |
| None, # Hide Ligand A image | |
| None, # Hide Ligand B image | |
| None, # Clear the properties table | |
| gr.update(choices=[], value=""), # Clear dropdown | |
| None, # Clear dropdown value | |
| gr.update(interactive=False), | |
| ) | |
| # Generate new pairs for the next iteration | |
| new_pairs = self._generate_new_pairs(self.comparisons_per_iteration) | |
| self.current_pairs = new_pairs | |
| self.completed_pairs = {} | |
| updated_labels = [f"Pair {i+1} (Pending)" for i in range(len(new_pairs))] | |
| default_val = updated_labels[0] if updated_labels else "" | |
| if new_pairs: | |
| ligandA_id, ligandB_id = new_pairs[0] | |
| imgA = self._mol_to_image(ligandA_id) | |
| imgB = self._mol_to_image(ligandB_id) | |
| propsA = self.compute_properties(self.ligands[ligandA_id]) | |
| propsB = self.compute_properties(self.ligands[ligandB_id]) | |
| table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"] | |
| table_data = [ | |
| ["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
| ["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
| ] | |
| msg = ( | |
| f"Starting iteration {self.bo_iteration}/{self.max_iterations} with {len(new_pairs)} new pairs.<br>" | |
| f"**Currently selected**: {default_val} => **Ligand A** vs **Ligand B**" | |
| ) | |
| else: | |
| msg = f"Starting iteration {self.bo_iteration}/{self.max_iterations}, but no new pairs?" | |
| imgA = None | |
| imgB = None | |
| table_data = [] | |
| return ( | |
| iteration_str, | |
| msg, | |
| imgA, | |
| imgB, | |
| gr.update(value=table_data, headers=table_headers), | |
| gr.update(choices=updated_labels, value=default_val), | |
| default_val, | |
| gr.update(interactive=False), | |
| gr.update(visible=True), # Submit button | |
| gr.update(visible=True), # Preference radio | |
| ) | |
| def finish_bo_process(self): | |
| self.is_completed = True | |
| final_message = """ | |
| <div style="text-align: center; font-size: 24px; margin-top: 50px;"> | |
| <strong>The BO process has been completed.</strong><br> | |
| Thank you for your input! | |
| </div> | |
| """ | |
| self.protein_view.delete() | |
| self.protein_view.visible = False | |
| self.protein_view.showviewer = False | |
| del self.protein_view | |
| return ( | |
| "", # Clear iteration status | |
| final_message, # Show completion message | |
| gr.update(visible=False), # Clear Ligand A image | |
| gr.update(visible=False), # Clear Ligand B image | |
| gr.update(visible=False), # Hide properties table | |
| gr.update(visible=False), # Hide dropdown | |
| None, # Clear dropdown value | |
| gr.update(visible=False), # Hide BO button | |
| gr.update(visible=False), # Hide submit button | |
| gr.update(visible=False), # Hide preference radio | |
| ) | |
| def build_app(self): | |
| with gr.Blocks() as app: | |
| gr.Markdown("## Virtual Screening BO App") | |
| iteration_status_text = gr.Markdown(value="Loading...", label="Iteration Status") | |
| current_selection_text = gr.HTML(value="Initializing...", label="Current Selection") | |
| with gr.Row(): | |
| pair_dropdown = gr.Dropdown( | |
| label="Select a Pair", | |
| allow_custom_value=True, | |
| interactive=True, | |
| ) | |
| preference_radio = gr.Radio( | |
| ["Ligand A", "Ligand B"], | |
| label="Your preference", | |
| value="Ligand A" | |
| ) | |
| bo_btn = gr.Button(value="Next BO Iteration", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| out_imgA = gr.Image(label="Ligand A", width=650, height=325) | |
| out_imgB = gr.Image(label="Ligand B", width=650, height=325) | |
| with gr.Column(): | |
| if self.protein_pdb_data: | |
| self.protein_view = Molecule3D( | |
| label="Protein Structure", | |
| reps=[ | |
| { | |
| "model": 0, | |
| "chain": "", | |
| "resname": "", | |
| "style": "cartoon", | |
| "color": "spectrum", | |
| "residue_range": "", | |
| "around": 0, | |
| "byres": False, | |
| "visible": True | |
| } | |
| ], | |
| # Pass the PDB file path, not the data | |
| value=self.protein_pdb_path, | |
| ) | |
| else: | |
| # If PDB data not found, display a message | |
| self.protein_view = gr.Markdown("**Protein PDB file not found.**") | |
| out_table = gr.Dataframe( | |
| headers=["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"], | |
| label="Properties" | |
| ) | |
| submit_btn = gr.Button("Submit Preference") | |
| # Event: When the app loads, show the initial view | |
| app.load( | |
| fn=self.show_initial, | |
| inputs=None, | |
| outputs=[ | |
| iteration_status_text, | |
| current_selection_text, | |
| out_imgA, | |
| out_imgB, | |
| out_table, | |
| pair_dropdown, | |
| preference_radio, | |
| bo_btn, | |
| # Protein visualization is static; no need to output | |
| ], | |
| ) | |
| # Event: When the dropdown changes, update the view | |
| pair_dropdown.change( | |
| fn=self.update_view_on_dropdown, | |
| inputs=pair_dropdown, | |
| outputs=[ | |
| iteration_status_text, | |
| current_selection_text, | |
| out_imgA, | |
| out_imgB, | |
| out_table, | |
| # Protein visualization is static; no need to output | |
| ] | |
| ) | |
| bo_btn.click( | |
| fn=self.start_bo_iteration, | |
| inputs=[], | |
| outputs=[ | |
| iteration_status_text, | |
| current_selection_text, | |
| out_imgA, | |
| out_imgB, | |
| out_table, | |
| pair_dropdown, | |
| pair_dropdown, | |
| bo_btn, | |
| submit_btn, # Add submit button control | |
| preference_radio # Add preference radio control | |
| ] | |
| ) | |
| # Event: When the submit button is clicked | |
| submit_btn.click( | |
| fn=self.show_pair, | |
| inputs=[preference_radio, pair_dropdown], | |
| outputs=[ | |
| iteration_status_text, | |
| current_selection_text, | |
| out_imgA, | |
| out_imgB, | |
| out_table, | |
| pair_dropdown, | |
| pair_dropdown, | |
| bo_btn, | |
| # Protein visualization is static; no need to output | |
| ], | |
| ) | |
| self.app = app | |
| def launch(self, **kwargs): | |
| if self.app is None: | |
| self.build_app() | |
| self.app.launch(**kwargs) | |
| # --------------------------- | |
| # Example Usage | |
| # --------------------------- | |
| if __name__ == "__main__": | |
| ligands = { | |
| "L1": "CCN(CC)CC(=O)c1ccccc1N(C)C", | |
| "L2": "CC(C)Cc1ccc(cc1)C(C)C(=O)O", | |
| "L3": "Cn1c(=O)c2c(ncnc2N(C)C)n(C)c1=O", | |
| "L4": "CC(=O)Oc1ccccc1C(=O)O", | |
| "L5": "CCCC", | |
| } | |
| initial_pairs = [("L1", "L2"), ("L3", "L4")] | |
| # Hardcoded PDB file path | |
| protein_pdb = "1syn.pdb" # Update this path as needed | |
| # Check if the PDB file exists | |
| if not os.path.isfile(protein_pdb): | |
| print(f"Error: Protein PDB file not found at {protein_pdb}") | |
| exit(1) | |
| app = VirtualScreeningBOApp( | |
| ligands=ligands, | |
| initial_pairs=initial_pairs, | |
| protein_pdb_path=protein_pdb, # Provide the PDB file path | |
| max_iterations=2, | |
| comparisons_per_iteration=2, | |
| show_smiles=False | |
| ) | |
| app.launch(share=True, debug=True) |