Spaces:
Runtime error
Runtime error
| # Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # Redistribution and use in source and binary forms, with or without modification, are permitted | |
| # provided that the following conditions are met: | |
| # | |
| # 1. Redistributions of source code must retain the above copyright notice, this list of | |
| # conditions and the following disclaimer. | |
| # | |
| # 2. Redistributions in binary form must reproduce the above copyright notice, this | |
| # list of conditions and the following disclaimer in the documentation and/or other | |
| # materials provided with the distribution. | |
| # | |
| # 3. Neither the name of the copyright holder nor the names of its contributors | |
| # may be used to endorse or promote products derived from this software without | |
| # specific prior written permission. | |
| # | |
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR | |
| # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND | |
| # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR | |
| # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |
| # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |
| # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| from dataclasses import dataclass | |
| from functools import reduce | |
| from math import gcd | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from torch.nn import functional as F | |
| from fam.llm.utils import get_default_dtype | |
| import logging | |
| # Adjust the logging level | |
| logger = logging.getLogger("torch") | |
| logger.setLevel(logging.ERROR) | |
| def find_multiple(n: int, *args: Tuple[int]) -> int: | |
| k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) | |
| if n % k == 0: | |
| return n | |
| return n + k - (n % k) | |
| class ModelArgs: | |
| block_size: int = 2048 | |
| vocab_size: int = 32000 | |
| n_layer: int = 32 | |
| n_head: int = 32 | |
| dim: int = 4096 | |
| speaker_emb_dim: int = 256 | |
| intermediate_size: int = None | |
| n_local_heads: int = -1 | |
| head_dim: int = 64 | |
| norm_eps: float = 1e-5 | |
| dtype: torch.dtype = torch.bfloat16 | |
| def __post_init__(self): | |
| if self.n_local_heads == -1: | |
| self.n_local_heads = self.n_head | |
| if self.intermediate_size is None: | |
| hidden_dim = 4 * self.dim | |
| n_hidden = int(2 * hidden_dim / 3) | |
| self.intermediate_size = find_multiple(n_hidden, 256) | |
| self.head_dim = self.dim // self.n_head | |
| self.dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[get_default_dtype()] | |
| def from_name(cls, name: str): | |
| if name in transformer_configs: | |
| return cls(**transformer_configs[name]) | |
| # fuzzy search | |
| config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] | |
| assert len(config) == 1, name | |
| return cls(**transformer_configs[config[0]]) | |
| transformer_configs = { | |
| "metavoice-1B": dict( | |
| n_layer=24, | |
| n_head=16, | |
| dim=2048, | |
| vocab_size=2562, | |
| ), | |
| } | |
| class KVCache(nn.Module): | |
| def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype): | |
| super().__init__() | |
| cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) | |
| self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) | |
| self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) | |
| def update(self, input_pos, k_val, v_val): | |
| # input_pos: [S], k_val: [B, H, S, D] | |
| assert input_pos.shape[0] == k_val.shape[2] | |
| k_out = self.k_cache | |
| v_out = self.v_cache | |
| k_out[:, :, input_pos] = k_val | |
| v_out[:, :, input_pos] = v_val | |
| return k_out, v_out | |
| class Transformer(nn.Module): | |
| def __init__(self, config: ModelArgs) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) | |
| self.pos_embeddings = nn.Embedding(config.block_size, config.dim) | |
| self.speaker_cond_pos = nn.Linear(config.speaker_emb_dim, config.dim, bias=False) | |
| self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) | |
| self.norm = RMSNorm(config.dim, eps=config.norm_eps) | |
| self.output = nn.Linear(config.dim, config.vocab_size, bias=False) | |
| self.mask_cache: Optional[Tensor] = None | |
| self.max_batch_size = -1 | |
| self.max_seq_length = -1 | |
| def setup_spk_cond_mask(self): | |
| self.spk_cond_mask = torch.zeros((2, 1, self.config.dim), dtype=torch.bool) | |
| self.spk_cond_mask[0] = 1 | |
| def setup_caches(self, max_batch_size, max_seq_length): | |
| if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: | |
| return | |
| head_dim = self.config.dim // self.config.n_head | |
| max_seq_length = find_multiple(max_seq_length, 8) | |
| self.max_seq_length = max_seq_length | |
| self.max_batch_size = max_batch_size | |
| for b in self.layers: | |
| b.attention.kv_cache = KVCache( | |
| max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=self.config.dtype | |
| ) | |
| self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) | |
| def forward(self, idx: Tensor, spk_emb: Tensor, input_pos: Tensor) -> Tensor: | |
| mask = self.causal_mask[None, None, input_pos] | |
| x = ( | |
| self.tok_embeddings(idx) | |
| + self.pos_embeddings(input_pos) | |
| # masking for speaker condition free guidance | |
| + self.speaker_cond_pos(spk_emb) * self.spk_cond_mask | |
| ) | |
| for i, layer in enumerate(self.layers): | |
| x = layer(x, input_pos, mask) | |
| x = self.norm(x) | |
| logits = self.output(x) | |
| return logits | |
| def from_name(cls, name: str): | |
| return cls(ModelArgs.from_name(name)) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, config: ModelArgs) -> None: | |
| super().__init__() | |
| self.attention = Attention(config) | |
| self.feed_forward = FeedForward(config) | |
| self.ffn_norm = RMSNorm(config.dim, config.norm_eps) | |
| self.attention_norm = RMSNorm(config.dim, config.norm_eps) | |
| def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor: | |
| h = x + self.attention(self.attention_norm(x), mask, input_pos) | |
| out = h + self.feed_forward(self.ffn_norm(h)) | |
| return out | |
| class Attention(nn.Module): | |
| def __init__(self, config: ModelArgs): | |
| super().__init__() | |
| assert config.dim % config.n_head == 0 | |
| total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim | |
| # key, query, value projections for all heads, but in a batch | |
| self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) | |
| self.wo = nn.Linear(config.dim, config.dim, bias=False) | |
| self.kv_cache = None | |
| self.n_head = config.n_head | |
| self.head_dim = config.head_dim | |
| self.n_local_heads = config.n_local_heads | |
| self.dim = config.dim | |
| def forward( | |
| self, | |
| x: Tensor, | |
| mask: Tensor, | |
| input_pos: Optional[Tensor] = None, | |
| ) -> Tensor: | |
| bsz, seqlen, _ = x.shape | |
| kv_size = self.n_local_heads * self.head_dim | |
| q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) | |
| q = q.view(bsz, seqlen, self.n_head, self.head_dim) | |
| k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) | |
| v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) | |
| q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) | |
| if self.kv_cache is not None: | |
| k, v = self.kv_cache.update(input_pos, k, v) | |
| k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) | |
| v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) | |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) | |
| y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) | |
| y = self.wo(y) | |
| return y | |
| class SwiGLU(nn.Module): | |
| def __init__(self, config: ModelArgs) -> None: | |
| super().__init__() | |
| self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) | |
| self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return F.silu(self.w1(x)) * self.w3(x) | |
| class FeedForward(nn.Module): | |
| def __init__(self, config: ModelArgs) -> None: | |
| super().__init__() | |
| self.swiglu = SwiGLU(config) | |
| self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.w2(self.swiglu(x)) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) | |
| def forward(self, x: Tensor) -> Tensor: | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |