Spaces:
Sleeping
Sleeping
| # Copyright 2021 AlQuraishi Laboratory | |
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from dockformerpp.model.primitives import Linear, LayerNorm | |
| class PairTransition(nn.Module): | |
| """ | |
| Implements Algorithm 15. | |
| """ | |
| def __init__(self, c_z, n): | |
| """ | |
| Args: | |
| c_z: | |
| Pair transition channel dimension | |
| n: | |
| Factor by which c_z is multiplied to obtain hidden channel | |
| dimension | |
| """ | |
| super(PairTransition, self).__init__() | |
| self.c_z = c_z | |
| self.n = n | |
| self.layer_norm = LayerNorm(self.c_z) | |
| self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") | |
| self.relu = nn.ReLU() | |
| self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") | |
| def _transition(self, z, mask): | |
| # [*, N_res, N_res, C_z] | |
| z = self.layer_norm(z) | |
| # [*, N_res, N_res, C_hidden] | |
| z = self.linear_1(z) | |
| z = self.relu(z) | |
| # [*, N_res, N_res, C_z] | |
| z = self.linear_2(z) | |
| z = z * mask | |
| return z | |
| def forward(self, | |
| z: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| z: | |
| [*, N_res, N_res, C_z] pair embedding | |
| Returns: | |
| [*, N_res, N_res, C_z] pair embedding update | |
| """ | |
| # DISCREPANCY: DeepMind forgets to apply the mask in this module. | |
| if mask is None: | |
| mask = z.new_ones(z.shape[:-1]) | |
| # [*, N_res, N_res, 1] | |
| mask = mask.unsqueeze(-1) | |
| z = self._transition(z=z, mask=mask) | |
| return z | |