Spaces:
Runtime error
Runtime error
| import warnings | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| class SelfAttention(nn.Module): | |
| def __init__(self, config): | |
| """ | |
| Initializes the SelfAttention module. | |
| Args: | |
| config: An object containing the configuration parameters for the SelfAttention module. | |
| """ | |
| super().__init__() | |
| self._validate_config(config) | |
| self._initialize_parameters(config) | |
| def empty_kv_cache(self, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype): | |
| """ | |
| Empties the key-value cache. | |
| Args: | |
| batch_size: The batch size. | |
| kv_cache_maxlen: The maximum length of the key-value cache. | |
| dtype: The data type of the cache. | |
| Raises: | |
| Exception: If trying to empty the KV cache when it is disabled. | |
| """ | |
| if self.kv_cache_enabled is False: | |
| raise Exception("Trying to empty KV cache when it is disabled") | |
| # register so that the cache moves devices along with the module | |
| # TODO: get rid of re-allocation. | |
| self.register_buffer( | |
| "kv_cache", | |
| torch.zeros( | |
| 2, | |
| batch_size, | |
| kv_cache_maxlen, | |
| self.n_head, | |
| self.n_embd // self.n_head, | |
| dtype=dtype, | |
| device=self.c_attn.weight.device, | |
| ), | |
| persistent=False, | |
| ) | |
| self.kv_cache_first_empty_index = 0 | |
| def _initialize_parameters(self, config): | |
| """ | |
| Initializes the parameters of the SelfAttention module. | |
| Args: | |
| config: An object containing the configuration parameters for the SelfAttention module. | |
| """ | |
| # key, query, value projections for all heads, but in a batch | |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) | |
| # output projection | |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | |
| # regularization | |
| self.resid_dropout = nn.Dropout(config.dropout) | |
| self.n_head = config.n_head | |
| self.n_embd = config.n_embd | |
| self.dropout = config.dropout | |
| self.causal = config.causal | |
| self.attn_kernel_type = config.attn_kernel_type | |
| self.attn_dropout = nn.Dropout(config.dropout) | |
| self.kv_cache_enabled = False | |
| def _validate_config(self, config): | |
| """ | |
| Validates the configuration parameters. | |
| Args: | |
| config: An object containing the configuration parameters for the SelfAttention module. | |
| Raises: | |
| AssertionError: If the embedding dimension is not divisible by the number of heads. | |
| """ | |
| assert config.n_embd % config.n_head == 0, "Embedding dimension must be divisible by number of heads" | |
| def _update_kv_cache(self, q, k, v): | |
| """ | |
| Updates the key-value cache. | |
| Args: | |
| q: The query tensor. | |
| k: The key tensor. | |
| v: The value tensor. | |
| Returns: | |
| The updated key and value tensors. | |
| Raises: | |
| AssertionError: If the dimensions of the query, key, and value tensors are not compatible. | |
| """ | |
| q_time, k_time, v_time = q.shape[1], k.shape[1], v.shape[1] | |
| if self.kv_cache_first_empty_index == 0: | |
| assert q_time == k_time and q_time == v_time | |
| else: | |
| assert ( | |
| q_time == 1 | |
| ), f"Only one query at a time is supported, but got q_time={q_time} for kv_cache_first_empty_index={self.kv_cache_first_empty_index}" | |
| self.kv_cache[0, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = k | |
| self.kv_cache[1, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = v | |
| self.kv_cache_first_empty_index += q_time | |
| k = self.kv_cache[0, :, : self.kv_cache_first_empty_index] | |
| v = self.kv_cache[1, :, : self.kv_cache_first_empty_index] | |
| return k, v | |
| def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Performs attention using the torch.nn.functional.scaled_dot_product_attention function. | |
| Args: | |
| c_x: The input tensor. | |
| Returns: | |
| The output tensor. | |
| """ | |
| q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, 1, nh, hs) | |
| q = q.squeeze(2) # (B, T, nh, hs) | |
| k = k.squeeze(2) # (B, T, nh, hs) | |
| v = v.squeeze(2) # (B, T, nh, hs) | |
| # if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and | |
| # use no mask for the "one time step" parts. | |
| # calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index | |
| is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0) | |
| if self.kv_cache_enabled: | |
| k, v = self._update_kv_cache(q, k, v) | |
| q = q.transpose(1, 2) # (B, nh, T, hs) | |
| k = k.transpose(1, 2) # (B, nh, T, hs) | |
| v = v.transpose(1, 2) # (B, nh, T, hs) | |
| y = torch.nn.functional.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| attn_mask=None, | |
| dropout_p=self.dropout if self.training else 0, | |
| is_causal=is_causal_attn_mask, | |
| ).transpose( | |
| 1, 2 | |
| ) # (B, nh, T, hs) -> (B, T, nh, hs) | |
| return y | |
| def forward(self, x): | |
| """ | |
| Performs the forward pass of the SelfAttention module. | |
| Args: | |
| x: The input tensor. | |
| Returns: | |
| The output tensor. | |
| """ | |
| B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) | |
| # calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
| c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs) | |
| # causal self-attention; | |
| if self.attn_kernel_type == "torch_attn": | |
| y = self._torch_attn(c_x) | |
| else: | |
| raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}") | |
| y = y.contiguous().view(B, T, C) # re-assemble all head outputs side by side: (B, T, nh, hs) -> (B, T, hs * nh) | |
| # output projection | |
| y = self.resid_dropout(self.c_proj(y)) | |
| return y | |