Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import time | |
| import json | |
| import gradio as gr | |
| from gradio_molecule3d import Molecule3D | |
| import torch | |
| from pinder.core import get_pinder_location | |
| get_pinder_location() | |
| from pytorch_lightning import LightningModule | |
| import torch | |
| import lightning.pytorch as pl | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import torchmetrics | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import MessagePassing | |
| from torch_geometric.nn import global_mean_pool | |
| from torch.nn import Sequential, Linear, BatchNorm1d, ReLU | |
| from torch_scatter import scatter | |
| from torch.nn import Module | |
| import pinder.core as pinder | |
| pinder.__version__ | |
| from torch_geometric.loader import DataLoader | |
| from pinder.core.loader.dataset import get_geo_loader | |
| from pinder.core import download_dataset | |
| from pinder.core import get_index | |
| from pinder.core import get_metadata | |
| from pathlib import Path | |
| import pandas as pd | |
| from pinder.core import PinderSystem | |
| import torch | |
| from pinder.core.loader.dataset import PPIDataset | |
| from pinder.core.loader.geodata import NodeRepresentation | |
| import pickle | |
| from pinder.core import get_index, PinderSystem | |
| from torch_geometric.data import HeteroData | |
| import os | |
| from enum import Enum | |
| import numpy as np | |
| import torch | |
| import lightning.pytorch as pl | |
| from numpy.typing import NDArray | |
| from torch_geometric.data import HeteroData | |
| from pinder.core.index.system import PinderSystem | |
| from pinder.core.loader.structure import Structure | |
| from pinder.core.utils import constants as pc | |
| from pinder.core.utils.log import setup_logger | |
| from pinder.core.index.system import _align_monomers_with_mask | |
| from pinder.core.loader.structure import Structure | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import MessagePassing | |
| from torch_geometric.nn import global_mean_pool | |
| from torch.nn import Sequential, Linear, BatchNorm1d, ReLU | |
| from torch_scatter import scatter | |
| from torch.nn import Module | |
| import time | |
| from torch_geometric.nn import global_max_pool | |
| import copy | |
| import inspect | |
| import warnings | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| from torch import Tensor | |
| from torch_geometric.data import Data, Dataset, HeteroData | |
| from torch_geometric.data.feature_store import FeatureStore | |
| from torch_geometric.data.graph_store import GraphStore | |
| from torch_geometric.loader import ( | |
| LinkLoader, | |
| LinkNeighborLoader, | |
| NeighborLoader, | |
| NodeLoader, | |
| ) | |
| from torch_geometric.loader.dataloader import DataLoader | |
| from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes | |
| from torch_geometric.sampler import BaseSampler, NeighborSampler | |
| from torch_geometric.typing import InputEdges, InputNodes | |
| try: | |
| from lightning.pytorch import LightningDataModule as PLLightningDataModule | |
| no_pytorch_lightning = False | |
| except (ImportError, ModuleNotFoundError): | |
| PLLightningDataModule = object | |
| no_pytorch_lightning = True | |
| from lightning.pytorch.callbacks import ModelCheckpoint | |
| from lightning.pytorch.loggers.tensorboard import TensorBoardLogger | |
| from lightning.pytorch.callbacks.early_stopping import EarlyStopping | |
| from torch_geometric.data.lightning.datamodule import LightningDataset | |
| from pytorch_lightning.loggers.wandb import WandbLogger | |
| def get_system(system_id: str) -> PinderSystem: | |
| return PinderSystem(system_id) | |
| from Bio import PDB | |
| from Bio.PDB.PDBIO import PDBIO | |
| from pinder.core.structure.atoms import atom_array_from_pdb_file | |
| from pathlib import Path | |
| from pinder.eval.dockq.biotite_dockq import BiotiteDockQ | |
| def extract_coordinates_from_pdb(filename, atom_name="CA"): | |
| """ | |
| Extracts coordinates for specific atoms from a PDB file and returns them as a list of tuples. | |
| Each tuple contains (x, y, z) coordinates of the specified atom type. | |
| Parameters: | |
| filename (str): Path to the PDB file. | |
| atom_name (str): The name of the atom to filter by (e.g., "CA" for alpha carbon). | |
| Returns: | |
| list of tuple: List of coordinates as (x, y, z) tuples for the specified atom. | |
| """ | |
| parser = PDB.PDBParser(QUIET=True) | |
| structure = parser.get_structure("structure", filename) | |
| coordinates = [] | |
| # Loop through each model, chain, residue, and atom to collect coordinates of specified atom | |
| for model in structure: | |
| for chain in model: | |
| for residue in chain: | |
| for atom in residue: | |
| # Filter for specific atom name | |
| xyz = atom.coord # Coordinates are in a numpy array | |
| coordinates.append([xyz[0], xyz[1], xyz[2]]) | |
| return coordinates | |
| log = setup_logger(__name__) | |
| try: | |
| from torch_cluster import knn_graph | |
| torch_cluster_installed = True | |
| except ImportError as e: | |
| log.warning( | |
| "torch-cluster is not installed!" | |
| "Please install the appropriate library for your pytorch installation." | |
| "See https://github.com/rusty1s/pytorch_cluster/issues/185 for background." | |
| ) | |
| torch_cluster_installed = False | |
| def structure2tensor( | |
| atom_coordinates: NDArray[np.double] | None = None, | |
| atom_types: NDArray[np.str_] | None = None, | |
| element_types: NDArray[np.str_] | None = None, | |
| residue_coordinates: NDArray[np.double] | None = None, | |
| residue_ids: NDArray[np.int_] | None = None, | |
| residue_types: NDArray[np.str_] | None = None, | |
| chain_ids: NDArray[np.str_] | None = None, | |
| dtype: torch.dtype = torch.float32, | |
| ) -> dict[str, torch.Tensor]: | |
| property_dict = {} | |
| if atom_types is not None: | |
| unknown_name_idx = max(pc.ALL_ATOM_POSNS.values()) + 1 | |
| types_array_at = np.zeros((len(atom_types), 1)) | |
| for i, name in enumerate(atom_types): | |
| types_array_at[i] = pc.ALL_ATOM_POSNS.get(name, unknown_name_idx) | |
| property_dict["atom_types"] = torch.tensor(types_array_at).type(dtype) | |
| if element_types is not None: | |
| types_array_ele = np.zeros((len(element_types), 1)) | |
| for i, name in enumerate(element_types): | |
| types_array_ele[i] = pc.ELE2NUM.get(name, pc.ELE2NUM["other"]) | |
| property_dict["element_types"] = torch.tensor(types_array_ele).type(dtype) | |
| if residue_types is not None: | |
| unknown_name_idx = max(pc.AA_TO_INDEX.values()) + 1 | |
| types_array_res = np.zeros((len(residue_types), 1)) | |
| for i, name in enumerate(residue_types): | |
| types_array_res[i] = pc.AA_TO_INDEX.get(name, unknown_name_idx) | |
| property_dict["residue_types"] = torch.tensor(types_array_res).type(dtype) | |
| if atom_coordinates is not None: | |
| property_dict["atom_coordinates"] = torch.tensor(atom_coordinates, dtype=dtype) | |
| if residue_coordinates is not None: | |
| property_dict["residue_coordinates"] = torch.tensor( | |
| residue_coordinates, dtype=dtype | |
| ) | |
| if residue_ids is not None: | |
| property_dict["residue_ids"] = torch.tensor(residue_ids, dtype=dtype) | |
| if chain_ids is not None: | |
| property_dict["chain_ids"] = torch.zeros(len(chain_ids), dtype=dtype) | |
| property_dict["chain_ids"][chain_ids == "L"] = 1 | |
| return property_dict | |
| class NodeRepresentation(Enum): | |
| Surface = "surface" | |
| Atom = "atom" | |
| Residue = "residue" | |
| class PairedPDB(HeteroData): # type: ignore | |
| def from_tuple_system( | |
| cls, | |
| tupal: tuple = (Structure , Structure , Structure), | |
| add_edges: bool = True, | |
| k: int = 10, | |
| ) -> PairedPDB: | |
| return cls.from_structure_pair( | |
| holo=tupal[0], | |
| apo=tupal[1], | |
| add_edges=add_edges, | |
| k=k, | |
| ) | |
| def from_structure_pair( | |
| cls, | |
| holo: Structure, | |
| apo: Structure, | |
| add_edges: bool = True, | |
| k: int = 10, | |
| ) -> PairedPDB: | |
| graph = cls() | |
| holo_calpha = holo.filter("atom_name", mask=["CA"]) | |
| apo_calpha = apo.filter("atom_name", mask=["CA"]) | |
| r_h = (holo.dataframe['chain_id'] == 'R').sum() | |
| r_a = (apo.dataframe['chain_id'] == 'R').sum() | |
| holo_r_props = structure2tensor( | |
| atom_coordinates=holo.coords[:r_h], | |
| atom_types=holo.atom_array.atom_name[:r_h], | |
| element_types=holo.atom_array.element[:r_h], | |
| residue_coordinates=holo_calpha.coords[:r_h], | |
| residue_types=holo_calpha.atom_array.res_name[:r_h], | |
| residue_ids=holo_calpha.atom_array.res_id[:r_h], | |
| ) | |
| holo_l_props = structure2tensor( | |
| atom_coordinates=holo.coords[r_h:], | |
| atom_types=holo.atom_array.atom_name[r_h:], | |
| element_types=holo.atom_array.element[r_h:], | |
| residue_coordinates=holo_calpha.coords[r_h:], | |
| residue_types=holo_calpha.atom_array.res_name[r_h:], | |
| residue_ids=holo_calpha.atom_array.res_id[r_h:], | |
| ) | |
| apo_r_props = structure2tensor( | |
| atom_coordinates=apo.coords[:r_a], | |
| atom_types=apo.atom_array.atom_name[:r_a], | |
| element_types=apo.atom_array.element[:r_a], | |
| residue_coordinates=apo_calpha.coords[:r_a], | |
| residue_types=apo_calpha.atom_array.res_name[:r_a], | |
| residue_ids=apo_calpha.atom_array.res_id[:r_a], | |
| ) | |
| apo_l_props = structure2tensor( | |
| atom_coordinates=apo.coords[r_a:], | |
| atom_types=apo.atom_array.atom_name[r_a:], | |
| element_types=apo.atom_array.element[r_a:], | |
| residue_coordinates=apo_calpha.coords[r_a:], | |
| residue_types=apo_calpha.atom_array.res_name[r_a:], | |
| residue_ids=apo_calpha.atom_array.res_id[r_a:], | |
| ) | |
| graph["ligand"].x = apo_l_props["atom_types"] | |
| graph["ligand"].pos = apo_l_props["atom_coordinates"] | |
| graph["receptor"].x = apo_r_props["atom_types"] | |
| graph["receptor"].pos = apo_r_props["atom_coordinates"] | |
| graph["ligand"].y = holo_l_props["atom_coordinates"] | |
| # graph["ligand"].pos = holo_l_props["atom_coordinates"] | |
| graph["receptor"].y = holo_r_props["atom_coordinates"] | |
| # graph["receptor"].pos = holo_r_props["atom_coordinates"] | |
| if add_edges and torch_cluster_installed: | |
| graph["ligand"].edge_index = knn_graph( | |
| graph["ligand"].pos, k=k | |
| ) | |
| graph["receptor"].edge_index = knn_graph( | |
| graph["receptor"].pos, k=k | |
| ) | |
| # graph["ligand"].edge_index = knn_graph( | |
| # graph["ligand"].pos, k=k | |
| # ) | |
| # graph["receptor"].edge_index = knn_graph( | |
| # graph["receptor"].pos, k=k | |
| # ) | |
| return graph | |
| #create_graph takes inputs apo_ligand, apo_residue and paired holo as pdb3(ground truth). | |
| def create_graph(pdb1, pdb2, k=5): | |
| r""" | |
| Create a heterogeneous graph from two PDB files, with the ligand and receptor | |
| as separate nodes, and their respective features and edges. | |
| Args: | |
| pdb1 (str): PDB file path for ligand. | |
| pdb2 (str): PDB file path for receptor. | |
| coords3 (list): List of coordinates used for `y` values (e.g., binding affinity, etc.). | |
| k (int): Number of nearest neighbors for constructing the knn graph. | |
| Returns: | |
| HeteroData: A PyG HeteroData object containing ligand and receptor data. | |
| """ | |
| # Extract coordinates from PDB files | |
| coords1 = torch.tensor(extract_coordinates_from_pdb(pdb1),dtype=torch.float) | |
| coords2 = torch.tensor(extract_coordinates_from_pdb(pdb2),dtype=torch.float) | |
| # coords3 = torch.tensor(extract_coordinates_from_pdb(pdb3),dtype=torch.float) | |
| # Create the HeteroData object | |
| data = HeteroData() | |
| # Define ligand node features | |
| data["ligand"].x = torch.tensor(coords1, dtype=torch.float) | |
| data["ligand"].pos = coords1 | |
| # data["ligand"].y = torch.tensor(coords3[:len(coords1)], dtype=torch.float) | |
| # Define receptor node features | |
| data["receptor"].x = torch.tensor(coords2, dtype=torch.float) | |
| data["receptor"].pos = coords2 | |
| # data["receptor"].y = torch.tensor(coords3[len(coords1):], dtype=torch.float) | |
| # Construct k-NN graph for ligand | |
| ligand_edge_index = knn_graph(data["ligand"].pos, k=k) | |
| data["ligand"].edge_index = ligand_edge_index | |
| # Construct k-NN graph for receptor | |
| receptor_edge_index = knn_graph(data["receptor"].pos, k=k) | |
| data["receptor"].edge_index = receptor_edge_index | |
| # Convert edge index to SparseTensor for ligand | |
| data["ligand", "ligand"].edge_index = ligand_edge_index | |
| # Convert edge index to SparseTensor for receptor | |
| data["receptor", "receptor"].edge_index = receptor_edge_index | |
| return data | |
| def update_pdb_coordinates_from_tensor(input_filename, output_filename, coordinates_tensor): | |
| r""" | |
| Updates atom coordinates in a PDB file with new transformed coordinates provided in a tensor. | |
| Parameters: | |
| - input_filename (str): Path to the original PDB file. | |
| - output_filename (str): Path to the new PDB file to save updated coordinates. | |
| - coordinates_tensor (torch.Tensor): Tensor of shape (1, N, 3) with transformed coordinates. | |
| """ | |
| # Convert the tensor to a list of tuples | |
| new_coordinates = coordinates_tensor.squeeze(0).tolist() | |
| # Create a parser and parse the structure | |
| parser = PDB.PDBParser(QUIET=True) | |
| structure = parser.get_structure("structure", input_filename) | |
| # Flattened iterator for atoms to update coordinates | |
| atom_iterator = (atom for model in structure for chain in model for residue in chain for atom in residue) | |
| # Update each atom's coordinates | |
| for atom, (new_x, new_y, new_z) in zip(atom_iterator, new_coordinates): | |
| original_anisou = atom.get_anisou() | |
| original_uij = atom.get_siguij() | |
| original_tm= atom.get_sigatm() | |
| original_occupancy = atom.get_occupancy() | |
| original_bfactor = atom.get_bfactor() | |
| original_altloc = atom.get_altloc() | |
| original_fullname = atom.get_fullname() | |
| original_serial_number = atom.get_serial_number() | |
| original_element = atom.get_charge() | |
| original_id = atom.get_full_id() | |
| original_idx = atom.get_id() | |
| original_level = atom.get_level() | |
| original_name = atom.get_name() | |
| original_parent = atom.get_parent() | |
| original_radius = atom.get_radius() | |
| # Update only the atom coordinates, keep other fields intact | |
| atom.coord = np.array([new_x, new_y, new_z]) | |
| # Reapply the preserved properties | |
| atom.set_anisou(original_anisou) | |
| atom.set_siguij(original_uij) | |
| atom.set_sigatm(original_tm) | |
| atom.set_occupancy(original_occupancy) | |
| atom.set_bfactor(original_bfactor) | |
| atom.set_altloc(original_altloc) | |
| # atom.set_fullname(original_fullname) | |
| atom.set_serial_number(original_serial_number) | |
| atom.set_charge(original_element) | |
| atom.set_radius(original_radius) | |
| atom.set_parent(original_parent) | |
| # atom.set_name(original_name) | |
| # atom.set_leve | |
| output_filename = "/tmp/" + output_filename | |
| # Save the updated structure to a new PDB file | |
| io = PDBIO() | |
| io.set_structure(structure) | |
| io.save(output_filename) | |
| # Return the path to the updated PDB file | |
| return output_filename | |
| def merge_pdb_files(file1, file2, output_file): | |
| r""" | |
| Merges two PDB files by concatenating them without altering their contents. | |
| Parameters: | |
| - file1 (str): Path to the first PDB file (e.g., receptor). | |
| - file2 (str): Path to the second PDB file (e.g., ligand). | |
| - output_file (str): Path to the output file where the merged structure will be saved. | |
| """ | |
| output_file = "/tmp/" + output_file | |
| with open(output_file, 'w') as outfile: | |
| # Copy the contents of the first file | |
| with open(file1, 'r') as f1: | |
| lines = f1.readlines() | |
| # Write all lines except the last 'END' line | |
| outfile.writelines(lines[:-1]) | |
| # Copy the contents of the second file | |
| with open(file2, 'r') as f2: | |
| outfile.write(f2.read()) | |
| print(f"Merged PDB saved to {output_file}") | |
| return output_file | |
| class MPNNLayer(MessagePassing): | |
| def __init__(self, emb_dim=64, edge_dim=4, aggr='add'): | |
| r"""Message Passing Neural Network Layer | |
| Args: | |
| emb_dim: (int) - hidden dimension d | |
| edge_dim: (int) - edge feature dimension d_e | |
| aggr: (str) - aggregation function \oplus (sum/mean/max) | |
| """ | |
| # Set the aggregation function | |
| super().__init__(aggr=aggr) | |
| self.emb_dim = emb_dim | |
| self.edge_dim = edge_dim | |
| # MLP \psi for computing messages m_ij | |
| # Implemented as a stack of Linear->BN->ReLU->Linear->BN->ReLU | |
| # dims: (2d + d_e) -> d | |
| self.mlp_msg = Sequential( | |
| Linear(2*emb_dim + edge_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), | |
| Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU() | |
| ) | |
| # MLP \phi for computing updated node features h_i^{l+1} | |
| # Implemented as a stack of Linear->BN->ReLU->Linear->BN->ReLU | |
| # dims: 2d -> d | |
| self.mlp_upd = Sequential( | |
| Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), | |
| Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU() | |
| ) | |
| def forward(self, h, edge_index, edge_attr): | |
| r""" | |
| The forward pass updates node features h via one round of message passing. | |
| As our MPNNLayer class inherits from the PyG MessagePassing parent class, | |
| we simply need to call the propagate() function which starts the | |
| message passing procedure: message() -> aggregate() -> update(). | |
| The MessagePassing class handles most of the logic for the implementation. | |
| To build custom GNNs, we only need to define our own message(), | |
| aggregate(), and update() functions (defined subsequently). | |
| Args: | |
| h: (n, d) - initial node features | |
| edge_index: (e, 2) - pairs of edges (i, j) | |
| edge_attr: (e, d_e) - edge features | |
| Returns: | |
| out: (n, d) - updated node features | |
| """ | |
| out = self.propagate(edge_index, h=h, edge_attr=edge_attr) | |
| return out | |
| def message(self, h_i, h_j, edge_attr): | |
| r"""Step (1) Message | |
| The message() function constructs messages from source nodes j | |
| to destination nodes i for each edge (i, j) in edge_index. | |
| The arguments can be a bit tricky to understand: message() can take | |
| any arguments that were initially passed to propagate. Additionally, | |
| we can differentiate destination nodes and source nodes by appending | |
| _i or _j to the variable name, e.g. for the node features h, we | |
| can use h_i and h_j. | |
| This part is critical to understand as the message() function | |
| constructs messages for each edge in the graph. The indexing of the | |
| original node features h (or other node variables) is handled under | |
| the hood by PyG. | |
| Args: | |
| h_i: (e, d) - destination node features | |
| h_j: (e, d) - source node features | |
| edge_attr: (e, d_e) - edge features | |
| Returns: | |
| msg: (e, d) - messages m_ij passed through MLP \psi | |
| """ | |
| msg = torch.cat([h_i, h_j, edge_attr], dim=-1) | |
| return self.mlp_msg(msg) | |
| def aggregate(self, inputs, index): | |
| r"""Step (2) Aggregate | |
| The aggregate function aggregates the messages from neighboring nodes, | |
| according to the chosen aggregation function ('sum' by default). | |
| Args: | |
| inputs: (e, d) - messages m_ij from destination to source nodes | |
| index: (e, 1) - list of source nodes for each edge/message in input | |
| Returns: | |
| aggr_out: (n, d) - aggregated messages m_i | |
| """ | |
| return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr) | |
| def update(self, aggr_out, h): | |
| r""" | |
| Step (3) Update | |
| The update() function computes the final node features by combining the | |
| aggregated messages with the initial node features. | |
| update() takes the first argument aggr_out, the result of aggregate(), | |
| as well as any optional arguments that were initially passed to | |
| propagate(). E.g. in this case, we additionally pass h. | |
| Args: | |
| aggr_out: (n, d) - aggregated messages m_i | |
| h: (n, d) - initial node features | |
| Returns: | |
| upd_out: (n, d) - updated node features passed through MLP \phi | |
| """ | |
| upd_out = torch.cat([h, aggr_out], dim=-1) | |
| return self.mlp_upd(upd_out) | |
| def __repr__(self) -> str: | |
| return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})') | |
| class MPNNModel(Module): | |
| def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1): | |
| r"""Message Passing Neural Network model for graph property prediction | |
| Args: | |
| num_layers: (int) - number of message passing layers L | |
| emb_dim: (int) - hidden dimension d | |
| in_dim: (int) - initial node feature dimension d_n | |
| edge_dim: (int) - edge feature dimension d_e | |
| out_dim: (int) - output dimension (fixed to 1) | |
| """ | |
| super().__init__() | |
| # Linear projection for initial node features | |
| # dim: d_n -> d | |
| self.lin_in = Linear(in_dim, emb_dim) | |
| # Stack of MPNN layers | |
| self.convs = torch.nn.ModuleList() | |
| for layer in range(num_layers): | |
| self.convs.append(MPNNLayer(emb_dim, edge_dim, aggr='add')) | |
| # Global pooling/readout function R (mean pooling) | |
| # PyG handles the underlying logic via global_mean_pool() | |
| self.pool = global_mean_pool | |
| # Linear prediction head | |
| # dim: d -> out_dim | |
| self.lin_pred = Linear(emb_dim, out_dim) | |
| def forward(self, data): | |
| r""" | |
| Args: | |
| data: (PyG.Data) - batch of PyG graphs | |
| Returns: | |
| out: (batch_size, out_dim) - prediction for each graph | |
| """ | |
| h = self.lin_in(data.x) # (n, d_n) -> (n, d) | |
| for conv in self.convs: | |
| h = h + conv(h, data.edge_index, data.edge_attr) # (n, d) -> (n, d) | |
| # Note that we add a residual connection after each MPNN layer | |
| h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d) | |
| out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1) | |
| return out.view(-1) | |
| class EquivariantMPNNLayer(MessagePassing): | |
| def __init__(self, emb_dim=64, aggr='add'): | |
| r"""Message Passing Neural Network Layer | |
| This layer is equivariant to 3D rotations and translations. | |
| Args: | |
| emb_dim: (int) - hidden dimension d | |
| edge_dim: (int) - edge feature dimension d_e | |
| aggr: (str) - aggregation function \oplus (sum/mean/max) | |
| """ | |
| # Set the aggregation function | |
| super().__init__(aggr=aggr) | |
| self.emb_dim = emb_dim | |
| # | |
| self.mlp_msg = Sequential( | |
| Linear(2 * emb_dim + 1, emb_dim), | |
| BatchNorm1d(emb_dim), | |
| ReLU(), | |
| Linear(emb_dim, emb_dim), | |
| BatchNorm1d(emb_dim), | |
| ReLU() | |
| ) | |
| self.mlp_pos = Sequential( | |
| Linear(emb_dim, emb_dim), | |
| BatchNorm1d(emb_dim), | |
| ReLU(), | |
| Linear(emb_dim,1) | |
| ) # MLP \psi | |
| self.mlp_upd = Sequential( | |
| Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim,emb_dim), BatchNorm1d(emb_dim), ReLU() | |
| ) # MLP \phi | |
| # =========================================== | |
| def forward(self, h, pos, edge_index): | |
| r""" | |
| The forward pass updates node features h via one round of message passing. | |
| Args: | |
| h: (n, d) - initial node features | |
| pos: (n, 3) - initial node coordinates | |
| edge_index: (e, 2) - pairs of edges (i, j) | |
| edge_attr: (e, d_e) - edge features | |
| Returns: | |
| out: [(n, d),(n,3)] - updated node features | |
| """ | |
| # | |
| out = self.propagate(edge_index=edge_index, h=h, pos=pos) | |
| return out | |
| # ========================================== | |
| # | |
| def message(self, h_i,h_j,pos_i,pos_j): | |
| # Compute distance between nodes i and j (Euclidean distance) | |
| #distance_ij = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) # (e, 1) | |
| pos_diff = pos_i - pos_j | |
| dists = torch.norm(pos_diff,dim=-1).unsqueeze(1) | |
| # Concatenate node features, edge features, and distance | |
| msg = torch.cat([h_i , h_j, dists], dim=-1) | |
| msg = self.mlp_msg(msg) | |
| pos_diff = pos_diff * self.mlp_pos(msg) # (e, 2d + d_e + 1) | |
| # (e, d) | |
| return msg , pos_diff | |
| # ... | |
| # | |
| def aggregate(self, inputs, index): | |
| r"""The aggregate function aggregates the messages from neighboring nodes, | |
| according to the chosen aggregation function ('sum' by default). | |
| Args: | |
| inputs: (e, d) - messages m_ij from destination to source nodes | |
| index: (e, 1) - list of source nodes for each edge/message in input | |
| Returns: | |
| aggr_out: (n, d) - aggregated messages m_i | |
| """ | |
| msgs , pos_diffs = inputs | |
| msg_aggr = scatter(msgs, index , dim = self.node_dim , reduce = self.aggr) | |
| pos_aggr = scatter(pos_diffs, index, dim = self.node_dim , reduce = "mean") | |
| return msg_aggr , pos_aggr | |
| def update(self, aggr_out, h , pos): | |
| msg_aggr , pos_aggr = aggr_out | |
| upd_out = self.mlp_upd(torch.cat((h, msg_aggr), dim=-1)) | |
| upd_pos = pos + pos_aggr | |
| return upd_out , upd_pos | |
| def __repr__(self) -> str: | |
| return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})') | |
| class FinalMPNNModel(MPNNModel): | |
| def __init__(self, num_layers=4, emb_dim=64, in_dim=3, num_heads = 2): | |
| r"""Message Passing Neural Network model for graph property prediction | |
| This model uses both node features and coordinates as inputs, and | |
| is invariant to 3D rotations and translations (the constituent MPNN layers | |
| are equivariant to 3D rotations and translations). | |
| Args: | |
| num_layers: (int) - number of message passing layers L | |
| emb_dim: (int) - hidden dimension d | |
| in_dim: (int) - initial node feature dimension d_n | |
| edge_dim: (int) - edge feature dimension d_e | |
| out_dim: (int) - output dimension (fixed to 1) | |
| """ | |
| super().__init__() | |
| # Linear projection for initial node features | |
| # dim: d_n -> d | |
| self.lin_in = Linear(in_dim, emb_dim) | |
| self.equiv_layer = EquivariantMPNNLayer(emb_dim=emb_dim) | |
| # Stack of MPNN layers | |
| self.convs = torch.nn.ModuleList() | |
| for layer in range(num_layers): | |
| self.convs.append(EquivariantMPNNLayer(emb_dim, aggr='add')) | |
| self.cross_attention = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True) | |
| self.fc_rotation = nn.Linear(emb_dim, 9) | |
| self.fc_translation = nn.Linear(emb_dim, 3) | |
| # Global pooling/readout function R (mean pooling) | |
| # PyG handles the underlying logic via global_mean_pool() | |
| # self.pool = global_mean_pool | |
| def naive_single(self, receptor, ligand , receptor_edge_index , ligand_edge_index): | |
| r""" | |
| Processes a single receptor-ligand pair. | |
| Args: | |
| receptor: Tensor of shape (1, num_receptor_atoms, 3) (receptor coordinates) | |
| ligand: Tensor of shape (1, num_ligand_atoms, 3) (ligand coordinates) | |
| Returns: | |
| rotation_matrix: Tensor of shape (1, 3, 3) predicted rotation matrix for the ligand. | |
| translation_vector: Tensor of shape (1, 3) predicted translation vector for the ligand. | |
| """ | |
| # h_receptor = receptor # Initial node features for the receptor | |
| # h_ligand = ligand | |
| h_receptor = self.lin_in(receptor) | |
| h_ligand = self.lin_in(ligand) # Initial node features for the ligand | |
| pos_receptor = receptor # Initial positions | |
| pos_ligand = ligand | |
| for layer in self.convs: | |
| # Apply the equivariant message-passing layer for both receptor and ligand | |
| h_receptor, pos_receptor = layer(h_receptor, pos_receptor,receptor_edge_index ) | |
| h_ligand, pos_ligand = layer(h_ligand, pos_ligand, ligand_edge_index) | |
| # print("Shape of h_receptor:", h_receptor.shape) | |
| # print("Shape of h_ligand:", h_ligand.shape) | |
| # Pass the layer outputs through MLPs for embeddings | |
| emb_features_receptor = h_receptor | |
| emb_features_ligand = h_ligand | |
| attn_output, _ = self.cross_attention(emb_features_receptor, emb_features_ligand, emb_features_ligand) | |
| rotation_matrix = self.fc_rotation(attn_output.mean(dim=0)) | |
| rotation_matrix = rotation_matrix.view(-1, 3, 3) | |
| translation_vector = self.fc_translation(attn_output.mean(dim=0)) | |
| return rotation_matrix, translation_vector | |
| def forward(self, data): | |
| r""" | |
| The main forward pass of the model. | |
| Args: | |
| batch: Same as in forward_rot_trans. | |
| Returns: | |
| transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3) | |
| representing the transformed ligand coordinates after applying the predicted | |
| rotation and translation. | |
| """ | |
| receptor = data['receptor']['pos'] | |
| ligand = data['ligand']['pos'] | |
| receptor_edge_index = data['receptor']['edge_index'] | |
| ligand_edge_index = data['ligand']['edge_index'] | |
| rotation_matrix, translation_vector = self.naive_single(receptor, ligand,receptor_edge_index , ligand_edge_index) | |
| # for i in range(len(ligands)): | |
| # ligands[i] = ligands[i] @ rotation_matrix[i] + translation_vector[i] | |
| ligands = data['ligand']['pos'] @ rotation_matrix + translation_vector | |
| return ligands | |
| class FinalMPNNModelight(pl.LightningModule): | |
| def __init__(self, num_layers=4, emb_dim=32, in_dim=3, num_heads=1, lr=1e-4): | |
| super().__init__() | |
| self.lin_in = nn.Linear(in_dim, emb_dim) | |
| self.convs = nn.ModuleList([EquivariantMPNNLayer(emb_dim, aggr='add') for _ in range(num_layers)]) | |
| self.cross_attention = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True) | |
| self.fc_rotation = nn.Linear(emb_dim, 9) | |
| self.fc_translation = nn.Linear(emb_dim, 3) | |
| self.lr = lr | |
| def naive_single(self, receptor, ligand, receptor_edge_index, ligand_edge_index): | |
| h_receptor = self.lin_in(receptor) | |
| h_ligand = self.lin_in(ligand) | |
| pos_receptor, pos_ligand = receptor, ligand | |
| for layer in self.convs: | |
| h_receptor, pos_receptor = layer(h_receptor, pos_receptor, receptor_edge_index) | |
| h_ligand, pos_ligand = layer(h_ligand, pos_ligand, ligand_edge_index) | |
| attn_output, _ = self.cross_attention(h_receptor, h_ligand, h_ligand) | |
| rotation_matrix = self.fc_rotation(attn_output.mean(dim=0)).view(-1, 3, 3) | |
| translation_vector = self.fc_translation(attn_output.mean(dim=0)) | |
| return rotation_matrix, translation_vector | |
| def forward(self, data): | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| receptor = data['receptor']['pos'].to(device) | |
| ligand = data['ligand']['pos'].to(device) | |
| receptor_edge_index = data['receptor', 'receptor']['edge_index'].to(device) | |
| ligand_edge_index = data['ligand', 'ligand']['edge_index'].to(device) | |
| rotation_matrix, translation_vector = self.naive_single(receptor, ligand, receptor_edge_index, ligand_edge_index) | |
| # transformed_ligand = torch.matmul(ligand ,rotation_matrix) + translation_vector | |
| return rotation_matrix, translation_vector | |
| def training_step(self, batch, batch_idx): | |
| ligand_pred = self(batch) | |
| ligand_true = batch['ligand']['y'] | |
| loss = F.mse_loss(ligand_pred.squeeze(0), ligand_true) | |
| self.log('train_loss', loss, batch_size=8) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| ligand_pred = self(batch) | |
| ligand_true = batch['ligand']['y'] | |
| loss = F.l1_loss(ligand_pred.squeeze(0), ligand_true) | |
| self.log('val_loss', loss, prog_bar=True, batch_size=8) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| ligand_pred = self(batch) | |
| ligand_true = batch['ligand']['y'] | |
| loss = F.l1_loss(ligand_pred.squeeze(0), ligand_true) | |
| self.log('test_loss', loss, prog_bar=True, batch_size=8) | |
| return loss | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode="min", factor=0.1, patience=5 | |
| ) | |
| return { | |
| "optimizer": optimizer, | |
| "lr_scheduler": { | |
| "scheduler": scheduler, | |
| "monitor": "val_loss", # Monitor validation loss to adjust the learning rate | |
| }, | |
| } | |
| model_path = "./EquiMPNN-epoch=413-val_loss=9.25-val_acc=0.00.ckpt" | |
| model = FinalMPNNModelight.load_from_checkpoint(model_path) | |
| trainer = pl.Trainer( | |
| fast_dev_run=False, | |
| accelerator="gpu" if torch.cuda.is_available() else "cpu", | |
| precision="bf16-mixed", | |
| devices=1, | |
| ) | |
| model.eval() | |
| def predict (input_seq_1, input_msa_1, input_protein_1, input_seq_2,input_msa_2, input_protein_2): | |
| start_time = time.time() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| data = create_graph(input_protein_1, input_protein_2, k=10) | |
| R_chain, L_chain = ["R"], ["L"] | |
| with torch.no_grad(): | |
| mat, vect = model(data) | |
| mat = mat.to(device) | |
| vect = vect.to(device) | |
| ligand1 = torch.tensor(extract_coordinates_from_pdb(input_protein_1),dtype=torch.float).to(device) | |
| # receptor1 = torch.tensor(extract_coordinates_from_pdb(input_protein_2),dtype=torch.float).to(device) | |
| transformed_ligand = torch.matmul(ligand1, mat) + vect | |
| # transformed_receptor = torch.matmul(receptor1, mat) + vect | |
| file1 = update_pdb_coordinates_from_tensor(input_protein_1, "holo_ligand.pdb", transformed_ligand) | |
| # file2 = update_pdb_coordinates_from_tensor(input_protein_2, "holo_receptor.pdb", transformed_receptor) | |
| out_pdb = merge_pdb_files(file1,input_protein_2,"output.pdb") | |
| # return an output pdb file with the protein and two chains A and B. | |
| # also return a JSON with any metrics you want to report | |
| metrics = {"mean_plddt": 80, "binding_affinity": 2} | |
| # native = './test_out (1).pdb' | |
| # decoys = out_pdb | |
| # bdq = BiotiteDockQ( | |
| # native=native, decoys=decoys, | |
| # # These are optional and if not specified will be assigned based on number of atoms (receptor > ligand) | |
| # native_receptor_chain=R_chain, | |
| # native_ligand_chain=L_chain, | |
| # decoy_receptor_chain=R_chain, | |
| # decoy_ligand_chain=L_chain, | |
| # ) | |
| # dockq = bdq.calculate() | |
| # metrics['DockQ'] = dockq | |
| end_time = time.time() | |
| run_time = end_time - start_time | |
| return out_pdb,json.dumps(metrics), run_time | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Template for inference") | |
| gr.Markdown("EquiMPNN MOdel") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)") | |
| input_msa_1 = gr.File(label="Input MSA Protein 1 (A3M)") | |
| input_protein_1 = gr.File(label="Input Protein 2 monomer (PDB)") | |
| with gr.Column(): | |
| input_seq_2 = gr.Textbox(lines=3, label="Input Protein 2 sequence (FASTA)") | |
| input_msa_2 = gr.File(label="Input MSA Protein 2 (A3M)") | |
| input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)") | |
| # define any options here | |
| # for automated inference the default options are used | |
| # slider_option = gr.Slider(0,10, label="Slider Option") | |
| # checkbox_option = gr.Checkbox(label="Checkbox Option") | |
| # dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option") | |
| btn = gr.Button("Run Inference") | |
| gr.Examples( | |
| [ | |
| [ | |
| "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD", | |
| "3v1c_A.pdb", | |
| "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD", | |
| "3v1c_B.pdb", | |
| ], | |
| ], | |
| [input_seq_1, input_protein_1, input_seq_2, input_protein_2], | |
| ) | |
| reps = [ | |
| { | |
| "model": 0, | |
| "style": "cartoon", | |
| "chain": "A", | |
| "color": "whiteCarbon", | |
| }, | |
| { | |
| "model": 0, | |
| "style": "cartoon", | |
| "chain": "B", | |
| "color": "greenCarbon", | |
| }, | |
| { | |
| "model": 0, | |
| "chain": "A", | |
| "style": "stick", | |
| "sidechain": True, | |
| "color": "whiteCarbon", | |
| }, | |
| { | |
| "model": 0, | |
| "chain": "B", | |
| "style": "stick", | |
| "sidechain": True, | |
| "color": "greenCarbon" | |
| } | |
| ] | |
| # outputs | |
| out = Molecule3D(reps=reps) | |
| metrics = gr.JSON(label="Metrics") | |
| run_time = gr.Textbox(label="Runtime") | |
| btn.click(predict, inputs=[input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2], outputs=[out, metrics, run_time]) | |
| app.launch() | |