Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import time | |
| import json | |
| import gradio as gr | |
| from gradio_molecule3d import Molecule3D | |
| import torch | |
| from pinder.core import get_pinder_location | |
| get_pinder_location() | |
| from pytorch_lightning import LightningModule | |
| import torch | |
| import lightning.pytorch as pl | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import torchmetrics | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import MessagePassing | |
| from torch_geometric.nn import global_mean_pool | |
| from torch.nn import Sequential, Linear, BatchNorm1d, ReLU | |
| from torch_scatter import scatter | |
| from torch.nn import Module | |
| import pinder.core as pinder | |
| pinder.__version__ | |
| from torch_geometric.loader import DataLoader | |
| from pinder.core.loader.dataset import get_geo_loader | |
| from pinder.core import download_dataset | |
| from pinder.core import get_index | |
| from pinder.core import get_metadata | |
| from pathlib import Path | |
| import pandas as pd | |
| from pinder.core import PinderSystem | |
| import torch | |
| from pinder.core.loader.dataset import PPIDataset | |
| from pinder.core.loader.geodata import NodeRepresentation | |
| import pickle | |
| from pinder.core import get_index, PinderSystem | |
| from torch_geometric.data import HeteroData | |
| import os | |
| from enum import Enum | |
| import numpy as np | |
| import torch | |
| import lightning.pytorch as pl | |
| from numpy.typing import NDArray | |
| from torch_geometric.data import HeteroData | |
| from pinder.core.index.system import PinderSystem | |
| from pinder.core.loader.structure import Structure | |
| from pinder.core.utils import constants as pc | |
| from pinder.core.utils.log import setup_logger | |
| from pinder.core.index.system import _align_monomers_with_mask | |
| from pinder.core.loader.structure import Structure | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import MessagePassing | |
| from torch_geometric.nn import global_mean_pool | |
| from torch.nn import Sequential, Linear, BatchNorm1d, ReLU | |
| from torch_scatter import scatter | |
| from torch.nn import Module | |
| import time | |
| from torch_geometric.nn import global_max_pool | |
| import copy | |
| import inspect | |
| import warnings | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| from torch import Tensor | |
| from torch_geometric.data import Data, Dataset, HeteroData | |
| from torch_geometric.data.feature_store import FeatureStore | |
| from torch_geometric.data.graph_store import GraphStore | |
| from torch_geometric.loader import ( | |
| LinkLoader, | |
| LinkNeighborLoader, | |
| NeighborLoader, | |
| NodeLoader, | |
| ) | |
| from torch_geometric.loader.dataloader import DataLoader | |
| from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes | |
| from torch_geometric.sampler import BaseSampler, NeighborSampler | |
| from torch_geometric.typing import InputEdges, InputNodes | |
| try: | |
| from lightning.pytorch import LightningDataModule as PLLightningDataModule | |
| no_pytorch_lightning = False | |
| except (ImportError, ModuleNotFoundError): | |
| PLLightningDataModule = object | |
| no_pytorch_lightning = True | |
| from lightning.pytorch.callbacks import ModelCheckpoint | |
| from lightning.pytorch.loggers.tensorboard import TensorBoardLogger | |
| from lightning.pytorch.callbacks.early_stopping import EarlyStopping | |
| from torch_geometric.data.lightning.datamodule import LightningDataset | |
| from pytorch_lightning.loggers.wandb import WandbLogger | |
| def get_system(system_id: str) -> PinderSystem: | |
| return PinderSystem(system_id) | |
| from Bio import PDB | |
| from Bio.PDB.PDBIO import PDBIO | |
| # To create dataset, we have used only PINDER datyaset with following steps as follows: | |
| log = setup_logger(__name__) | |
| try: | |
| from torch_cluster import knn_graph | |
| torch_cluster_installed = True | |
| except ImportError as e: | |
| log.warning( | |
| "torch-cluster is not installed!" | |
| "Please install the appropriate library for your pytorch installation." | |
| "See https://github.com/rusty1s/pytorch_cluster/issues/185 for background." | |
| ) | |
| torch_cluster_installed = False | |
| def structure2tensor( | |
| atom_coordinates: NDArray[np.double] | None = None, | |
| atom_types: NDArray[np.str_] | None = None, | |
| element_types: NDArray[np.str_] | None = None, | |
| residue_coordinates: NDArray[np.double] | None = None, | |
| residue_ids: NDArray[np.int_] | None = None, | |
| residue_types: NDArray[np.str_] | None = None, | |
| chain_ids: NDArray[np.str_] | None = None, | |
| dtype: torch.dtype = torch.float32, | |
| ) -> dict[str, torch.Tensor]: | |
| property_dict = {} | |
| if atom_types is not None: | |
| unknown_name_idx = max(pc.ALL_ATOM_POSNS.values()) + 1 | |
| types_array_at = np.zeros((len(atom_types), 1)) | |
| for i, name in enumerate(atom_types): | |
| types_array_at[i] = pc.ALL_ATOM_POSNS.get(name, unknown_name_idx) | |
| property_dict["atom_types"] = torch.tensor(types_array_at).type(dtype) | |
| if element_types is not None: | |
| types_array_ele = np.zeros((len(element_types), 1)) | |
| for i, name in enumerate(element_types): | |
| types_array_ele[i] = pc.ELE2NUM.get(name, pc.ELE2NUM["other"]) | |
| property_dict["element_types"] = torch.tensor(types_array_ele).type(dtype) | |
| if residue_types is not None: | |
| unknown_name_idx = max(pc.AA_TO_INDEX.values()) + 1 | |
| types_array_res = np.zeros((len(residue_types), 1)) | |
| for i, name in enumerate(residue_types): | |
| types_array_res[i] = pc.AA_TO_INDEX.get(name, unknown_name_idx) | |
| property_dict["residue_types"] = torch.tensor(types_array_res).type(dtype) | |
| if atom_coordinates is not None: | |
| property_dict["atom_coordinates"] = torch.tensor(atom_coordinates, dtype=dtype) | |
| if residue_coordinates is not None: | |
| property_dict["residue_coordinates"] = torch.tensor( | |
| residue_coordinates, dtype=dtype | |
| ) | |
| if residue_ids is not None: | |
| property_dict["residue_ids"] = torch.tensor(residue_ids, dtype=dtype) | |
| if chain_ids is not None: | |
| property_dict["chain_ids"] = torch.zeros(len(chain_ids), dtype=dtype) | |
| property_dict["chain_ids"][chain_ids == "L"] = 1 | |
| return property_dict | |
| class NodeRepresentation(Enum): | |
| Surface = "surface" | |
| Atom = "atom" | |
| Residue = "residue" | |
| class PairedPDB(HeteroData): # type: ignore | |
| def from_tuple_system( | |
| cls, | |
| tupal: tuple = (Structure , Structure , Structure), | |
| add_edges: bool = True, | |
| k: int = 10, | |
| ) -> PairedPDB: | |
| return cls.from_structure_pair( | |
| holo=tupal[0], | |
| apo=tupal[1], | |
| add_edges=add_edges, | |
| k=k, | |
| ) | |
| def from_structure_pair( | |
| cls, | |
| holo: Structure, | |
| apo: Structure, | |
| add_edges: bool = True, | |
| k: int = 10, | |
| ) -> PairedPDB: | |
| graph = cls() | |
| holo_calpha = holo.filter("atom_name", mask=["CA"]) | |
| apo_calpha = apo.filter("atom_name", mask=["CA"]) | |
| r_h = (holo.dataframe['chain_id'] == 'R').sum() | |
| r_a = (apo.dataframe['chain_id'] == 'R').sum() | |
| holo_r_props = structure2tensor( | |
| atom_coordinates=holo.coords[:r_h], | |
| atom_types=holo.atom_array.atom_name[:r_h], | |
| element_types=holo.atom_array.element[:r_h], | |
| residue_coordinates=holo_calpha.coords[:r_h], | |
| residue_types=holo_calpha.atom_array.res_name[:r_h], | |
| residue_ids=holo_calpha.atom_array.res_id[:r_h], | |
| ) | |
| holo_l_props = structure2tensor( | |
| atom_coordinates=holo.coords[r_h:], | |
| atom_types=holo.atom_array.atom_name[r_h:], | |
| element_types=holo.atom_array.element[r_h:], | |
| residue_coordinates=holo_calpha.coords[r_h:], | |
| residue_types=holo_calpha.atom_array.res_name[r_h:], | |
| residue_ids=holo_calpha.atom_array.res_id[r_h:], | |
| ) | |
| apo_r_props = structure2tensor( | |
| atom_coordinates=apo.coords[:r_a], | |
| atom_types=apo.atom_array.atom_name[:r_a], | |
| element_types=apo.atom_array.element[:r_a], | |
| residue_coordinates=apo_calpha.coords[:r_a], | |
| residue_types=apo_calpha.atom_array.res_name[:r_a], | |
| residue_ids=apo_calpha.atom_array.res_id[:r_a], | |
| ) | |
| apo_l_props = structure2tensor( | |
| atom_coordinates=apo.coords[r_a:], | |
| atom_types=apo.atom_array.atom_name[r_a:], | |
| element_types=apo.atom_array.element[r_a:], | |
| residue_coordinates=apo_calpha.coords[r_a:], | |
| residue_types=apo_calpha.atom_array.res_name[r_a:], | |
| residue_ids=apo_calpha.atom_array.res_id[r_a:], | |
| ) | |
| graph["ligand"].x = apo_l_props["atom_types"] | |
| graph["ligand"].pos = apo_l_props["atom_coordinates"] | |
| graph["receptor"].x = apo_r_props["atom_types"] | |
| graph["receptor"].pos = apo_r_props["atom_coordinates"] | |
| graph["ligand"].y = holo_l_props["atom_coordinates"] | |
| # graph["ligand"].pos = holo_l_props["atom_coordinates"] | |
| graph["receptor"].y = holo_r_props["atom_coordinates"] | |
| # graph["receptor"].pos = holo_r_props["atom_coordinates"] | |
| if add_edges and torch_cluster_installed: | |
| graph["ligand"].edge_index = knn_graph( | |
| graph["ligand"].pos, k=k | |
| ) | |
| graph["receptor"].edge_index = knn_graph( | |
| graph["receptor"].pos, k=k | |
| ) | |
| # graph["ligand"].edge_index = knn_graph( | |
| # graph["ligand"].pos, k=k | |
| # ) | |
| # graph["receptor"].edge_index = knn_graph( | |
| # graph["receptor"].pos, k=k | |
| # ) | |
| return graph | |
| index = get_index() | |
| train = index[index.split == "train"].copy() | |
| val = index[index.split == "val"].copy() | |
| test = index[index.split == "test"].copy() | |
| train_filtered = train[(train['apo_R'] == True) & (train['apo_L'] == True)].copy() | |
| val_filtered = val[(val['apo_R'] == True) & (val['apo_L'] == True)].copy() | |
| test_filtered = test[(test['apo_R'] == True) & (test['apo_L'] == True)].copy() | |
| train_apo = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( | |
| monomer_types=["apo"], renumber_residues=True | |
| ) for i in range(0, 10000)] | |
| train_new_apo11 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( | |
| monomer_types=["apo"], renumber_residues=True | |
| ) for i in range(10000,10908)] | |
| train_new_apo12 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( | |
| # monomer_types=["apo"], renumber_residues=True | |
| ) for i in range(10908,11816)] | |
| val_new_apo1 = [get_system(val_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( | |
| monomer_types=["apo"], renumber_residues=True | |
| ) for i in range(0,342)] | |
| test_new_apo1 = [get_system(test_filtered.id.iloc[i]).create_masked_bound_unbound_complexes( | |
| monomer_types=["apo"], renumber_residues=True | |
| ) for i in range(0,342)] | |
| val_apo = val_new_apo1 + train_new_apo11 | |
| test_apo = test_new_apo1 + train_new_apo12 | |
| import pickle | |
| # with open("train_apo.pkl", "wb") as file: | |
| # pickle.dump(train_apo, file) | |
| # with open("val_apo.pkl", "wb") as file: | |
| # pickle.dump(val_apo, file) | |
| # with open("test_apo.pkl", "wb") as file: | |
| # pickle.dump(test_apo, file) | |
| # with open("train_apo.pkl", "rb") as file: | |
| # train_apo = pickle.load(file) | |
| # with open("val_apo.pkl", "rb") as file: | |
| # val_apo = pickle.load(file) | |
| # with open("test_apo.pkl", "rb") as file: | |
| # test_apo = pickle.load(file) | |
| # # %% | |
| train_geo = [PairedPDB.from_tuple_system(train_apo[i]) for i in range(0,len(train_apo))] | |
| val_geo = [PairedPDB.from_tuple_system(val_apo[i]) for i in range(0,len(val_apo))] | |
| test_geo = [PairedPDB.from_tuple_system(test_apo[i]) for i in range(0,len(test_apo))] | |
| # # %% | |
| # Train= [] | |
| # for i in range(0,len(train_geo)): | |
| # data = HeteroData() | |
| # data["ligand"].x = train_geo[i]["ligand"].x | |
| # data['ligand'].y = train_geo[i]["ligand"].y | |
| # data["ligand"].pos = train_geo[i]["ligand"].pos | |
| # data["ligand","ligand"].edge_index = train_geo[i]["ligand"] | |
| # data["receptor"].x = train_geo[i]["receptor"].x | |
| # data['receptor'].y = train_geo[i]["receptor"].y | |
| # data["receptor"].pos = train_geo[i]["receptor"].pos | |
| # data["receptor","receptor"].edge_index = train_geo[i]["receptor"] | |
| # #torch.save(data, f"./data/processed/train_sample_{i}.pt") | |
| # Train.append(data) | |
| from torch_geometric.data import HeteroData | |
| import torch_sparse | |
| from torch_geometric.edge_index import to_sparse_tensor | |
| import torch | |
| # Example of converting edge indices to SparseTensor and storing them in HeteroData | |
| Train1 = [] | |
| for i in range(len(train_geo)): | |
| data = HeteroData() | |
| # Define ligand node features | |
| data["ligand"].x = train_geo[i]["ligand"].x | |
| data["ligand"].y = train_geo[i]["ligand"].y | |
| data["ligand"].pos = train_geo[i]["ligand"].pos | |
| # Convert ligand edge index to SparseTensor | |
| ligand_edge_index = train_geo[i]["ligand"]["edge_index"] | |
| data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(train_geo[i]["ligand"].num_nodes,)*2) | |
| # Define receptor node features | |
| data["receptor"].x = train_geo[i]["receptor"].x | |
| data["receptor"].y = train_geo[i]["receptor"].y | |
| data["receptor"].pos = train_geo[i]["receptor"].pos | |
| # Convert receptor edge index to SparseTensor | |
| receptor_edge_index = train_geo[i]["receptor"]["edge_index"] | |
| data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(train_geo[i]["receptor"].num_nodes,)*2) | |
| Train1.append(data) | |
| # # %% | |
| # Val= [] | |
| # for i in range(0,len(val_geo)): | |
| # data = HeteroData() | |
| # data["ligand"].x = val_geo[i]["ligand"].x | |
| # data['ligand'].y = val_geo[i]["ligand"].y | |
| # data["ligand"].pos = val_geo[i]["ligand"].pos | |
| # data["ligand","ligand"].edge_index = val_geo[i]["ligand"] | |
| # data["receptor"].x = val_geo[i]["receptor"].x | |
| # data['receptor'].y = val_geo[i]["receptor"].y | |
| # data["receptor"].pos = val_geo[i]["receptor"].pos | |
| # data["receptor","receptor"].edge_index = val_geo[i]["receptor"] | |
| # #torch.save(data, f"./data/processed/val_sample_{i}.pt") | |
| # Val.append(data) | |
| Val1 = [] | |
| for i in range(len(val_geo)): | |
| data = HeteroData() | |
| # Define ligand node features | |
| data["ligand"].x = val_geo[i]["ligand"].x | |
| data["ligand"].y = val_geo[i]["ligand"].y | |
| data["ligand"].pos = val_geo[i]["ligand"].pos | |
| # Convert ligand edge index to SparseTensor | |
| ligand_edge_index = val_geo[i]["ligand"]["edge_index"] | |
| data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(val_geo[i]["ligand"].num_nodes,)*2) | |
| # Define receptor node features | |
| data["receptor"].x = val_geo[i]["receptor"].x | |
| data["receptor"].y = val_geo[i]["receptor"].y | |
| data["receptor"].pos = val_geo[i]["receptor"].pos | |
| # Convert receptor edge index to SparseTensor | |
| receptor_edge_index = val_geo[i]["receptor"]["edge_index"] | |
| data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(val_geo[i]["receptor"].num_nodes,)*2) | |
| Val1.append(data) | |
| # # %% | |
| # Test= [] | |
| # for i in range(0,len(test_geo)): | |
| # data = HeteroData() | |
| # data["ligand"].x = test_geo[i]["ligand"].x | |
| # data['ligand'].y = test_geo[i]["ligand"].y | |
| # data["ligand"].pos = test_geo[i]["ligand"].pos | |
| # data["ligand","ligand"].edge_index = test_geo[i]["ligand"] | |
| # data["receptor"].x = test_geo[i]["receptor"].x | |
| # data['receptor'].y = test_geo[i]["receptor"].y | |
| # data["receptor"].pos = test_geo[i]["receptor"].pos | |
| # data["receptor","receptor"].edge_index = test_geo[i]["receptor"] | |
| # #torch.save(data, f"./data/processed/test_sample_{i}.pt") | |
| # Test.append(data) | |
| Test1 = [] | |
| for i in range(len(test_geo)): | |
| data = HeteroData() | |
| # Define ligand node features | |
| data["ligand"].x = test_geo[i]["ligand"].x | |
| data["ligand"].y = test_geo[i]["ligand"].y | |
| data["ligand"].pos = test_geo[i]["ligand"].pos | |
| # Convert ligand edge index to SparseTensor | |
| ligand_edge_index = test_geo[i]["ligand"]["edge_index"] | |
| data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(test_geo[i]["ligand"].num_nodes,)*2) | |
| # Define receptor node features | |
| data["receptor"].x = test_geo[i]["receptor"].x | |
| data["receptor"].y = test_geo[i]["receptor"].y | |
| data["receptor"].pos = test_geo[i]["receptor"].pos | |
| # Convert receptor edge index to SparseTensor | |
| receptor_edge_index = test_geo[i]["receptor"]["edge_index"] | |
| data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(test_geo[i]["receptor"].num_nodes,)*2) | |
| Test1.append(data) | |
| # with open("Train.pkl", "wb") as file: | |
| # pickle.dump(Train, file) | |
| # with open("Val.pkl", "wb") as file: | |
| # pickle.dump(Val, file) | |
| # with open("Test.pkl", "wb") as file: | |
| # pickle.dump(Test, file) | |
| # with open("Train1.pkl", "rb") as file: | |
| # Train= pickle.load(file) | |
| # with open("Val.pkl", "rb") as file: | |
| # Val = pickle.load(file) | |
| # with open("Test.pkl", "rb") as file: | |
| # Test = pickle.load(file) |