Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| def norm(x: torch.Tensor) -> torch.Tensor: | |
| return F.rms_norm(x, (x.size(-1),)) | |
| def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: | |
| assert x.ndim == 4 | |
| d = x.shape[3] // 2 | |
| x1, x2 = x[..., :d], x[..., d:] | |
| y1 = x1 * cos + x2 * sin | |
| y2 = x1 * (-sin) + x2 * cos | |
| out = torch.cat([y1, y2], 3) | |
| out = out.to(x.dtype) | |
| return out | |
| def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| if n_rep == 1: | |
| return x | |
| bs, n_kv_heads, slen, head_dim = x.shape | |
| return ( | |
| x[:, :, None, :, :] | |
| .expand(bs, n_kv_heads, n_rep, slen, head_dim) | |
| .reshape(bs, n_kv_heads * n_rep, slen, head_dim) | |
| ) | |
| class GPTConfig: | |
| sequence_len: int = 1024 | |
| vocab_size: int = 50304 | |
| n_layer: int = 12 | |
| n_head: int = 6 | |
| n_kv_head: int = 6 | |
| n_embd: int = 768 | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, config: GPTConfig, layer_idx: int): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.n_head = config.n_head | |
| self.n_kv_head = config.n_kv_head | |
| self.n_embd = config.n_embd | |
| self.head_dim = self.n_embd // self.n_head | |
| assert self.n_embd % self.n_head == 0 | |
| assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 | |
| self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) | |
| self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) | |
| self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) | |
| self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) | |
| def forward(self, x: torch.Tensor, cos_sin, kv_cache=None): | |
| B, T, C = x.size() | |
| q = self.c_q(x).view(B, T, self.n_head, self.head_dim) | |
| k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) | |
| v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) | |
| cos, sin = cos_sin | |
| q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) | |
| q, k = norm(q), norm(k) | |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) | |
| if kv_cache is not None: | |
| k, v = kv_cache.insert_kv(self.layer_idx, k, v) | |
| Tq = q.size(2) | |
| Tk = k.size(2) | |
| nrep = self.n_head // self.n_kv_head | |
| k, v = repeat_kv(k, nrep), repeat_kv(v, nrep) | |
| if kv_cache is None or Tq == Tk: | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) | |
| elif Tq == 1: | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=False) | |
| else: | |
| attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) | |
| prefix_len = Tk - Tq | |
| if prefix_len > 0: | |
| attn_mask[:, :prefix_len] = True | |
| attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) | |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) | |
| y = y.transpose(1, 2).contiguous().view(B, T, -1) | |
| y = self.c_proj(y) | |
| return y | |
| class MLP(nn.Module): | |
| def __init__(self, config: GPTConfig): | |
| super().__init__() | |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) | |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.c_fc(x) | |
| x = F.relu(x).square() | |
| x = self.c_proj(x) | |
| return x | |
| class Block(nn.Module): | |
| def __init__(self, config: GPTConfig, layer_idx: int): | |
| super().__init__() | |
| self.attn = CausalSelfAttention(config, layer_idx) | |
| self.mlp = MLP(config) | |
| def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor: | |
| x = x + self.attn(norm(x), cos_sin, kv_cache) | |
| x = x + self.mlp(norm(x)) | |
| return x | |
| class GPT(nn.Module): | |
| def __init__(self, config: GPTConfig): | |
| super().__init__() | |
| self.config = config | |
| self.transformer = nn.ModuleDict({ | |
| 'wte': nn.Embedding(config.vocab_size, config.n_embd), | |
| 'h': nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), | |
| }) | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| self.rotary_seq_len = config.sequence_len * 10 | |
| head_dim = config.n_embd // config.n_head | |
| cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) | |
| self.register_buffer('cos', cos, persistent=False) | |
| self.register_buffer('sin', sin, persistent=False) | |
| self.transformer.wte.to(dtype=torch.bfloat16) | |
| def _precompute_rotary_embeddings(self, seq_len: int, head_dim: int, base: int = 10000, device=None): | |
| if device is None: | |
| device = self.transformer.wte.weight.device | |
| channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) | |
| inv_freq = 1.0 / (base ** (channel_range / head_dim)) | |
| t = torch.arange(seq_len, dtype=torch.float32, device=device) | |
| freqs = torch.outer(t, inv_freq) | |
| cos, sin = freqs.cos(), freqs.sin() | |
| cos, sin = cos.bfloat16(), sin.bfloat16() | |
| cos, sin = cos[None, :, None, :], sin[None, :, None, :] | |
| return cos, sin | |
| def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache=None, loss_reduction: str = 'mean'): | |
| B, T = idx.size() | |
| assert T <= self.cos.size(1) | |
| assert idx.device == self.cos.device | |
| assert self.cos.dtype == torch.bfloat16 | |
| T0 = 0 if kv_cache is None else kv_cache.get_pos() | |
| cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] | |
| x = self.transformer.wte(idx) | |
| x = norm(x) | |
| for block in self.transformer.h: | |
| x = block(x, cos_sin, kv_cache) | |
| x = norm(x) | |
| softcap = 15 | |
| if targets is not None: | |
| logits = self.lm_head(x) | |
| logits = softcap * torch.tanh(logits / softcap) | |
| logits = logits.float() | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) | |
| return loss | |
| else: | |
| logits = self.lm_head(x) | |
| logits = softcap * torch.tanh(logits / softcap) | |
| return logits | |
| def generate(self, tokens: list[int], max_tokens: int, temperature: float = 1.0, top_k: int | None = None, seed: int = 42): | |
| assert isinstance(tokens, list) | |
| device = self.transformer.wte.weight.device | |
| rng = None | |
| if temperature > 0: | |
| rng = torch.Generator(device=device) | |
| rng.manual_seed(seed) | |
| ids = torch.tensor([tokens], dtype=torch.long, device=device) | |
| for _ in range(max_tokens): | |
| logits = self.forward(ids) | |
| logits = logits[:, -1, :] | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float('Inf') | |
| if temperature > 0: | |
| logits = logits / max(temperature, 1e-6) | |
| probs = F.softmax(logits, dim=-1) | |
| next_ids = torch.multinomial(probs, num_samples=1, generator=rng) | |
| else: | |
| next_ids = torch.argmax(logits, dim=-1, keepdim=True) | |
| ids = torch.cat((ids, next_ids), dim=1) | |
| yield next_ids.item() | |