Spaces:
Sleeping
Sleeping
| # Copyright 2021 AlQuraishi Laboratory | |
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from functools import reduce | |
| import importlib | |
| import math | |
| import sys | |
| from operator import mul | |
| import torch | |
| import torch.nn as nn | |
| from typing import Optional, Tuple, Sequence, Union | |
| from dockformerpp.model.primitives import Linear, LayerNorm, ipa_point_weights_init_ | |
| from dockformerpp.utils.residue_constants import ( | |
| restype_rigid_group_default_frame, | |
| restype_atom14_to_rigid_group, | |
| restype_atom14_mask, | |
| restype_atom14_rigid_group_positions, | |
| ) | |
| from dockformerpp.utils.geometry.quat_rigid import QuatRigid | |
| from dockformerpp.utils.geometry.rigid_matrix_vector import Rigid3Array | |
| from dockformerpp.utils.geometry.vector import Vec3Array, square_euclidean_distance | |
| from dockformerpp.utils.feats import ( | |
| frames_and_literature_positions_to_atom14_pos, | |
| torsion_angles_to_frames, | |
| ) | |
| from dockformerpp.utils.precision_utils import is_fp16_enabled | |
| from dockformerpp.utils.rigid_utils import Rotation, Rigid | |
| from dockformerpp.utils.tensor_utils import ( | |
| dict_multimap, | |
| permute_final_dims, | |
| flatten_final_dims, | |
| ) | |
| import importlib.util | |
| attn_core_is_installed = importlib.util.find_spec("attn_core_inplace_cuda") is not None | |
| attn_core_inplace_cuda = None | |
| if attn_core_is_installed: | |
| attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda") | |
| class AngleResnetBlock(nn.Module): | |
| def __init__(self, c_hidden): | |
| """ | |
| Args: | |
| c_hidden: | |
| Hidden channel dimension | |
| """ | |
| super(AngleResnetBlock, self).__init__() | |
| self.c_hidden = c_hidden | |
| self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu") | |
| self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final") | |
| self.relu = nn.ReLU() | |
| def forward(self, a: torch.Tensor) -> torch.Tensor: | |
| s_initial = a | |
| a = self.relu(a) | |
| a = self.linear_1(a) | |
| a = self.relu(a) | |
| a = self.linear_2(a) | |
| return a + s_initial | |
| class AngleResnet(nn.Module): | |
| """ | |
| Implements Algorithm 20, lines 11-14 | |
| """ | |
| def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon): | |
| """ | |
| Args: | |
| c_in: | |
| Input channel dimension | |
| c_hidden: | |
| Hidden channel dimension | |
| no_blocks: | |
| Number of resnet blocks | |
| no_angles: | |
| Number of torsion angles to generate | |
| epsilon: | |
| Small constant for normalization | |
| """ | |
| super(AngleResnet, self).__init__() | |
| self.c_in = c_in | |
| self.c_hidden = c_hidden | |
| self.no_blocks = no_blocks | |
| self.no_angles = no_angles | |
| self.eps = epsilon | |
| self.linear_in = Linear(self.c_in, self.c_hidden) | |
| self.linear_initial = Linear(self.c_in, self.c_hidden) | |
| self.layers = nn.ModuleList() | |
| for _ in range(self.no_blocks): | |
| layer = AngleResnetBlock(c_hidden=self.c_hidden) | |
| self.layers.append(layer) | |
| self.linear_out = Linear(self.c_hidden, self.no_angles * 2) | |
| self.relu = nn.ReLU() | |
| def forward( | |
| self, s: torch.Tensor, s_initial: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| s: | |
| [*, C_hidden] single embedding | |
| s_initial: | |
| [*, C_hidden] single embedding as of the start of the | |
| StructureModule | |
| Returns: | |
| [*, no_angles, 2] predicted angles | |
| """ | |
| # NOTE: The ReLU's applied to the inputs are absent from the supplement | |
| # pseudocode but present in the source. For maximal compatibility with | |
| # the pretrained weights, I'm going with the source. | |
| # [*, C_hidden] | |
| s_initial = self.relu(s_initial) | |
| s_initial = self.linear_initial(s_initial) | |
| s = self.relu(s) | |
| s = self.linear_in(s) | |
| s = s + s_initial | |
| for l in self.layers: | |
| s = l(s) | |
| s = self.relu(s) | |
| # [*, no_angles * 2] | |
| s = self.linear_out(s) | |
| # [*, no_angles, 2] | |
| s = s.view(s.shape[:-1] + (-1, 2)) | |
| unnormalized_s = s | |
| norm_denom = torch.sqrt( | |
| torch.clamp( | |
| torch.sum(s ** 2, dim=-1, keepdim=True), | |
| min=self.eps, | |
| ) | |
| ) | |
| s = s / norm_denom | |
| return unnormalized_s, s | |
| class PointProjection(nn.Module): | |
| def __init__(self, | |
| c_hidden: int, | |
| num_points: int, | |
| no_heads: int, | |
| return_local_points: bool = False, | |
| ): | |
| super().__init__() | |
| self.return_local_points = return_local_points | |
| self.no_heads = no_heads | |
| self.num_points = num_points | |
| # Multimer requires this to be run with fp32 precision during training | |
| precision = None | |
| self.linear = Linear(c_hidden, no_heads * 3 * num_points, precision=precision) | |
| def forward(self, | |
| activations: torch.Tensor, | |
| rigids: Union[Rigid, Rigid3Array], | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| # TODO: Needs to run in high precision during training | |
| points_local = self.linear(activations) | |
| out_shape = points_local.shape[:-1] + (self.no_heads, self.num_points, 3) | |
| points_local = torch.split( | |
| points_local, points_local.shape[-1] // 3, dim=-1 | |
| ) | |
| points_local = torch.stack(points_local, dim=-1).view(out_shape) | |
| points_global = rigids[..., None, None].apply(points_local) | |
| if(self.return_local_points): | |
| return points_global, points_local | |
| return points_global | |
| class InvariantPointAttention(nn.Module): | |
| """ | |
| Implements Algorithm 22. | |
| """ | |
| def __init__( | |
| self, | |
| c_s: int, | |
| c_z: int, | |
| c_hidden: int, | |
| no_heads: int, | |
| no_qk_points: int, | |
| no_v_points: int, | |
| inf: float = 1e5, | |
| eps: float = 1e-8, | |
| ): | |
| """ | |
| Args: | |
| c_s: | |
| Single representation channel dimension | |
| c_z: | |
| Pair representation channel dimension | |
| c_hidden: | |
| Hidden channel dimension | |
| no_heads: | |
| Number of attention heads | |
| no_qk_points: | |
| Number of query/key points to generate | |
| no_v_points: | |
| Number of value points to generate | |
| """ | |
| super(InvariantPointAttention, self).__init__() | |
| self.c_s = c_s | |
| self.c_z = c_z | |
| self.c_hidden = c_hidden | |
| self.no_heads = no_heads | |
| self.no_qk_points = no_qk_points | |
| self.no_v_points = no_v_points | |
| self.inf = inf | |
| self.eps = eps | |
| # These linear layers differ from their specifications in the | |
| # supplement. There, they lack bias and use Glorot initialization. | |
| # Here as in the official source, they have bias and use the default | |
| # Lecun initialization. | |
| hc = self.c_hidden * self.no_heads | |
| self.linear_q = Linear(self.c_s, hc, bias=True) | |
| self.linear_q_points = PointProjection( | |
| self.c_s, | |
| self.no_qk_points, | |
| self.no_heads, | |
| ) | |
| self.linear_kv = Linear(self.c_s, 2 * hc) | |
| self.linear_kv_points = PointProjection( | |
| self.c_s, | |
| self.no_qk_points + self.no_v_points, | |
| self.no_heads, | |
| ) | |
| self.linear_b = Linear(self.c_z, self.no_heads) | |
| self.head_weights = nn.Parameter(torch.zeros((no_heads))) | |
| ipa_point_weights_init_(self.head_weights) | |
| concat_out_dim = self.no_heads * ( | |
| self.c_z + self.c_hidden + self.no_v_points * 4 | |
| ) | |
| self.linear_out = Linear(concat_out_dim, self.c_s, init="final") | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.softplus = nn.Softplus() | |
| def forward( | |
| self, | |
| s: torch.Tensor, | |
| z: torch.Tensor, | |
| r: Union[Rigid, Rigid3Array], | |
| mask: torch.Tensor, | |
| inplace_safe: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| s: | |
| [*, N_res, C_s] single representation | |
| z: | |
| [*, N_res, N_res, C_z] pair representation | |
| r: | |
| [*, N_res] transformation object | |
| mask: | |
| [*, N_res] mask | |
| Returns: | |
| [*, N_res, C_s] single representation update | |
| """ | |
| z = [z] | |
| ####################################### | |
| # Generate scalar and point activations | |
| ####################################### | |
| # [*, N_res, H * C_hidden] | |
| q = self.linear_q(s) | |
| # [*, N_res, H, C_hidden] | |
| q = q.view(q.shape[:-1] + (self.no_heads, -1)) | |
| # [*, N_res, H, P_qk] | |
| q_pts = self.linear_q_points(s, r) | |
| # The following two blocks are equivalent | |
| # They're separated only to preserve compatibility with old AF weights | |
| # [*, N_res, H * 2 * C_hidden] | |
| kv = self.linear_kv(s) | |
| # [*, N_res, H, 2 * C_hidden] | |
| kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) | |
| # [*, N_res, H, C_hidden] | |
| k, v = torch.split(kv, self.c_hidden, dim=-1) | |
| kv_pts = self.linear_kv_points(s, r) | |
| # [*, N_res, H, P_q/P_v, 3] | |
| k_pts, v_pts = torch.split( | |
| kv_pts, [self.no_qk_points, self.no_v_points], dim=-2 | |
| ) | |
| ########################## | |
| # Compute attention scores | |
| ########################## | |
| # [*, N_res, N_res, H] | |
| b = self.linear_b(z[0]) | |
| # [*, H, N_res, N_res] | |
| if (is_fp16_enabled()): | |
| with torch.cuda.amp.autocast(enabled=False): | |
| a = torch.matmul( | |
| permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] | |
| permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] | |
| ) | |
| else: | |
| a = torch.matmul( | |
| permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] | |
| permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] | |
| ) | |
| a *= math.sqrt(1.0 / (3 * self.c_hidden)) | |
| a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) | |
| # [*, N_res, N_res, H, P_q, 3] | |
| pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) | |
| if (inplace_safe): | |
| pt_att *= pt_att | |
| else: | |
| pt_att = pt_att ** 2 | |
| pt_att = sum(torch.unbind(pt_att, dim=-1)) | |
| head_weights = self.softplus(self.head_weights).view( | |
| *((1,) * len(pt_att.shape[:-2]) + (-1, 1)) | |
| ) | |
| head_weights = head_weights * math.sqrt( | |
| 1.0 / (3 * (self.no_qk_points * 9.0 / 2)) | |
| ) | |
| if (inplace_safe): | |
| pt_att *= head_weights | |
| else: | |
| pt_att = pt_att * head_weights | |
| # [*, N_res, N_res, H] | |
| pt_att = torch.sum(pt_att, dim=-1) * (-0.5) | |
| # [*, N_res, N_res] | |
| square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) | |
| square_mask = self.inf * (square_mask - 1) | |
| # [*, H, N_res, N_res] | |
| pt_att = permute_final_dims(pt_att, (2, 0, 1)) | |
| if (inplace_safe): | |
| a += pt_att | |
| del pt_att | |
| a += square_mask.unsqueeze(-3) | |
| # in-place softmax | |
| attn_core_inplace_cuda.forward_( | |
| a, | |
| reduce(mul, a.shape[:-1]), | |
| a.shape[-1], | |
| ) | |
| else: | |
| a = a + pt_att | |
| a = a + square_mask.unsqueeze(-3) | |
| a = self.softmax(a) | |
| ################ | |
| # Compute output | |
| ################ | |
| # [*, N_res, H, C_hidden] | |
| o = torch.matmul( | |
| a, v.transpose(-2, -3).to(dtype=a.dtype) | |
| ).transpose(-2, -3) | |
| # [*, N_res, H * C_hidden] | |
| o = flatten_final_dims(o, 2) | |
| # [*, H, 3, N_res, P_v] | |
| if (inplace_safe): | |
| v_pts = permute_final_dims(v_pts, (1, 3, 0, 2)) | |
| o_pt = [ | |
| torch.matmul(a, v.to(a.dtype)) | |
| for v in torch.unbind(v_pts, dim=-3) | |
| ] | |
| o_pt = torch.stack(o_pt, dim=-3) | |
| else: | |
| o_pt = torch.sum( | |
| ( | |
| a[..., None, :, :, None] | |
| * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] | |
| ), | |
| dim=-2, | |
| ) | |
| # [*, N_res, H, P_v, 3] | |
| o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) | |
| o_pt = r[..., None, None].invert_apply(o_pt) | |
| # [*, N_res, H * P_v] | |
| o_pt_norm = flatten_final_dims( | |
| torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2 | |
| ) | |
| # [*, N_res, H * P_v, 3] | |
| o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) | |
| o_pt = torch.unbind(o_pt, dim=-1) | |
| # [*, N_res, H, C_z] | |
| o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype)) | |
| # [*, N_res, H * C_z] | |
| o_pair = flatten_final_dims(o_pair, 2) | |
| # [*, N_res, C_s] | |
| s = self.linear_out( | |
| torch.cat( | |
| (o, *o_pt, o_pt_norm, o_pair), dim=-1 | |
| ).to(dtype=z[0].dtype) | |
| ) | |
| return s | |
| class BackboneUpdate(nn.Module): | |
| """ | |
| Implements part of Algorithm 23. | |
| """ | |
| def __init__(self, c_s): | |
| """ | |
| Args: | |
| c_s: | |
| Single representation channel dimension | |
| """ | |
| super(BackboneUpdate, self).__init__() | |
| self.c_s = c_s | |
| self.linear = Linear(self.c_s, 6, init="final") | |
| def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| [*, N_res, C_s] single representation | |
| Returns: | |
| [*, N_res, 6] update vector | |
| """ | |
| # [*, 6] | |
| update = self.linear(s) | |
| return update | |
| class StructureModuleTransitionLayer(nn.Module): | |
| def __init__(self, c): | |
| super(StructureModuleTransitionLayer, self).__init__() | |
| self.c = c | |
| self.linear_1 = Linear(self.c, self.c, init="relu") | |
| self.linear_2 = Linear(self.c, self.c, init="relu") | |
| self.linear_3 = Linear(self.c, self.c, init="final") | |
| self.relu = nn.ReLU() | |
| def forward(self, s): | |
| s_initial = s | |
| s = self.linear_1(s) | |
| s = self.relu(s) | |
| s = self.linear_2(s) | |
| s = self.relu(s) | |
| s = self.linear_3(s) | |
| s = s + s_initial | |
| return s | |
| class StructureModuleTransition(nn.Module): | |
| def __init__(self, c, num_layers, dropout_rate): | |
| super(StructureModuleTransition, self).__init__() | |
| self.c = c | |
| self.num_layers = num_layers | |
| self.dropout_rate = dropout_rate | |
| self.layers = nn.ModuleList() | |
| for _ in range(self.num_layers): | |
| l = StructureModuleTransitionLayer(self.c) | |
| self.layers.append(l) | |
| self.dropout = nn.Dropout(self.dropout_rate) | |
| self.layer_norm = LayerNorm(self.c) | |
| def forward(self, s): | |
| for l in self.layers: | |
| s = l(s) | |
| s = self.dropout(s) | |
| s = self.layer_norm(s) | |
| return s | |
| class StructureModule(nn.Module): | |
| def __init__( | |
| self, | |
| c_s, | |
| c_z, | |
| c_ipa, | |
| c_resnet, | |
| no_heads_ipa, | |
| no_qk_points, | |
| no_v_points, | |
| dropout_rate, | |
| no_blocks, | |
| no_transition_layers, | |
| no_resnet_blocks, | |
| no_angles, | |
| trans_scale_factor, | |
| epsilon, | |
| inf, | |
| **kwargs, | |
| ): | |
| """ | |
| Args: | |
| c_s: | |
| Single representation channel dimension | |
| c_z: | |
| Pair representation channel dimension | |
| c_ipa: | |
| IPA hidden channel dimension | |
| c_resnet: | |
| Angle resnet (Alg. 23 lines 11-14) hidden channel dimension | |
| no_heads_ipa: | |
| Number of IPA heads | |
| no_qk_points: | |
| Number of query/key points to generate during IPA | |
| no_v_points: | |
| Number of value points to generate during IPA | |
| dropout_rate: | |
| Dropout rate used throughout the layer | |
| no_blocks: | |
| Number of structure module blocks | |
| no_transition_layers: | |
| Number of layers in the single representation transition | |
| (Alg. 23 lines 8-9) | |
| no_resnet_blocks: | |
| Number of blocks in the angle resnet | |
| no_angles: | |
| Number of angles to generate in the angle resnet | |
| trans_scale_factor: | |
| Scale of single representation transition hidden dimension | |
| epsilon: | |
| Small number used in angle resnet normalization | |
| inf: | |
| Large number used for attention masking | |
| """ | |
| super(StructureModule, self).__init__() | |
| self.c_s = c_s | |
| self.c_z = c_z | |
| self.c_ipa = c_ipa | |
| self.c_resnet = c_resnet | |
| self.no_heads_ipa = no_heads_ipa | |
| self.no_qk_points = no_qk_points | |
| self.no_v_points = no_v_points | |
| self.dropout_rate = dropout_rate | |
| self.no_blocks = no_blocks | |
| self.no_transition_layers = no_transition_layers | |
| self.no_resnet_blocks = no_resnet_blocks | |
| self.no_angles = no_angles | |
| self.trans_scale_factor = trans_scale_factor | |
| self.epsilon = epsilon | |
| self.inf = inf | |
| # Buffers to be lazily initialized later | |
| # self.default_frames | |
| # self.group_idx | |
| # self.atom_mask | |
| # self.lit_positions | |
| self.layer_norm_s = LayerNorm(self.c_s) | |
| self.layer_norm_z = LayerNorm(self.c_z) | |
| self.linear_in = Linear(self.c_s, self.c_s) | |
| self.ipa = InvariantPointAttention( | |
| self.c_s, | |
| self.c_z, | |
| self.c_ipa, | |
| self.no_heads_ipa, | |
| self.no_qk_points, | |
| self.no_v_points, | |
| inf=self.inf, | |
| eps=self.epsilon, | |
| ) | |
| self.ipa_dropout = nn.Dropout(self.dropout_rate) | |
| self.layer_norm_ipa = LayerNorm(self.c_s) | |
| self.transition = StructureModuleTransition( | |
| self.c_s, | |
| self.no_transition_layers, | |
| self.dropout_rate, | |
| ) | |
| self.bb_update = BackboneUpdate(self.c_s) | |
| self.angle_resnet = AngleResnet( | |
| self.c_s, | |
| self.c_resnet, | |
| self.no_resnet_blocks, | |
| self.no_angles, | |
| self.epsilon, | |
| ) | |
| def forward( | |
| self, | |
| evoformer_output_dict, | |
| aatype, | |
| mask=None, | |
| inplace_safe=False, | |
| ): | |
| """ | |
| Args: | |
| evoformer_output_dict: | |
| Dictionary containing: | |
| "single": | |
| [*, N_res, C_s] single representation | |
| "pair": | |
| [*, N_res, N_res, C_z] pair representation | |
| aatype: | |
| [*, N_res] amino acid indices | |
| mask: | |
| Optional [*, N_res] sequence mask | |
| Returns: | |
| A dictionary of outputs | |
| """ | |
| s = evoformer_output_dict["single"] | |
| if mask is None: | |
| # [*, N] | |
| mask = s.new_ones(s.shape[:-1]) | |
| # [*, N, C_s] | |
| s = self.layer_norm_s(s) | |
| # [*, N, N, C_z] | |
| z = self.layer_norm_z(evoformer_output_dict["pair"]) | |
| # [*, N, C_s] | |
| s_initial = s | |
| s = self.linear_in(s) | |
| # [*, N] | |
| rigids = Rigid.identity( | |
| s.shape[:-1], | |
| s.dtype, | |
| s.device, | |
| self.training, | |
| fmt="quat", | |
| ) | |
| outputs = [] | |
| for i in range(self.no_blocks): | |
| # [*, N, C_s] | |
| s = s + self.ipa( | |
| s, | |
| z, | |
| rigids, | |
| mask, | |
| inplace_safe=inplace_safe, | |
| ) | |
| s = self.ipa_dropout(s) | |
| s = self.layer_norm_ipa(s) | |
| s = self.transition(s) | |
| # [*, N] | |
| # [*, N_res, 6] vector of translations and rotations | |
| bb_update_output = self.bb_update(s) | |
| rigids = rigids.compose_q_update_vec(bb_update_output) | |
| # To hew as closely as possible to AlphaFold, we convert our | |
| # quaternion-based transformations to rotation-matrix ones | |
| # here | |
| backb_to_global = Rigid( | |
| Rotation( | |
| rot_mats=rigids.get_rots().get_rot_mats(), | |
| quats=None | |
| ), | |
| rigids.get_trans(), | |
| ) | |
| backb_to_global = backb_to_global.scale_translation( | |
| self.trans_scale_factor | |
| ) | |
| # [*, N, 7, 2] | |
| unnormalized_angles, angles = self.angle_resnet(s, s_initial) | |
| all_frames_to_global = self.torsion_angles_to_frames( | |
| backb_to_global, | |
| angles, | |
| aatype, | |
| ) | |
| pred_xyz = self.frames_and_literature_positions_to_atom14_pos( | |
| all_frames_to_global, | |
| aatype, | |
| ) | |
| scaled_rigids = rigids.scale_translation(self.trans_scale_factor) | |
| preds = { | |
| "frames": scaled_rigids.to_tensor_7(), | |
| "sidechain_frames": all_frames_to_global.to_tensor_4x4(), | |
| "unnormalized_angles": unnormalized_angles, | |
| "angles": angles, | |
| "positions": pred_xyz, | |
| "states": s, | |
| } | |
| outputs.append(preds) | |
| rigids = rigids.stop_rot_gradient() | |
| del z | |
| outputs = dict_multimap(torch.stack, outputs) | |
| outputs["single"] = s | |
| return outputs | |
| def _init_residue_constants(self, float_dtype, device): | |
| if not hasattr(self, "default_frames"): | |
| self.register_buffer( | |
| "default_frames", | |
| torch.tensor( | |
| restype_rigid_group_default_frame, | |
| dtype=float_dtype, | |
| device=device, | |
| requires_grad=False, | |
| ), | |
| persistent=False, | |
| ) | |
| if not hasattr(self, "group_idx"): | |
| self.register_buffer( | |
| "group_idx", | |
| torch.tensor( | |
| restype_atom14_to_rigid_group, | |
| device=device, | |
| requires_grad=False, | |
| ), | |
| persistent=False, | |
| ) | |
| if not hasattr(self, "atom_mask"): | |
| self.register_buffer( | |
| "atom_mask", | |
| torch.tensor( | |
| restype_atom14_mask, | |
| dtype=float_dtype, | |
| device=device, | |
| requires_grad=False, | |
| ), | |
| persistent=False, | |
| ) | |
| if not hasattr(self, "lit_positions"): | |
| self.register_buffer( | |
| "lit_positions", | |
| torch.tensor( | |
| restype_atom14_rigid_group_positions, | |
| dtype=float_dtype, | |
| device=device, | |
| requires_grad=False, | |
| ), | |
| persistent=False, | |
| ) | |
| def torsion_angles_to_frames(self, r, alpha, f): | |
| # Lazily initialize the residue constants on the correct device | |
| self._init_residue_constants(alpha.dtype, alpha.device) | |
| # Separated purely to make testing less annoying | |
| return torsion_angles_to_frames(r, alpha, f, self.default_frames) | |
| def frames_and_literature_positions_to_atom14_pos( | |
| self, r, f # [*, N, 8] # [*, N] | |
| ): | |
| # Lazily initialize the residue constants on the correct device | |
| self._init_residue_constants(r.dtype, r.device) | |
| return frames_and_literature_positions_to_atom14_pos( | |
| r, | |
| f, | |
| self.default_frames, | |
| self.group_idx, | |
| self.atom_mask, | |
| self.lit_positions, | |
| ) | |