Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from torch import Tensor | |
| from transformers import AutoTokenizer | |
| from vui.fluac import Fluac | |
| from vui.utils import load_what_you_can | |
| from .config import Config | |
| from .patterns import DelayedPatternProvider | |
| from .rope import apply_rotary_emb, precompute_freqs_cis | |
| class KVCache(nn.Module): | |
| def __init__( | |
| self, | |
| batch_size: int, | |
| max_seqlen: int, | |
| n_kv_heads: int, | |
| head_dim: int, | |
| dtype: torch.dtype = torch.bfloat16, | |
| ): | |
| super().__init__() | |
| cache_shape = (batch_size, n_kv_heads, max_seqlen, head_dim) | |
| self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) | |
| self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) | |
| def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor): | |
| # input_pos: (T,), k_val: (B, nh, T, d) | |
| assert input_pos.size(0) == k_val.size(-2) | |
| k_out = self.k_cache | |
| v_out = self.v_cache | |
| input_pos = input_pos.int() | |
| k_out[:, :, input_pos] = k_val.to(k_out.dtype) | |
| v_out[:, :, input_pos] = v_val.to(k_out.dtype) | |
| return k_out, v_out | |
| def repeat_kv(x: torch.Tensor, n_reps: int) -> torch.Tensor: | |
| """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" | |
| bs, n_kv_heads, T, head_dim = x.shape | |
| return ( | |
| x[:, :, :, None, :] | |
| .expand(bs, n_kv_heads, n_reps, T, head_dim) | |
| .reshape(bs, n_kv_heads * n_reps, T, head_dim) | |
| ) | |
| class MHA(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| *, | |
| block_idx: int, | |
| bias: bool = False, | |
| dropout: float = 0.0, | |
| causal: bool = False, | |
| use_rotary_emb: bool = True, | |
| ): | |
| super().__init__() | |
| head_dim = dim // n_heads | |
| self.use_rotary_emb = use_rotary_emb | |
| self.block_idx = block_idx | |
| self.dim = dim | |
| self.n_heads = n_heads | |
| self.n_kv_heads = n_kv_heads | |
| self.head_dim = head_dim | |
| self.dropout = dropout | |
| self.causal = causal | |
| self.n_reps = n_kv_heads // n_heads | |
| qkv_dim = (n_heads + 2 * n_kv_heads) * head_dim | |
| self.Wqkv = nn.Linear(dim, qkv_dim, bias=bias) | |
| self.out_proj = nn.Linear(dim, dim, bias=bias) | |
| self.kv_cache = None | |
| def forward( | |
| self, | |
| x: Tensor, | |
| freqs_cis: Tensor | None = None, | |
| input_pos: Tensor | None = None, | |
| attn_mask: Tensor | None = None, | |
| ): | |
| B, T, d = x.size() | |
| dropout_p = self.dropout if self.training else 0.0 | |
| qkv = self.Wqkv(x).to(x.dtype) | |
| if self.n_heads == self.n_kv_heads: | |
| qkv = rearrange( | |
| qkv, "B T (three h d) -> B three h T d", three=3, h=self.n_heads | |
| ) | |
| q, k, v = qkv.unbind(dim=1) # (B, h, T, d) | |
| else: | |
| q, k, v = torch.split( | |
| qkv, | |
| [ | |
| self.head_dim * self.n_heads, | |
| self.head_dim * self.n_kv_heads, | |
| self.head_dim * self.n_kv_heads, | |
| ], | |
| dim=1, | |
| ) | |
| q, k, v = map(lambda t: rearrange(t, "B T (h d) -> B h T d"), (q, k, v)) | |
| if self.use_rotary_emb: | |
| q = apply_rotary_emb(freqs_cis, q) | |
| k = apply_rotary_emb(freqs_cis, k) | |
| if self.kv_cache is not None: | |
| k, v = self.kv_cache.update(input_pos, k, v) | |
| if self.n_reps > 1: | |
| k = repeat_kv(k, self.n_reps) | |
| v = repeat_kv(v, self.n_reps) | |
| q, k, v = q.to(x.dtype), k.to(x.dtype), v.to(x.dtype) | |
| is_causal = self.causal and self.kv_cache is None | |
| out = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| dropout_p=dropout_p, | |
| is_causal=is_causal, | |
| attn_mask=attn_mask, | |
| ) | |
| out = self.out_proj(rearrange(out, "B h T d -> B T (h d)")) | |
| return out | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, *, d_model: int, bias: bool, dropout: float, act=nn.GELU, **kwargs | |
| ): | |
| super().__init__() | |
| self.fc1 = nn.Linear(d_model, 4 * d_model, bias=bias) | |
| self.act = act() | |
| self.fc2 = nn.Linear(4 * d_model, d_model, bias=bias) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| return self.dropout(self.fc2(self.act(self.fc1(x)))) | |
| class LlamaMLP(nn.Module): | |
| def __init__( | |
| self, *, d_model: int, multiple_of: int = 256, bias: bool = False, **kwargs | |
| ) -> None: | |
| super().__init__() | |
| hidden_dim = 4 * d_model | |
| hidden_dim = int(2 * hidden_dim / 3) | |
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | |
| self.w1 = nn.Linear(d_model, hidden_dim, bias=bias) | |
| self.w3 = nn.Linear(d_model, hidden_dim, bias=bias) | |
| self.w2 = nn.Linear(hidden_dim, d_model, bias=bias) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x: Tensor): | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| class Block(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| d_model: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| block_idx: int, | |
| bias: bool, | |
| dropout: float, | |
| norm_eps: float = 1e-5, # use 1e-6 for rms | |
| use_rotary_emb: bool = True, | |
| ): | |
| super().__init__() | |
| self.block_idx = block_idx | |
| self.n_heads = n_heads | |
| self.n_kv_heads = n_kv_heads | |
| self.head_dim = d_model // n_heads | |
| self.attn_norm = RMSNorm(d_model, eps=norm_eps) | |
| self.attn = MHA( | |
| d_model, | |
| n_heads, | |
| n_kv_heads, | |
| block_idx=block_idx, | |
| bias=bias, | |
| dropout=dropout, | |
| causal=True, | |
| use_rotary_emb=use_rotary_emb, | |
| ) | |
| self.mlp_norm = RMSNorm(d_model, eps=norm_eps) | |
| self.mlp = LlamaMLP(d_model=d_model, bias=bias, dropout=dropout) | |
| def forward( | |
| self, | |
| x: Tensor, | |
| freqs_cis: Tensor | None = None, | |
| input_pos: Tensor | None = None, | |
| attn_mask: Tensor | None = None, | |
| ): | |
| x = x + self.attn( | |
| self.attn_norm(x), | |
| freqs_cis=freqs_cis, | |
| input_pos=input_pos, | |
| attn_mask=attn_mask, | |
| ) | |
| x = x + self.mlp(self.mlp_norm(x)) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| n_layers: int, | |
| d_model: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| bias: bool, | |
| dropout: float, | |
| max_seqlen: int = 4096, | |
| rope_theta: float = 10000.0, | |
| rope_theta_rescale_factor: float = 1.0, | |
| norm_eps: float = 1e-5, | |
| use_rotary_emb: bool = True, | |
| rope_dim: int | None = None, | |
| ): | |
| super().__init__() | |
| assert d_model % n_heads == 0 | |
| self.use_rotary_emb = use_rotary_emb | |
| self.max_seqlen = max_seqlen | |
| self.blocks = nn.ModuleList( | |
| [ | |
| Block( | |
| d_model=d_model, | |
| n_heads=n_heads, | |
| n_kv_heads=n_kv_heads, | |
| block_idx=block_idx, | |
| bias=bias, | |
| dropout=dropout, | |
| norm_eps=norm_eps, | |
| use_rotary_emb=use_rotary_emb, | |
| ) | |
| for block_idx in range(n_layers) | |
| ] | |
| ) | |
| self.norm = RMSNorm(d_model, eps=norm_eps) | |
| self.attn_mask = None | |
| head_dim = d_model // n_heads | |
| rope_dim = rope_dim or head_dim | |
| assert rope_dim <= head_dim # apply RoPE to a fraction of embeddings | |
| freqs_cis = precompute_freqs_cis( | |
| rope_dim, | |
| max_seqlen, | |
| theta=rope_theta, | |
| theta_rescale_factor=rope_theta_rescale_factor, | |
| ) | |
| self.register_buffer("freqs_cis", freqs_cis, persistent=False) | |
| def allocate_inference_cache( | |
| self, batch_size: int, device: str, dtype=torch.bfloat16 | |
| ): | |
| for block in self.blocks: | |
| block.attn.kv_cache = KVCache( | |
| batch_size, self.max_seqlen, block.n_kv_heads, block.head_dim, dtype | |
| ).to(device) | |
| # I don't understand why this is needed | |
| self.attn_mask = torch.tril( | |
| torch.ones( | |
| self.max_seqlen, self.max_seqlen, dtype=torch.bool, device=device | |
| ) | |
| ) | |
| def deallocate_kv_cache(self): | |
| for block in self.blocks: | |
| block.attn.kv_cache = None | |
| self.attn_mask = None | |
| def forward(self, x: Tensor, input_pos: Tensor): | |
| if self.use_rotary_emb: | |
| freqs_cis = self.freqs_cis[input_pos] | |
| else: | |
| freqs_cis = None | |
| attn_mask = ( | |
| self.attn_mask[None, None, input_pos] | |
| if self.attn_mask is not None | |
| else None | |
| ) | |
| for block in self.blocks: | |
| x = block(x, freqs_cis=freqs_cis, input_pos=input_pos, attn_mask=attn_mask) | |
| x = self.norm(x) | |
| return x | |
| class Vui(nn.Module): | |
| BASE = "vui-100m-base.pt" | |
| COHOST = "vui-cohost-100m.pt" | |
| ABRAHAM = "vui-abraham-100m.pt" | |
| def __init__(self, config: Config = Config()): | |
| super().__init__() | |
| self.codec = Fluac.from_pretrained() | |
| self.config = config | |
| cfg = config.model | |
| self.tokenizer = AutoTokenizer.from_pretrained("google/byt5-small") | |
| self.use_rotary_emb = cfg.use_rotary_emb | |
| self.token_emb = nn.Embedding(self.tokenizer.vocab_size, cfg.d_model) | |
| self.pattern_provider = DelayedPatternProvider(n_q=cfg.n_quantizers) | |
| self.audio_embeddings = nn.ModuleList( | |
| [ | |
| nn.Embedding(cfg.codebook_size + 8, cfg.d_model) | |
| for _ in range(cfg.n_quantizers) | |
| ] | |
| ) | |
| n_kv_heads = cfg.n_heads | |
| max_seqlen = cfg.max_text_tokens + cfg.max_audio_tokens | |
| self.decoder = Decoder( | |
| n_layers=cfg.n_layers, | |
| d_model=cfg.d_model, | |
| n_heads=cfg.n_heads, | |
| n_kv_heads=n_kv_heads, | |
| bias=cfg.bias, | |
| dropout=cfg.dropout, | |
| max_seqlen=max_seqlen + cfg.n_quantizers, | |
| rope_dim=cfg.rope_dim, | |
| rope_theta=cfg.rope_theta, | |
| rope_theta_rescale_factor=cfg.rope_theta_rescale_factor, | |
| ) | |
| self.audio_heads = nn.ModuleList( | |
| [ | |
| nn.Linear(cfg.d_model, cfg.codebook_size + 8, bias=cfg.bias) | |
| for _ in range(cfg.n_quantizers) | |
| ] | |
| ) | |
| self.apply(self._init_weights) | |
| for pn, p in self.named_parameters(): | |
| if pn.endswith("out_proj.weight"): | |
| torch.nn.init.normal_( | |
| p, mean=0.0, std=0.02 / math.sqrt(2 * cfg.n_layers) | |
| ) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def from_pretrained( | |
| checkpoint_path: str | dict = ABRAHAM, | |
| **config_kwargs, | |
| ): | |
| if isinstance(checkpoint_path, dict): | |
| checkpoint = checkpoint_path | |
| else: | |
| if not os.path.exists(checkpoint_path): | |
| from huggingface_hub import hf_hub_download | |
| checkpoint_path = hf_hub_download( | |
| "fluxions/vui", | |
| checkpoint_path, | |
| ) | |
| checkpoint = torch.load( | |
| checkpoint_path, map_location="cpu", weights_only=True | |
| ) | |
| config = {**checkpoint["config"], **config_kwargs} | |
| config = Config(**config) | |
| state_dict = checkpoint["model"] | |
| state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
| state_dict = { | |
| k.replace("text_embedding.", "token_emb."): v for k, v in state_dict.items() | |
| } | |
| model = Vui(config) | |
| load_what_you_can(state_dict, model) | |
| return model | |
| def from_pretrained_inf( | |
| checkpoint_path: str | dict, | |
| **config_kwargs, | |
| ): | |
| return Vui.from_pretrained(checkpoint_path, **config_kwargs).eval() | |
| def device(self): | |
| return next(self.parameters()).device | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |