Spaces:
Sleeping
Sleeping
| # Copyright 2021 AlQuraishi Laboratory | |
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import importlib | |
| import math | |
| from typing import Optional, Callable, List, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.utils.checkpoint | |
| from scipy.stats import truncnorm | |
| from dockformerpp.utils.kernel.attention_core import attention_core | |
| from dockformerpp.utils.precision_utils import is_fp16_enabled | |
| from dockformerpp.utils.tensor_utils import ( | |
| permute_final_dims, | |
| flatten_final_dims, | |
| ) | |
| # Suited for 40gb GPU | |
| # DEFAULT_LMA_Q_CHUNK_SIZE = 1024 | |
| # DEFAULT_LMA_KV_CHUNK_SIZE = 4096 | |
| # Suited for 10gb GPU | |
| DEFAULT_LMA_Q_CHUNK_SIZE = 64 | |
| DEFAULT_LMA_KV_CHUNK_SIZE = 256 | |
| def _prod(nums): | |
| out = 1 | |
| for n in nums: | |
| out = out * n | |
| return out | |
| def _calculate_fan(linear_weight_shape, fan="fan_in"): | |
| fan_out, fan_in = linear_weight_shape | |
| if fan == "fan_in": | |
| f = fan_in | |
| elif fan == "fan_out": | |
| f = fan_out | |
| elif fan == "fan_avg": | |
| f = (fan_in + fan_out) / 2 | |
| else: | |
| raise ValueError("Invalid fan option") | |
| return f | |
| def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): | |
| shape = weights.shape | |
| f = _calculate_fan(shape, fan) | |
| scale = scale / max(1, f) | |
| a = -2 | |
| b = 2 | |
| std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) | |
| size = _prod(shape) | |
| samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) | |
| samples = np.reshape(samples, shape) | |
| with torch.no_grad(): | |
| weights.copy_(torch.tensor(samples, device=weights.device)) | |
| def lecun_normal_init_(weights): | |
| trunc_normal_init_(weights, scale=1.0) | |
| def he_normal_init_(weights): | |
| trunc_normal_init_(weights, scale=2.0) | |
| def glorot_uniform_init_(weights): | |
| nn.init.xavier_uniform_(weights, gain=1) | |
| def final_init_(weights): | |
| with torch.no_grad(): | |
| weights.fill_(0.0) | |
| def gating_init_(weights): | |
| with torch.no_grad(): | |
| weights.fill_(0.0) | |
| def normal_init_(weights): | |
| torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") | |
| def ipa_point_weights_init_(weights): | |
| with torch.no_grad(): | |
| softplus_inverse_1 = 0.541324854612918 | |
| weights.fill_(softplus_inverse_1) | |
| class Linear(nn.Linear): | |
| """ | |
| A Linear layer with built-in nonstandard initializations. Called just | |
| like torch.nn.Linear. | |
| Implements the initializers in 1.11.4, plus some additional ones found | |
| in the code. | |
| """ | |
| def __init__( | |
| self, | |
| in_dim: int, | |
| out_dim: int, | |
| bias: bool = True, | |
| init: str = "default", | |
| init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, | |
| precision=None | |
| ): | |
| """ | |
| Args: | |
| in_dim: | |
| The final dimension of inputs to the layer | |
| out_dim: | |
| The final dimension of layer outputs | |
| bias: | |
| Whether to learn an additive bias. True by default | |
| init: | |
| The initializer to use. Choose from: | |
| "default": LeCun fan-in truncated normal initialization | |
| "relu": He initialization w/ truncated normal distribution | |
| "glorot": Fan-average Glorot uniform initialization | |
| "gating": Weights=0, Bias=1 | |
| "normal": Normal initialization with std=1/sqrt(fan_in) | |
| "final": Weights=0, Bias=0 | |
| Overridden by init_fn if the latter is not None. | |
| init_fn: | |
| A custom initializer taking weight and bias as inputs. | |
| Overrides init if not None. | |
| """ | |
| super(Linear, self).__init__(in_dim, out_dim, bias=bias) | |
| if bias: | |
| with torch.no_grad(): | |
| self.bias.fill_(0) | |
| with torch.no_grad(): | |
| if init_fn is not None: | |
| init_fn(self.weight, self.bias) | |
| else: | |
| if init == "default": | |
| lecun_normal_init_(self.weight) | |
| elif init == "relu": | |
| he_normal_init_(self.weight) | |
| elif init == "glorot": | |
| glorot_uniform_init_(self.weight) | |
| elif init == "gating": | |
| gating_init_(self.weight) | |
| if bias: | |
| self.bias.fill_(1.0) | |
| elif init == "normal": | |
| normal_init_(self.weight) | |
| elif init == "final": | |
| final_init_(self.weight) | |
| else: | |
| raise ValueError("Invalid init string.") | |
| self.precision = precision | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| d = input.dtype | |
| if self.precision is not None: | |
| with torch.cuda.amp.autocast(enabled=False): | |
| bias = self.bias.to(dtype=self.precision) if self.bias is not None else None | |
| return nn.functional.linear(input.to(dtype=self.precision), | |
| self.weight.to(dtype=self.precision), | |
| bias).to(dtype=d) | |
| if d is torch.bfloat16: | |
| with torch.cuda.amp.autocast(enabled=False): | |
| bias = self.bias.to(dtype=d) if self.bias is not None else None | |
| return nn.functional.linear(input, self.weight.to(dtype=d), bias) | |
| return nn.functional.linear(input, self.weight, self.bias) | |
| class LayerNorm(nn.Module): | |
| def __init__(self, c_in, eps=1e-5): | |
| super(LayerNorm, self).__init__() | |
| self.c_in = (c_in,) | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(c_in)) | |
| self.bias = nn.Parameter(torch.zeros(c_in)) | |
| def forward(self, x): | |
| d = x.dtype | |
| if d is torch.bfloat16: | |
| with torch.cuda.amp.autocast(enabled=False): | |
| out = nn.functional.layer_norm( | |
| x, | |
| self.c_in, | |
| self.weight.to(dtype=d), | |
| self.bias.to(dtype=d), | |
| self.eps | |
| ) | |
| else: | |
| out = nn.functional.layer_norm( | |
| x, | |
| self.c_in, | |
| self.weight, | |
| self.bias, | |
| self.eps, | |
| ) | |
| return out | |
| def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: | |
| """ | |
| Softmax, but without automatic casting to fp32 when the input is of | |
| type bfloat16 | |
| """ | |
| d = t.dtype | |
| if d is torch.bfloat16: | |
| with torch.cuda.amp.autocast(enabled=False): | |
| s = torch.nn.functional.softmax(t, dim=dim) | |
| else: | |
| s = torch.nn.functional.softmax(t, dim=dim) | |
| return s | |
| #@torch.jit.script | |
| def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor: | |
| # [*, H, C_hidden, K] | |
| key = permute_final_dims(key, (1, 0)) | |
| # [*, H, Q, K] | |
| a = torch.matmul(query, key) | |
| for b in biases: | |
| a += b | |
| a = softmax_no_cast(a, -1) | |
| # [*, H, Q, C_hidden] | |
| a = torch.matmul(a, value) | |
| return a | |
| class Attention(nn.Module): | |
| """ | |
| Standard multi-head attention using AlphaFold's default layer | |
| initialization. Allows multiple bias vectors. | |
| """ | |
| def __init__( | |
| self, | |
| c_q: int, | |
| c_k: int, | |
| c_v: int, | |
| c_hidden: int, | |
| no_heads: int, | |
| gating: bool = True, | |
| ): | |
| """ | |
| Args: | |
| c_q: | |
| Input dimension of query data | |
| c_k: | |
| Input dimension of key data | |
| c_v: | |
| Input dimension of value data | |
| c_hidden: | |
| Per-head hidden dimension | |
| no_heads: | |
| Number of attention heads | |
| gating: | |
| Whether the output should be gated using query data | |
| """ | |
| super(Attention, self).__init__() | |
| self.c_q = c_q | |
| self.c_k = c_k | |
| self.c_v = c_v | |
| self.c_hidden = c_hidden | |
| self.no_heads = no_heads | |
| self.gating = gating | |
| # DISCREPANCY: c_hidden is not the per-head channel dimension, as | |
| # stated in the supplement, but the overall channel dimension. | |
| self.linear_q = Linear( | |
| self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot" | |
| ) | |
| self.linear_k = Linear( | |
| self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot" | |
| ) | |
| self.linear_v = Linear( | |
| self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot" | |
| ) | |
| self.linear_o = Linear( | |
| self.c_hidden * self.no_heads, self.c_q, init="final" | |
| ) | |
| self.linear_g = None | |
| if self.gating: | |
| self.linear_g = Linear( | |
| self.c_q, self.c_hidden * self.no_heads, init="gating" | |
| ) | |
| self.sigmoid = nn.Sigmoid() | |
| def _prep_qkv(self, | |
| q_x: torch.Tensor, | |
| kv_x: torch.Tensor, | |
| apply_scale: bool = True | |
| ) -> Tuple[ | |
| torch.Tensor, torch.Tensor, torch.Tensor | |
| ]: | |
| # [*, Q/K/V, H * C_hidden] | |
| q = self.linear_q(q_x) | |
| k = self.linear_k(kv_x) | |
| v = self.linear_v(kv_x) | |
| # [*, Q/K, H, C_hidden] | |
| q = q.view(q.shape[:-1] + (self.no_heads, -1)) | |
| k = k.view(k.shape[:-1] + (self.no_heads, -1)) | |
| v = v.view(v.shape[:-1] + (self.no_heads, -1)) | |
| # [*, H, Q/K, C_hidden] | |
| q = q.transpose(-2, -3) | |
| k = k.transpose(-2, -3) | |
| v = v.transpose(-2, -3) | |
| if apply_scale: | |
| q /= math.sqrt(self.c_hidden) | |
| return q, k, v | |
| def _wrap_up(self, | |
| o: torch.Tensor, | |
| q_x: torch.Tensor | |
| ) -> torch.Tensor: | |
| if self.linear_g is not None: | |
| g = self.sigmoid(self.linear_g(q_x)) | |
| # [*, Q, H, C_hidden] | |
| g = g.view(g.shape[:-1] + (self.no_heads, -1)) | |
| o = o * g | |
| # [*, Q, H * C_hidden] | |
| o = flatten_final_dims(o, 2) | |
| # [*, Q, C_q] | |
| o = self.linear_o(o) | |
| return o | |
| def forward( | |
| self, | |
| q_x: torch.Tensor, | |
| kv_x: torch.Tensor, | |
| biases: Optional[List[torch.Tensor]] = None, | |
| use_memory_efficient_kernel: bool = False, | |
| use_lma: bool = False, | |
| lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE, | |
| lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| q_x: | |
| [*, Q, C_q] query data | |
| kv_x: | |
| [*, K, C_k] key data | |
| biases: | |
| List of biases that broadcast to [*, H, Q, K] | |
| use_memory_efficient_kernel: | |
| Whether to use a custom memory-efficient attention kernel. | |
| This should be the default choice for most. If none of the | |
| "use_<...>" flags are True, a stock PyTorch implementation | |
| is used instead | |
| use_lma: | |
| Whether to use low-memory attention (Staats & Rabe 2021). If | |
| none of the "use_<...>" flags are True, a stock PyTorch | |
| implementation is used instead | |
| lma_q_chunk_size: | |
| Query chunk size (for LMA) | |
| lma_kv_chunk_size: | |
| Key/Value chunk size (for LMA) | |
| Returns | |
| [*, Q, C_q] attention update | |
| """ | |
| if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None): | |
| raise ValueError( | |
| "If use_lma is specified, lma_q_chunk_size and " | |
| "lma_kv_chunk_size must be provided" | |
| ) | |
| attn_options = [use_memory_efficient_kernel, use_lma] | |
| if sum(attn_options) > 1: | |
| raise ValueError( | |
| "Choose at most one alternative attention algorithm" | |
| ) | |
| if biases is None: | |
| biases = [] | |
| q, k, v = self._prep_qkv(q_x, kv_x, apply_scale=True) | |
| if is_fp16_enabled(): | |
| use_memory_efficient_kernel = False | |
| if use_memory_efficient_kernel: | |
| if len(biases) > 2: | |
| raise ValueError( | |
| "If use_memory_efficient_kernel is True, you may only " | |
| "provide up to two bias terms" | |
| ) | |
| o = attention_core(q, k, v, *((biases + [None] * 2)[:2])) | |
| o = o.transpose(-2, -3) | |
| elif use_lma: | |
| biases = [ | |
| b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) | |
| for b in biases | |
| ] | |
| o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size) | |
| o = o.transpose(-2, -3) | |
| else: | |
| o = _attention(q, k, v, biases) | |
| o = o.transpose(-2, -3) | |
| o = self._wrap_up(o, q_x) | |
| return o | |
| class GlobalAttention(nn.Module): | |
| def __init__(self, c_in, c_hidden, no_heads, inf, eps): | |
| super(GlobalAttention, self).__init__() | |
| self.c_in = c_in | |
| self.c_hidden = c_hidden | |
| self.no_heads = no_heads | |
| self.inf = inf | |
| self.eps = eps | |
| self.linear_q = Linear( | |
| c_in, c_hidden * no_heads, bias=False, init="glorot" | |
| ) | |
| self.linear_k = Linear( | |
| c_in, c_hidden, bias=False, init="glorot", | |
| ) | |
| self.linear_v = Linear( | |
| c_in, c_hidden, bias=False, init="glorot", | |
| ) | |
| self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") | |
| self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, | |
| m: torch.Tensor, | |
| mask: torch.Tensor, | |
| use_lma: bool = False, | |
| ) -> torch.Tensor: | |
| # [*, N_res, C_in] | |
| q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / ( | |
| torch.sum(mask, dim=-1)[..., None] + self.eps | |
| ) | |
| # [*, N_res, H * C_hidden] | |
| q = self.linear_q(q) | |
| q *= (self.c_hidden ** (-0.5)) | |
| # [*, N_res, H, C_hidden] | |
| q = q.view(q.shape[:-1] + (self.no_heads, -1)) | |
| # [*, N_res, C_hidden] | |
| k = self.linear_k(m) | |
| v = self.linear_v(m) | |
| bias = (self.inf * (mask - 1))[..., :, None, :] | |
| if not use_lma: | |
| # [*, N_res, H, N_seq] | |
| a = torch.matmul( | |
| q, | |
| k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] | |
| ) | |
| a += bias | |
| a = softmax_no_cast(a) | |
| # [*, N_res, H, C_hidden] | |
| o = torch.matmul( | |
| a, | |
| v, | |
| ) | |
| else: | |
| o = _lma( | |
| q, | |
| k, | |
| v, | |
| [bias], | |
| DEFAULT_LMA_Q_CHUNK_SIZE, | |
| DEFAULT_LMA_KV_CHUNK_SIZE | |
| ) | |
| # [*, N_res, C_hidden] | |
| g = self.sigmoid(self.linear_g(m)) | |
| # [*, N_res, H, C_hidden] | |
| g = g.view(g.shape[:-1] + (self.no_heads, -1)) | |
| # [*, N_res, H, C_hidden] | |
| o = o.unsqueeze(-3) * g | |
| # [*, N_res, H * C_hidden] | |
| o = o.reshape(o.shape[:-2] + (-1,)) | |
| # [*, N_res, C_in] | |
| m = self.linear_o(o) | |
| return m | |
| def _lma( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| biases: List[torch.Tensor], | |
| q_chunk_size: int, | |
| kv_chunk_size: int, | |
| ): | |
| no_q, no_kv = q.shape[-2], k.shape[-2] | |
| # [*, H, Q, C_hidden] | |
| o = q.new_zeros(q.shape) | |
| for q_s in range(0, no_q, q_chunk_size): | |
| q_chunk = q[..., q_s: q_s + q_chunk_size, :] | |
| large_bias_chunks = [ | |
| b[..., q_s: q_s + q_chunk_size, :] for b in biases | |
| ] | |
| maxes = [] | |
| weights = [] | |
| values = [] | |
| for kv_s in range(0, no_kv, kv_chunk_size): | |
| k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :] | |
| v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :] | |
| small_bias_chunks = [ | |
| b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks | |
| ] | |
| a = torch.einsum( | |
| "...hqd,...hkd->...hqk", q_chunk, k_chunk, | |
| ) | |
| for b in small_bias_chunks: | |
| a += b | |
| max_a = torch.max(a, dim=-1, keepdim=True)[0] | |
| exp_a = torch.exp(a - max_a) | |
| exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a) | |
| maxes.append(max_a.detach().squeeze(-1)) | |
| weights.append(torch.sum(exp_a, dim=-1)) | |
| values.append(exp_v) | |
| chunk_max = torch.stack(maxes, dim=-3) | |
| chunk_weights = torch.stack(weights, dim=-3) | |
| chunk_values = torch.stack(values, dim=-4) | |
| global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] | |
| max_diffs = torch.exp(chunk_max - global_max) | |
| chunk_values = chunk_values * max_diffs.unsqueeze(-1) | |
| chunk_weights = chunk_weights * max_diffs | |
| all_values = torch.sum(chunk_values, dim=-4) | |
| all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) | |
| q_chunk_out = all_values / all_weights | |
| o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out | |
| return o | |