Spaces:
Running
Running
| import gradio as gr | |
| import argparse | |
| import os | |
| from UltraFlow.models.sbap import * | |
| from UltraFlow import commons | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| model_dir = './workdir/gradio/' | |
| checkpoint = 'checkpointbest_valid_1.ckp' | |
| total_num = 0 | |
| def get_config(model_dir): | |
| # get config | |
| config = commons.get_config_easydict(os.path.join(model_dir, 'affinity_default.yaml')) | |
| # get device | |
| # config.device = commons.get_device(config.train.gpus, config.train.gpu_memory_need) | |
| config.device = 'cpu' | |
| # set random seed | |
| commons.set_seed(config.seed) | |
| return config | |
| def load_graph_dim(lig_graph, prot_graph, model_config): | |
| lig_node_dim = lig_graph.ndata['h'].shape[1] | |
| lig_edge_dim = lig_graph.edata['e'].shape[1] | |
| if model_config.data.add_chemical_bond_feats: | |
| lig_edge_dim += lig_graph.edata['bond_type'].shape[1] | |
| pro_node_dim = prot_graph.ndata['h'].shape[1] | |
| pro_edge_dim = prot_graph.edata['e'].shape[1] | |
| inter_edge_dim = 15 | |
| if model_config.data.use_mean_node_features: | |
| lig_node_dim += 5 | |
| pro_node_dim += 5 | |
| return lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim | |
| def trans_device(data, device): | |
| return [x if isinstance(x, list) else x.to(device) for x in data] | |
| def get_data(model_config, ligand_path, protein_path): | |
| molecular_representation = commons.read_molecules_inference(ligand_path, protein_path, | |
| model_config.data.prot_graph_type, | |
| model_config.data.chaincut) | |
| lig_coords, lig_features, lig_edges, lig_node_type, \ | |
| prot_coords, prot_features, prot_edges, prot_node_type, \ | |
| sec_features, alpha_c_coords, c_coords, n_coords, ca_res_number_valid, chain_index_valid = molecular_representation | |
| lig_graph = commons.get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type, | |
| max_neighbors=model_config.data.lig_max_neighbors, | |
| cutoff=model_config.data.ligcut) | |
| prot_graph = commons.get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type, | |
| sec_features, alpha_c_coords, c_coords, n_coords, | |
| max_neighbor=model_config.data.prot_max_neighbors, | |
| cutoff=model_config.data.protcut) | |
| prot_graph.ndata['res_number'] = torch.tensor(ca_res_number_valid) | |
| prot_graph.chain_index = chain_index_valid | |
| inter_graph = commons.get_interact_graph_knn_v2(lig_coords, prot_coords, | |
| max_neighbor=model_config.data.inter_max_neighbors, | |
| min_neighbor=model_config.data.inter_min_neighbors, | |
| cutoff=model_config.data.intercut) | |
| # set feats dim | |
| lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim = load_graph_dim(lig_graph, prot_graph, model_config) | |
| model_config.model.lig_node_dim, model_config.model.lig_edge_dim = lig_node_dim, lig_edge_dim | |
| model_config.model.pro_node_dim, model_config.model.pro_edge_dim = pro_node_dim, pro_edge_dim | |
| model_config.model.inter_edge_dim = inter_edge_dim | |
| if model_config.data.add_chemical_bond_feats: | |
| lig_graph.edata['e'] = torch.cat([lig_graph.edata['e'], lig_graph.edata['bond_type']], dim=-1) | |
| if model_config.data.use_mean_node_features: | |
| lig_graph.ndata['h'] = torch.cat([lig_graph.ndata['h'], lig_graph.ndata['mu_r_norm']], dim=-1) | |
| prot_graph.ndata['h'] = torch.cat([prot_graph.ndata['h'], prot_graph.ndata['mu_r_norm']], dim=-1) | |
| label = torch.tensor(-100).unsqueeze(dim=-1) | |
| item = [0] | |
| assay_des = torch.zeros(0) | |
| IC50_f, K_f = [True], [True] | |
| data = (lig_graph, prot_graph, inter_graph, label, item, assay_des.unsqueeze(dim=0), IC50_f, K_f) | |
| return trans_device(data, model_config.device) | |
| def get_models(model_config, model_dir, checkpoint): | |
| if model_config.train.multi_task: | |
| model = globals()[model_config.model.model_type + '_MTL'](model_config).to(model_config.device) | |
| else: | |
| model = globals()[model_config.model.model_type](model_config).to(model_config.device) | |
| checkpoint_path = os.path.join(model_dir, checkpoint) | |
| print("Load checkpoint from %s" % checkpoint_path) | |
| state = torch.load(checkpoint_path, map_location=model_config.device) | |
| model.load_state_dict(state["model"]) | |
| model = model.eval() | |
| return model | |
| def mbp_scoring(ligand_path, protein_path): | |
| data_example = get_data(model_config, ligand_path, protein_path) | |
| _, (affinity_pred_IC50, affinity_pred_K), _ = model(data_example, ASRP=False) | |
| return affinity_pred_IC50.item(), affinity_pred_K.item() | |
| def test(ligand, protein): | |
| global total_num | |
| total_num = total_num + 1 | |
| print(f'total num: {total_num}') | |
| try: | |
| IC50, K = mbp_scoring(ligand.name, protein.name) | |
| print(f'ligand file name: {os.path.basename(ligand.name)},' | |
| f' protein file name: {os.path.basename(protein.name)},' | |
| f' IC50: {IC50}, K: {K}') | |
| return '{:.2f}'.format(IC50), '{:.2f}'.format(K) | |
| except Exception as e: | |
| # print(e) | |
| return e, e | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Multi-task Bioassay Pre-training for Protein-Ligand Binding Affinity Prediction | |
| ## Welcome to the MBP demo ! | |
| - Feel free to upload your own examples. Please upload an individual ligand 3D file and an individual protein 3D file each time. | |
| - If you encounter any issues, please reach out to jiaxianyan@mail.ustc.edu.cn. | |
| - All codes and data are available on the online platform https://github.com/jiaxianyan/MBP. | |
| """) | |
| with gr.Row(): | |
| ligand = gr.File(label="Ligand 3D file. MBP utilizes openbabel to process ligand files and supports all file types that openbabel can read.") | |
| protein = gr.File(label="Protein 3D file. Currently, MBP only supports the pdb file type for protein files.") | |
| IC50 = gr.Textbox(label="Predicted IC50 Value") | |
| K = gr.Textbox(label="Predicted K Value") | |
| submit_btn = gr.Button("Submit") | |
| submit_btn.click(fn=test, inputs=[ligand, protein], outputs=[IC50, K], api_name="MBP_Scoring") | |
| gr.Markdown("## Input Examples") | |
| gr.Examples( | |
| examples=[['./workdir/gradio/1a0q_ligand.sdf','./workdir/gradio/1a0q_protein.pdb']], | |
| inputs=[ligand, protein], | |
| # outputs=[IC50, K], | |
| fn=test, | |
| cache_examples=False, | |
| ) | |
| model_config = get_config(model_dir) | |
| data_example = get_data(model_config, './workdir/gradio/1a0q_ligand.sdf', './workdir/gradio/1a0q_protein.pdb') | |
| model = get_models(model_config, model_dir, checkpoint) | |
| demo.launch(share=False) | |