Spaces:
Sleeping
Sleeping
| # Copyright 2021 AlQuraishi Laboratory | |
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from functools import partial | |
| import weakref | |
| import torch | |
| import torch.nn as nn | |
| from dockformerpp.utils.tensor_utils import masked_mean | |
| from dockformerpp.model.embedders import ( | |
| StructureInputEmbedder, | |
| RecyclingEmbedder, | |
| ) | |
| from dockformerpp.model.evoformer import EvoformerStack | |
| from dockformerpp.model.heads import AuxiliaryHeads | |
| from dockformerpp.model.structure_module import StructureModule | |
| import dockformerpp.utils.residue_constants as residue_constants | |
| from dockformerpp.utils.feats import ( | |
| pseudo_beta_fn, | |
| atom14_to_atom37, | |
| ) | |
| from dockformerpp.utils.tensor_utils import ( | |
| add, | |
| tensor_tree_map, | |
| ) | |
| class AlphaFold(nn.Module): | |
| """ | |
| Alphafold 2. | |
| Implements Algorithm 2 (but with training). | |
| """ | |
| def __init__(self, config): | |
| """ | |
| Args: | |
| config: | |
| A dict-like config object (like the one in config.py) | |
| """ | |
| super(AlphaFold, self).__init__() | |
| self.globals = config.globals | |
| self.config = config.model | |
| # Main trunk + structure module | |
| self.input_embedder = StructureInputEmbedder( | |
| **self.config["structure_input_embedder"], | |
| ) | |
| self.recycling_embedder = RecyclingEmbedder( | |
| **self.config["recycling_embedder"], | |
| ) | |
| self.evoformer = EvoformerStack( | |
| **self.config["evoformer_stack"], | |
| ) | |
| self.structure_module = StructureModule( | |
| **self.config["structure_module"], | |
| ) | |
| self.aux_heads = AuxiliaryHeads( | |
| self.config["heads"], | |
| ) | |
| def tolerance_reached(self, prev_pos, next_pos, mask, eps=1e-8) -> bool: | |
| """ | |
| Early stopping criteria based on criteria used in | |
| AF2Complex: https://www.nature.com/articles/s41467-022-29394-2 | |
| Args: | |
| prev_pos: Previous atom positions in atom37/14 representation | |
| next_pos: Current atom positions in atom37/14 representation | |
| mask: 1-D sequence mask | |
| eps: Epsilon used in square root calculation | |
| Returns: | |
| Whether to stop recycling early based on the desired tolerance. | |
| """ | |
| def distances(points): | |
| """Compute all pairwise distances for a set of points.""" | |
| d = points[..., None, :] - points[..., None, :, :] | |
| return torch.sqrt(torch.sum(d ** 2, dim=-1)) | |
| if self.config.recycle_early_stop_tolerance < 0: | |
| return False | |
| ca_idx = residue_constants.atom_order['CA'] | |
| sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2 | |
| mask = mask[..., None] * mask[..., None, :] | |
| sq_diff = masked_mean(mask=mask, value=sq_diff, dim=list(range(len(mask.shape)))) | |
| diff = torch.sqrt(sq_diff + eps).item() | |
| return diff <= self.config.recycle_early_stop_tolerance | |
| def iteration(self, feats, prevs, _recycle=True): | |
| # Primary output dictionary | |
| outputs = {} | |
| # This needs to be done manually for DeepSpeed's sake | |
| dtype = next(self.parameters()).dtype | |
| for k in feats: | |
| if feats[k].dtype == torch.float32: | |
| feats[k] = feats[k].to(dtype=dtype) | |
| # Grab some data about the input | |
| batch_dims, n_total = feats["token_mask"].shape | |
| device = feats["token_mask"].device | |
| print("doing sample of size", feats["token_mask"].shape, | |
| feats["protein_r_mask"].sum(dim=1), feats["protein_l_mask"].sum(dim=1)) | |
| # Controls whether the model uses in-place operations throughout | |
| # The dual condition accounts for activation checkpoints | |
| # inplace_safe = not (self.training or torch.is_grad_enabled()) | |
| inplace_safe = False # so we don't need attn_core_inplace_cuda | |
| # Prep some features | |
| token_mask = feats["token_mask"] | |
| pair_mask = token_mask[..., None] * token_mask[..., None, :] | |
| # Initialize the single and pair representations | |
| # m: [*, 1, n_total, C_m] | |
| # z: [*, n_total, n_total, C_z] | |
| m, z = self.input_embedder( | |
| feats["token_mask"], | |
| feats["protein_r_mask"], | |
| feats["protein_l_mask"], | |
| feats["target_feat"], | |
| feats["input_positions"], | |
| feats["residue_index"], | |
| feats["distogram_mask"], | |
| inplace_safe=inplace_safe, | |
| ) | |
| # Unpack the recycling embeddings. Removing them from the list allows | |
| # them to be freed further down in this function, saving memory | |
| m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)]) | |
| # Initialize the recycling embeddings, if needs be | |
| if None in [m_1_prev, z_prev, x_prev]: | |
| # [*, N, C_m] | |
| m_1_prev = m.new_zeros( | |
| (batch_dims, n_total, self.config.structure_input_embedder.c_m), | |
| requires_grad=False, | |
| ) | |
| # [*, N, N, C_z] | |
| z_prev = z.new_zeros( | |
| (batch_dims, n_total, n_total, self.config.structure_input_embedder.c_z), | |
| requires_grad=False, | |
| ) | |
| # [*, N, 3] | |
| x_prev = z.new_zeros( | |
| (batch_dims, n_total, residue_constants.atom_type_num, 3), | |
| requires_grad=False, | |
| ) | |
| # shape == [1, n_total, 37, 3] | |
| pseudo_beta_or_lig_x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None).to(dtype=z.dtype) | |
| # m_1_prev_emb: [*, N, C_m] | |
| # z_prev_emb: [*, N, N, C_z] | |
| m_1_prev_emb, z_prev_emb = self.recycling_embedder( | |
| m_1_prev, | |
| z_prev, | |
| pseudo_beta_or_lig_x_prev, | |
| inplace_safe=inplace_safe, | |
| ) | |
| del pseudo_beta_or_lig_x_prev | |
| # [*, S_c, N, C_m] | |
| m += m_1_prev_emb | |
| # [*, N, N, C_z] | |
| z = add(z, z_prev_emb, inplace=inplace_safe) | |
| # Deletions like these become significant for inference with large N, | |
| # where they free unused tensors and remove references to others such | |
| # that they can be offloaded later | |
| del m_1_prev, z_prev, m_1_prev_emb, z_prev_emb | |
| # Run single + pair embeddings through the trunk of the network | |
| # m: [*, N, C_m] | |
| # z: [*, N, N, C_z] | |
| # s: [*, N, C_s] | |
| m, z, s = self.evoformer( | |
| m, | |
| z, | |
| single_mask=token_mask.to(dtype=m.dtype), | |
| pair_mask=pair_mask.to(dtype=z.dtype), | |
| use_lma=self.globals.use_lma, | |
| inplace_safe=inplace_safe, | |
| _mask_trans=self.config._mask_trans, | |
| ) | |
| outputs["pair"] = z | |
| outputs["single"] = s | |
| del z | |
| # Predict 3D structure | |
| outputs["sm"] = self.structure_module( | |
| outputs, | |
| feats["aatype"], | |
| mask=token_mask.to(dtype=s.dtype), | |
| inplace_safe=inplace_safe, | |
| ) | |
| outputs["final_atom_positions"] = atom14_to_atom37( | |
| outputs["sm"]["positions"][-1], feats | |
| ) | |
| outputs["final_atom_mask"] = feats["atom37_atom_exists"] | |
| # Save embeddings for use during the next recycling iteration | |
| # [*, N, C_m] | |
| m_1_prev = m[..., 0, :, :] | |
| # [*, N, N, C_z] | |
| z_prev = outputs["pair"] | |
| # TODO bshor: early stop depends on is_multimer, but I don't think it must | |
| early_stop = False | |
| # if self.globals.is_multimer: | |
| # early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask) | |
| del x_prev | |
| # [*, N, 3] | |
| x_prev = outputs["final_atom_positions"] | |
| return outputs, m_1_prev, z_prev, x_prev, early_stop | |
| def forward(self, batch): | |
| """ | |
| Args: | |
| batch: | |
| Dictionary of arguments outlined in Algorithm 2. Keys must | |
| include the official names of the features in the | |
| supplement subsection 1.2.9. | |
| The final dimension of each input must have length equal to | |
| the number of recycling iterations. | |
| Features (without the recycling dimension): | |
| "aatype" ([*, N_res]): | |
| Contrary to the supplement, this tensor of residue | |
| indices is not one-hot. | |
| "protein_target_feat" ([*, N_res, C_tf]) | |
| One-hot encoding of the target sequence. C_tf is | |
| config.model.input_embedder.tf_dim. | |
| "residue_index" ([*, N_res]) | |
| Tensor whose final dimension consists of | |
| consecutive indices from 0 to N_res. | |
| "token_mask" ([*, N_token]) | |
| 1-D token mask | |
| "pair_mask" ([*, N_token, N_token]) | |
| 2-D pair mask | |
| """ | |
| # Initialize recycling embeddings | |
| m_1_prev, z_prev, x_prev = None, None, None | |
| prevs = [m_1_prev, z_prev, x_prev] | |
| is_grad_enabled = torch.is_grad_enabled() | |
| # Main recycling loop | |
| num_iters = batch["aatype"].shape[-1] | |
| early_stop = False | |
| num_recycles = 0 | |
| for cycle_no in range(num_iters): | |
| # Select the features for the current recycling cycle | |
| fetch_cur_batch = lambda t: t[..., cycle_no] | |
| feats = tensor_tree_map(fetch_cur_batch, batch) | |
| # Enable grad iff we're training and it's the final recycling layer | |
| is_final_iter = cycle_no == (num_iters - 1) or early_stop | |
| with torch.set_grad_enabled(is_grad_enabled and is_final_iter): | |
| if is_final_iter: | |
| # Sidestep AMP bug (PyTorch issue #65766) | |
| if torch.is_autocast_enabled(): | |
| torch.clear_autocast_cache() | |
| # Run the next iteration of the model | |
| outputs, m_1_prev, z_prev, x_prev, early_stop = self.iteration( | |
| feats, | |
| prevs, | |
| _recycle=(num_iters > 1) | |
| ) | |
| num_recycles += 1 | |
| if not is_final_iter: | |
| del outputs | |
| prevs = [m_1_prev, z_prev, x_prev] | |
| del m_1_prev, z_prev, x_prev | |
| else: | |
| break | |
| outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device) | |
| # Run auxiliary heads, remove the recycling dimension batch properties | |
| outputs.update(self.aux_heads(outputs, batch["inter_pair_mask"][..., 0], batch["affinity_mask"][..., 0])) | |
| return outputs | |