Amir Hallaji commited on
Commit
233e25e
·
1 Parent(s): e2ba292

version 0.1.0

Browse files
Files changed (1) hide show
  1. models.py +95 -0
models.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoModel, AutoTokenizer
5
+
6
+ class Swish(torch.nn.Module):
7
+ def forward(self, x):
8
+ return x * torch.sigmoid(x)
9
+
10
+ class Mish(torch.nn.Module):
11
+ def forward(self, x):
12
+ return x * torch.tanh(torch.nn.functional.softplus(x))
13
+
14
+ class ResidualInceptionBlock(nn.Module):
15
+ def __init__(self, in_channels, out_channels, kernel_sizes=[1,3], dropout=0.05):
16
+ super(ResidualInceptionBlock, self).__init__()
17
+
18
+ self.out_channels = out_channels
19
+ num_branches = len(kernel_sizes)
20
+ branch_out_channels = out_channels // num_branches
21
+
22
+ self.branches = nn.ModuleList([
23
+ nn.Sequential(
24
+ nn.Conv1d(in_channels, in_channels, kernel_size=1),
25
+ nn.BatchNorm1d(in_channels),
26
+ nn.ReLU(),
27
+ nn.Conv1d(in_channels, branch_out_channels, kernel_size=k, padding=k // 2),
28
+ nn.BatchNorm1d(branch_out_channels),
29
+ nn.ReLU(),
30
+ nn.Dropout(dropout)
31
+ ) for k in kernel_sizes
32
+ ])
33
+
34
+ self.residual_adjust = nn.Conv1d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
35
+ self.relu = nn.ReLU()
36
+
37
+ def forward(self, x):
38
+ branch_outputs = [branch(x) for branch in self.branches]
39
+ concatenated = torch.cat(branch_outputs, dim=1)
40
+ residual = self.residual_adjust(x)
41
+ output = self.relu(concatenated + residual)
42
+ return output
43
+
44
+ class AffinityPredictor(nn.Module):
45
+ def __init__(self,
46
+ protein_model_name="facebook/esm2_t6_8M_UR50D",
47
+ molecule_model_name="DeepChem/ChemBERTa-77M-MLM",
48
+ hidden_sizes=[1024,768,512,256,1],
49
+ inception_out_channels=256,
50
+ dropout=0.01):
51
+ super(AffinityPredictor, self).__init__()
52
+
53
+ self.protein_model = AutoModel.from_pretrained(protein_model_name)
54
+ self.molecule_model = AutoModel.from_pretrained(molecule_model_name)
55
+
56
+ self.protein_model.config.gradient_checkpointing = True
57
+ self.protein_model.gradient_checkpointing_enable()
58
+
59
+ self.molecule_model.config.gradient_checkpointing = True
60
+ self.molecule_model.gradient_checkpointing_enable()
61
+
62
+ prot_embedding_dim = self.protein_model.config.hidden_size
63
+ mol_embedding_dim = self.molecule_model.config.hidden_size
64
+ combined_dim = prot_embedding_dim + mol_embedding_dim
65
+
66
+ self.inc1 = ResidualInceptionBlock(combined_dim, combined_dim, dropout=dropout)
67
+ self.inc2 = ResidualInceptionBlock(combined_dim, combined_dim, dropout=dropout)
68
+
69
+ layers = []
70
+ input_dim = combined_dim # After Inception block
71
+ for output_dim in hidden_sizes:
72
+ layers.append(nn.Linear(input_dim, output_dim))
73
+ if output_dim != 1:
74
+ layers.append(Mish())
75
+ input_dim = output_dim
76
+ self.regressor = nn.Sequential(*layers)
77
+ self.dropout = nn.Dropout(dropout)
78
+
79
+ def forward(self, batch):
80
+ protein_input = {
81
+ "input_ids": batch["protein_input_ids"],
82
+ "attention_mask": batch["protein_attention_mask"]
83
+ }
84
+ molecule_input = {
85
+ "input_ids": batch["molecule_input_ids"],
86
+ "attention_mask": batch["molecule_attention_mask"]
87
+ }
88
+ protein_embedding = self.protein_model(**protein_input).last_hidden_state.mean(dim=1) # (batch_size, hidden_dim)
89
+ molecule_embedding = self.molecule_model(**molecule_input).last_hidden_state.mean(dim=1) # (batch_size, hidden_dim)
90
+ combined_features = torch.cat((protein_embedding, molecule_embedding), dim=1).unsqueeze(2) # (batch_size, combined_dim, 1)
91
+ combined_features = self.inc1(combined_features) # (batch_size, combined_dim)
92
+ combined_features = self.inc2(combined_features)
93
+ combined_features = combined_features.squeeze(2)
94
+ output = self.regressor(self.dropout(combined_features)) # (batch_size, 1)
95
+ return output