Spaces:
Sleeping
Sleeping
fix inference_app.py
Browse files- inference_app.py +17 -17
inference_app.py
CHANGED
|
@@ -106,7 +106,7 @@ from Bio import PDB
|
|
| 106 |
from Bio.PDB.PDBIO import PDBIO
|
| 107 |
|
| 108 |
def extract_coordinates_from_pdb(filename):
|
| 109 |
-
"""
|
| 110 |
Extracts atom coordinates from a PDB file and returns them as a list of tuples.
|
| 111 |
Each tuple contains (x, y, z) coordinates of an atom.
|
| 112 |
"""
|
|
@@ -288,7 +288,7 @@ class PairedPDB(HeteroData): # type: ignore
|
|
| 288 |
|
| 289 |
#create_graph takes inputs apo_ligand, apo_residue and paired holo as pdb3(ground truth).
|
| 290 |
def create_graph(pdb1, pdb2, pdb3='/home/sukanya/iitm_bisect_pinder_submission/test_out.pdb', k=5):
|
| 291 |
-
"""
|
| 292 |
Create a heterogeneous graph from two PDB files, with the ligand and receptor
|
| 293 |
as separate nodes, and their respective features and edges.
|
| 294 |
|
|
@@ -336,7 +336,7 @@ def create_graph(pdb1, pdb2, pdb3='/home/sukanya/iitm_bisect_pinder_submission/t
|
|
| 336 |
|
| 337 |
|
| 338 |
def update_pdb_coordinates_from_tensor(input_filename, output_filename, coordinates_tensor):
|
| 339 |
-
"""
|
| 340 |
Updates atom coordinates in a PDB file with new transformed coordinates provided in a tensor.
|
| 341 |
|
| 342 |
Parameters:
|
|
@@ -400,7 +400,7 @@ def update_pdb_coordinates_from_tensor(input_filename, output_filename, coordina
|
|
| 400 |
return output_filename
|
| 401 |
|
| 402 |
def merge_pdb_files(file1, file2, output_file):
|
| 403 |
-
"""
|
| 404 |
Merges two PDB files by concatenating them without altering their contents.
|
| 405 |
|
| 406 |
Parameters:
|
|
@@ -421,7 +421,7 @@ def merge_pdb_files(file1, file2, output_file):
|
|
| 421 |
|
| 422 |
class MPNNLayer(MessagePassing):
|
| 423 |
def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
|
| 424 |
-
"""Message Passing Neural Network Layer
|
| 425 |
|
| 426 |
Args:
|
| 427 |
emb_dim: (int) - hidden dimension d
|
|
@@ -451,7 +451,7 @@ class MPNNLayer(MessagePassing):
|
|
| 451 |
)
|
| 452 |
|
| 453 |
def forward(self, h, edge_index, edge_attr):
|
| 454 |
-
"""
|
| 455 |
The forward pass updates node features h via one round of message passing.
|
| 456 |
|
| 457 |
As our MPNNLayer class inherits from the PyG MessagePassing parent class,
|
|
@@ -474,7 +474,7 @@ class MPNNLayer(MessagePassing):
|
|
| 474 |
return out
|
| 475 |
|
| 476 |
def message(self, h_i, h_j, edge_attr):
|
| 477 |
-
"""Step (1) Message
|
| 478 |
|
| 479 |
The message() function constructs messages from source nodes j
|
| 480 |
to destination nodes i for each edge (i, j) in edge_index.
|
|
@@ -502,7 +502,7 @@ class MPNNLayer(MessagePassing):
|
|
| 502 |
return self.mlp_msg(msg)
|
| 503 |
|
| 504 |
def aggregate(self, inputs, index):
|
| 505 |
-
"""Step (2) Aggregate
|
| 506 |
|
| 507 |
The aggregate function aggregates the messages from neighboring nodes,
|
| 508 |
according to the chosen aggregation function ('sum' by default).
|
|
@@ -517,7 +517,7 @@ class MPNNLayer(MessagePassing):
|
|
| 517 |
return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)
|
| 518 |
|
| 519 |
def update(self, aggr_out, h):
|
| 520 |
-
"""
|
| 521 |
Step (3) Update
|
| 522 |
|
| 523 |
The update() function computes the final node features by combining the
|
|
@@ -541,7 +541,7 @@ class MPNNLayer(MessagePassing):
|
|
| 541 |
return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')
|
| 542 |
class MPNNModel(Module):
|
| 543 |
def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
|
| 544 |
-
"""Message Passing Neural Network model for graph property prediction
|
| 545 |
|
| 546 |
Args:
|
| 547 |
num_layers: (int) - number of message passing layers L
|
|
@@ -570,7 +570,7 @@ class MPNNModel(Module):
|
|
| 570 |
self.lin_pred = Linear(emb_dim, out_dim)
|
| 571 |
|
| 572 |
def forward(self, data):
|
| 573 |
-
"""
|
| 574 |
Args:
|
| 575 |
data: (PyG.Data) - batch of PyG graphs
|
| 576 |
|
|
@@ -592,7 +592,7 @@ class MPNNModel(Module):
|
|
| 592 |
|
| 593 |
class EquivariantMPNNLayer(MessagePassing):
|
| 594 |
def __init__(self, emb_dim=64, aggr='add'):
|
| 595 |
-
"""Message Passing Neural Network Layer
|
| 596 |
|
| 597 |
This layer is equivariant to 3D rotations and translations.
|
| 598 |
|
|
@@ -630,7 +630,7 @@ class EquivariantMPNNLayer(MessagePassing):
|
|
| 630 |
# ===========================================
|
| 631 |
|
| 632 |
def forward(self, h, pos, edge_index):
|
| 633 |
-
"""
|
| 634 |
The forward pass updates node features h via one round of message passing.
|
| 635 |
|
| 636 |
Args:
|
|
@@ -667,7 +667,7 @@ class EquivariantMPNNLayer(MessagePassing):
|
|
| 667 |
# ...
|
| 668 |
#
|
| 669 |
def aggregate(self, inputs, index):
|
| 670 |
-
"""The aggregate function aggregates the messages from neighboring nodes,
|
| 671 |
according to the chosen aggregation function ('sum' by default).
|
| 672 |
|
| 673 |
Args:
|
|
@@ -701,7 +701,7 @@ class EquivariantMPNNLayer(MessagePassing):
|
|
| 701 |
|
| 702 |
class FinalMPNNModel(MPNNModel):
|
| 703 |
def __init__(self, num_layers=4, emb_dim=64, in_dim=3, num_heads = 2):
|
| 704 |
-
"""Message Passing Neural Network model for graph property prediction
|
| 705 |
|
| 706 |
This model uses both node features and coordinates as inputs, and
|
| 707 |
is invariant to 3D rotations and translations (the constituent MPNN layers
|
|
@@ -734,7 +734,7 @@ class FinalMPNNModel(MPNNModel):
|
|
| 734 |
# self.pool = global_mean_pool
|
| 735 |
|
| 736 |
def naive_single(self, receptor, ligand , receptor_edge_index , ligand_edge_index):
|
| 737 |
-
"""
|
| 738 |
Processes a single receptor-ligand pair.
|
| 739 |
|
| 740 |
Args:
|
|
@@ -775,7 +775,7 @@ class FinalMPNNModel(MPNNModel):
|
|
| 775 |
|
| 776 |
|
| 777 |
def forward(self, data):
|
| 778 |
-
"""
|
| 779 |
The main forward pass of the model.
|
| 780 |
|
| 781 |
Args:
|
|
|
|
| 106 |
from Bio.PDB.PDBIO import PDBIO
|
| 107 |
|
| 108 |
def extract_coordinates_from_pdb(filename):
|
| 109 |
+
r"""
|
| 110 |
Extracts atom coordinates from a PDB file and returns them as a list of tuples.
|
| 111 |
Each tuple contains (x, y, z) coordinates of an atom.
|
| 112 |
"""
|
|
|
|
| 288 |
|
| 289 |
#create_graph takes inputs apo_ligand, apo_residue and paired holo as pdb3(ground truth).
|
| 290 |
def create_graph(pdb1, pdb2, pdb3='/home/sukanya/iitm_bisect_pinder_submission/test_out.pdb', k=5):
|
| 291 |
+
r"""
|
| 292 |
Create a heterogeneous graph from two PDB files, with the ligand and receptor
|
| 293 |
as separate nodes, and their respective features and edges.
|
| 294 |
|
|
|
|
| 336 |
|
| 337 |
|
| 338 |
def update_pdb_coordinates_from_tensor(input_filename, output_filename, coordinates_tensor):
|
| 339 |
+
r"""
|
| 340 |
Updates atom coordinates in a PDB file with new transformed coordinates provided in a tensor.
|
| 341 |
|
| 342 |
Parameters:
|
|
|
|
| 400 |
return output_filename
|
| 401 |
|
| 402 |
def merge_pdb_files(file1, file2, output_file):
|
| 403 |
+
r"""
|
| 404 |
Merges two PDB files by concatenating them without altering their contents.
|
| 405 |
|
| 406 |
Parameters:
|
|
|
|
| 421 |
|
| 422 |
class MPNNLayer(MessagePassing):
|
| 423 |
def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
|
| 424 |
+
r"""Message Passing Neural Network Layer
|
| 425 |
|
| 426 |
Args:
|
| 427 |
emb_dim: (int) - hidden dimension d
|
|
|
|
| 451 |
)
|
| 452 |
|
| 453 |
def forward(self, h, edge_index, edge_attr):
|
| 454 |
+
r"""
|
| 455 |
The forward pass updates node features h via one round of message passing.
|
| 456 |
|
| 457 |
As our MPNNLayer class inherits from the PyG MessagePassing parent class,
|
|
|
|
| 474 |
return out
|
| 475 |
|
| 476 |
def message(self, h_i, h_j, edge_attr):
|
| 477 |
+
r"""Step (1) Message
|
| 478 |
|
| 479 |
The message() function constructs messages from source nodes j
|
| 480 |
to destination nodes i for each edge (i, j) in edge_index.
|
|
|
|
| 502 |
return self.mlp_msg(msg)
|
| 503 |
|
| 504 |
def aggregate(self, inputs, index):
|
| 505 |
+
r"""Step (2) Aggregate
|
| 506 |
|
| 507 |
The aggregate function aggregates the messages from neighboring nodes,
|
| 508 |
according to the chosen aggregation function ('sum' by default).
|
|
|
|
| 517 |
return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)
|
| 518 |
|
| 519 |
def update(self, aggr_out, h):
|
| 520 |
+
r"""
|
| 521 |
Step (3) Update
|
| 522 |
|
| 523 |
The update() function computes the final node features by combining the
|
|
|
|
| 541 |
return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')
|
| 542 |
class MPNNModel(Module):
|
| 543 |
def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
|
| 544 |
+
r"""Message Passing Neural Network model for graph property prediction
|
| 545 |
|
| 546 |
Args:
|
| 547 |
num_layers: (int) - number of message passing layers L
|
|
|
|
| 570 |
self.lin_pred = Linear(emb_dim, out_dim)
|
| 571 |
|
| 572 |
def forward(self, data):
|
| 573 |
+
r"""
|
| 574 |
Args:
|
| 575 |
data: (PyG.Data) - batch of PyG graphs
|
| 576 |
|
|
|
|
| 592 |
|
| 593 |
class EquivariantMPNNLayer(MessagePassing):
|
| 594 |
def __init__(self, emb_dim=64, aggr='add'):
|
| 595 |
+
r"""Message Passing Neural Network Layer
|
| 596 |
|
| 597 |
This layer is equivariant to 3D rotations and translations.
|
| 598 |
|
|
|
|
| 630 |
# ===========================================
|
| 631 |
|
| 632 |
def forward(self, h, pos, edge_index):
|
| 633 |
+
r"""
|
| 634 |
The forward pass updates node features h via one round of message passing.
|
| 635 |
|
| 636 |
Args:
|
|
|
|
| 667 |
# ...
|
| 668 |
#
|
| 669 |
def aggregate(self, inputs, index):
|
| 670 |
+
r"""The aggregate function aggregates the messages from neighboring nodes,
|
| 671 |
according to the chosen aggregation function ('sum' by default).
|
| 672 |
|
| 673 |
Args:
|
|
|
|
| 701 |
|
| 702 |
class FinalMPNNModel(MPNNModel):
|
| 703 |
def __init__(self, num_layers=4, emb_dim=64, in_dim=3, num_heads = 2):
|
| 704 |
+
r"""Message Passing Neural Network model for graph property prediction
|
| 705 |
|
| 706 |
This model uses both node features and coordinates as inputs, and
|
| 707 |
is invariant to 3D rotations and translations (the constituent MPNN layers
|
|
|
|
| 734 |
# self.pool = global_mean_pool
|
| 735 |
|
| 736 |
def naive_single(self, receptor, ligand , receptor_edge_index , ligand_edge_index):
|
| 737 |
+
r"""
|
| 738 |
Processes a single receptor-ligand pair.
|
| 739 |
|
| 740 |
Args:
|
|
|
|
| 775 |
|
| 776 |
|
| 777 |
def forward(self, data):
|
| 778 |
+
r"""
|
| 779 |
The main forward pass of the model.
|
| 780 |
|
| 781 |
Args:
|