Spaces:
Running
on
Zero
Running
on
Zero
| from functools import lru_cache | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import triton | |
| import triton.language as tl | |
| from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache | |
| from flashcosyvoice.config import CosyVoice2LLMConfig | |
| from flashcosyvoice.utils.context import get_context | |
| class SiluAndMul(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x, y = x.chunk(2, -1) | |
| return F.silu(x) * y | |
| class RMSNorm(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| eps: float = 1e-6, | |
| ) -> None: | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| def rms_forward( | |
| self, | |
| x: torch.Tensor, | |
| ) -> torch.Tensor: | |
| orig_dtype = x.dtype | |
| x = x.to(torch.float32) | |
| var = x.pow(2).mean(dim=-1, keepdim=True) | |
| x.mul_(torch.rsqrt(var + self.eps)) | |
| x = x.to(orig_dtype).mul_(self.weight) | |
| return x | |
| def add_rms_forward( | |
| self, | |
| x: torch.Tensor, | |
| residual: torch.Tensor, | |
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | |
| orig_dtype = x.dtype | |
| x = x.to(torch.float32).add_(residual.to(torch.float32)) | |
| residual = x.to(orig_dtype) | |
| var = x.pow(2).mean(dim=-1, keepdim=True) | |
| x.mul_(torch.rsqrt(var + self.eps)) | |
| x = x.to(orig_dtype).mul_(self.weight) | |
| return x, residual | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| residual: torch.Tensor | None = None, | |
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | |
| if residual is None: | |
| return self.rms_forward(x) | |
| else: | |
| return self.add_rms_forward(x, residual) | |
| def store_kvcache_kernel( | |
| key_ptr, | |
| key_stride, | |
| value_ptr, | |
| value_stride, | |
| k_cache_ptr, | |
| v_cache_ptr, | |
| slot_mapping_ptr, | |
| D: tl.constexpr, | |
| ): | |
| idx = tl.program_id(0) | |
| key_offsets = idx * key_stride + tl.arange(0, D) | |
| value_offsets = idx * value_stride + tl.arange(0, D) | |
| key = tl.load(key_ptr + key_offsets) | |
| value = tl.load(value_ptr + value_offsets) | |
| slot = tl.load(slot_mapping_ptr + idx) | |
| cache_offsets = slot * D + tl.arange(0, D) | |
| tl.store(k_cache_ptr + cache_offsets, key) | |
| tl.store(v_cache_ptr + cache_offsets, value) | |
| def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor): | |
| N, num_heads, head_dim = key.shape | |
| D = num_heads * head_dim | |
| assert key.stride(-1) == 1 and value.stride(-1) == 1 | |
| assert key.stride(1) == head_dim and value.stride(1) == head_dim | |
| assert k_cache.stride(1) == D and v_cache.stride(1) == D | |
| assert slot_mapping.numel() == N | |
| store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| num_heads, | |
| head_dim, | |
| scale, | |
| num_kv_heads, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = head_dim | |
| self.scale = scale | |
| self.num_kv_heads = num_kv_heads | |
| self.k_cache = self.v_cache = torch.tensor([]) | |
| def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): | |
| o: torch.Tensor | |
| q = q.view(-1, self.num_heads, self.head_dim) | |
| k = k.view(-1, self.num_kv_heads, self.head_dim) | |
| v = v.view(-1, self.num_kv_heads, self.head_dim) | |
| context = get_context() | |
| k_cache, v_cache = self.k_cache, self.v_cache | |
| if k_cache.numel() and v_cache.numel(): | |
| store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) | |
| if context.is_prefill: | |
| if context.block_tables is not None: # prefix cache | |
| k, v = k_cache, v_cache | |
| o = flash_attn_varlen_func(q, k, v, | |
| max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, | |
| max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, | |
| softmax_scale=self.scale, causal=True, block_table=context.block_tables) | |
| else: # decode | |
| o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, | |
| cache_seqlens=context.context_lens, block_table=context.block_tables, | |
| softmax_scale=self.scale, causal=True) | |
| o = o.view(-1, self.num_heads * self.head_dim) | |
| return o | |
| class VocabParallelEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| num_embeddings: int, | |
| embedding_dim: int, | |
| ): | |
| super().__init__() | |
| # TODO(xcsong): support tp > 1 | |
| self.tp_rank = 0 # dist.get_rank() | |
| self.tp_size = 1 # dist.get_world_size() | |
| assert num_embeddings % self.tp_size == 0 | |
| self.num_embeddings = num_embeddings | |
| self.num_embeddings_per_partition = self.num_embeddings // self.tp_size | |
| self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank | |
| self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition | |
| self.embedding_dim = embedding_dim | |
| self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim)) | |
| self.weight.weight_loader = self.weight_loader | |
| def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): | |
| param_data = param.data | |
| shard_size = param_data.size(0) | |
| start_idx = self.tp_rank * shard_size | |
| loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) | |
| assert param_data.size() == loaded_weight.size() | |
| param_data.copy_(loaded_weight) | |
| def forward(self, x: torch.Tensor): | |
| if self.tp_size > 1: | |
| mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx) | |
| x = mask * (x - self.vocab_start_idx) | |
| y = F.embedding(x, self.weight) | |
| if self.tp_size > 1: | |
| y = mask.unsqueeze(1) * y | |
| dist.all_reduce(y) | |
| return y | |
| class ParallelLMHead(VocabParallelEmbedding): | |
| def __init__( | |
| self, | |
| num_embeddings: int, | |
| embedding_dim: int, | |
| bias: bool = False, | |
| ): | |
| super().__init__(num_embeddings, embedding_dim) | |
| if bias: | |
| self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition)) | |
| self.bias.weight_loader = self.weight_loader | |
| else: | |
| self.register_parameter("bias", None) | |
| def forward(self, x: torch.Tensor): | |
| context = get_context() | |
| if context.is_prefill: | |
| last_indices = context.cu_seqlens_q[1:] - 1 | |
| x = x[last_indices].contiguous() | |
| logits = F.linear(x, self.weight, self.bias) | |
| if self.tp_size > 1: | |
| all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None | |
| dist.gather(logits, all_logits, 0) | |
| logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None | |
| return logits | |
| def divide(numerator, denominator): | |
| assert numerator % denominator == 0 | |
| return numerator // denominator | |
| class LinearBase(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int, | |
| tp_dim: int | None = None, | |
| ): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.output_size = output_size | |
| self.tp_dim = tp_dim | |
| # TODO(xcsong): support tp > 1 | |
| self.tp_rank = 0 # dist.get_rank() | |
| self.tp_size = 1 # dist.get_world_size() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| raise NotImplementedError | |
| class ReplicatedLinear(LinearBase): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int, | |
| bias: bool = False, | |
| ): | |
| super().__init__(input_size, output_size) | |
| self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size)) | |
| self.weight.weight_loader = self.weight_loader | |
| if bias: | |
| self.bias = nn.Parameter(torch.empty(self.output_size)) | |
| self.bias.weight_loader = self.weight_loader | |
| else: | |
| self.register_parameter("bias", None) | |
| def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): | |
| assert param.size() == loaded_weight.size() | |
| param.data.copy_(loaded_weight) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return F.linear(x, self.weight, self.bias) | |
| class ColumnParallelLinear(LinearBase): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int, | |
| bias: bool = False, | |
| ): | |
| super().__init__(input_size, output_size, 0) | |
| self.input_size_per_partition = input_size | |
| self.output_size_per_partition = divide(output_size, self.tp_size) | |
| self.output_partition_sizes = [self.output_size_per_partition] | |
| if hasattr(self, "output_sizes"): | |
| self.output_partition_sizes = [ | |
| divide(output_size, self.tp_size) | |
| for output_size in self.output_sizes | |
| ] | |
| self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size)) | |
| self.weight.weight_loader = self.weight_loader | |
| if bias: | |
| self.bias = nn.Parameter(torch.empty(self.output_size_per_partition)) | |
| self.bias.weight_loader = self.weight_loader | |
| else: | |
| self.register_parameter("bias", None) | |
| def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): | |
| param_data = param.data | |
| shard_size = param_data.size(self.tp_dim) | |
| start_idx = self.tp_rank * shard_size | |
| loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) | |
| assert param_data.size() == loaded_weight.size() | |
| param_data.copy_(loaded_weight) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return F.linear(x, self.weight, self.bias) | |
| class MergedColumnParallelLinear(ColumnParallelLinear): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_sizes: list[int], | |
| bias: bool = False, | |
| ): | |
| self.output_sizes = output_sizes | |
| super().__init__(input_size, sum(output_sizes), bias=bias) | |
| def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): | |
| param_data = param.data | |
| shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size | |
| shard_size = self.output_sizes[loaded_shard_id] // self.tp_size | |
| param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) | |
| loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank] | |
| assert param_data.size() == loaded_weight.size() | |
| param_data.copy_(loaded_weight) | |
| class QKVParallelLinear(ColumnParallelLinear): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| head_size: int, | |
| total_num_heads: int, | |
| total_num_kv_heads: int | None = None, | |
| bias: bool = False, | |
| ): | |
| self.hidden_size = hidden_size | |
| self.head_size = head_size | |
| self.total_num_heads = total_num_heads | |
| if total_num_kv_heads is None: | |
| total_num_kv_heads = total_num_heads | |
| self.total_num_kv_heads = total_num_kv_heads | |
| # TODO(xcsong): support tp > 1 | |
| tp_size = 1 # dist.get_world_size() | |
| self.num_heads = divide(self.total_num_heads, tp_size) | |
| self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) | |
| input_size = self.hidden_size | |
| output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size | |
| self.output_sizes = [ | |
| self.num_heads * self.head_size * tp_size, # q_proj | |
| self.num_kv_heads * self.head_size * tp_size, # k_proj | |
| self.num_kv_heads * self.head_size * tp_size, # v_proj | |
| ] | |
| super().__init__(input_size, output_size, bias) | |
| def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): | |
| param_data = param.data | |
| assert loaded_shard_id in ["q", "k", "v"] | |
| if loaded_shard_id == "q": | |
| shard_size = self.num_heads * self.head_size | |
| shard_offset = 0 | |
| elif loaded_shard_id == "k": | |
| shard_size = self.num_kv_heads * self.head_size | |
| shard_offset = self.num_heads * self.head_size | |
| else: | |
| shard_size = self.num_kv_heads * self.head_size | |
| shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size | |
| param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) | |
| loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank] | |
| assert param_data.size() == loaded_weight.size() | |
| param_data.copy_(loaded_weight) | |
| class RowParallelLinear(LinearBase): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int, | |
| bias: bool = False, | |
| ): | |
| super().__init__(input_size, output_size, 1) | |
| self.input_size_per_partition = divide(input_size, self.tp_size) | |
| self.output_size_per_partition = output_size | |
| self.output_partition_sizes = [output_size] | |
| self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition)) | |
| self.weight.weight_loader = self.weight_loader | |
| if bias: | |
| self.bias = nn.Parameter(torch.empty(self.output_size)) | |
| self.bias.weight_loader = self.weight_loader | |
| else: | |
| self.register_parameter("bias", None) | |
| def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): | |
| param_data = param.data | |
| shard_size = param_data.size(self.tp_dim) | |
| start_idx = self.tp_rank * shard_size | |
| loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) | |
| assert param_data.size() == loaded_weight.size() | |
| param_data.copy_(loaded_weight) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None) | |
| if self.tp_size > 1: | |
| dist.all_reduce(y) | |
| return y | |
| def apply_rotary_emb( | |
| x: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| ) -> torch.Tensor: | |
| cos = cos.unsqueeze(-2) | |
| sin = sin.unsqueeze(-2) | |
| x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1) | |
| y1 = x1 * cos - x2 * sin | |
| y2 = x2 * cos + x1 * sin | |
| return torch.cat((y1, y2), dim=-1).to(x.dtype) | |
| class RotaryEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| head_size: int, | |
| rotary_dim: int, | |
| max_position_embeddings: int, | |
| base: float, | |
| ) -> None: | |
| super().__init__() | |
| self.head_size = head_size | |
| assert rotary_dim == head_size | |
| inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) | |
| t = torch.arange(max_position_embeddings, dtype=torch.float) | |
| freqs = torch.einsum("i,j -> ij", t, inv_freq) | |
| cos = freqs.cos() | |
| sin = freqs.sin() | |
| cache = torch.cat((cos, sin), dim=-1) | |
| self.register_buffer("cos_sin_cache", cache, persistent=False) | |
| def forward( | |
| self, | |
| positions: torch.Tensor, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| positions = positions.flatten() | |
| num_tokens = positions.shape[0] | |
| cos_sin = self.cos_sin_cache[positions] | |
| cos, sin = cos_sin.chunk(2, dim=-1) | |
| query_shape = query.shape | |
| query = query.view(num_tokens, -1, self.head_size) | |
| query = apply_rotary_emb(query, cos, sin).view(query_shape) | |
| key_shape = key.shape | |
| key = key.view(num_tokens, -1, self.head_size) | |
| key = apply_rotary_emb(key, cos, sin).view(key_shape) | |
| return query, key | |
| def get_rope( | |
| head_size: int, | |
| rotary_dim: int, | |
| max_position: int, | |
| base: float, | |
| rope_scaling: dict | None = None, | |
| ): | |
| assert rope_scaling is None | |
| rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base) | |
| return rotary_emb | |
| class Qwen2Attention(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_heads: int, | |
| num_kv_heads: int, | |
| max_position: int = 4096 * 32, | |
| head_dim: int | None = None, | |
| rms_norm_eps: float = 1e-06, | |
| qkv_bias: bool = True, | |
| rope_theta: float = 1000000.0, | |
| rope_scaling: tuple | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| # TODO(xcsong): support tp > 1 | |
| tp_size = 1 # dist.get_world_size() | |
| self.total_num_heads = num_heads | |
| assert self.total_num_heads % tp_size == 0 | |
| self.num_heads = self.total_num_heads // tp_size | |
| self.total_num_kv_heads = num_kv_heads | |
| assert self.total_num_kv_heads % tp_size == 0 | |
| self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) | |
| self.head_dim = head_dim or hidden_size // self.total_num_heads | |
| self.q_size = self.num_heads * self.head_dim | |
| self.kv_size = self.num_kv_heads * self.head_dim | |
| self.scaling = self.head_dim**-0.5 | |
| self.rope_theta = rope_theta | |
| self.qkv_proj = QKVParallelLinear( | |
| hidden_size, | |
| self.head_dim, | |
| self.total_num_heads, | |
| self.total_num_kv_heads, | |
| bias=qkv_bias, | |
| ) | |
| self.o_proj = RowParallelLinear( | |
| self.total_num_heads * self.head_dim, | |
| hidden_size, | |
| bias=False, | |
| ) | |
| self.rotary_emb = get_rope( | |
| self.head_dim, | |
| rotary_dim=self.head_dim, | |
| max_position=max_position, | |
| base=self.rope_theta, | |
| rope_scaling=rope_scaling, | |
| ) | |
| self.attn = Attention(self.num_heads, | |
| self.head_dim, | |
| self.scaling, | |
| num_kv_heads=self.num_kv_heads) | |
| def forward( | |
| self, | |
| positions: torch.Tensor, | |
| hidden_states: torch.Tensor, | |
| ) -> torch.Tensor: | |
| qkv = self.qkv_proj(hidden_states) | |
| q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | |
| q, k = self.rotary_emb(positions, q, k) | |
| o = self.attn(q, k, v) | |
| output = self.o_proj(o) | |
| return output | |
| class Qwen2MLP(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| intermediate_size: int, | |
| hidden_act: str, | |
| ) -> None: | |
| super().__init__() | |
| self.gate_up_proj = MergedColumnParallelLinear( | |
| hidden_size, | |
| [intermediate_size] * 2, | |
| bias=False, | |
| ) | |
| self.down_proj = RowParallelLinear( | |
| intermediate_size, | |
| hidden_size, | |
| bias=False, | |
| ) | |
| assert hidden_act == "silu" | |
| self.act_fn = SiluAndMul() | |
| def forward(self, x): | |
| gate_up = self.gate_up_proj(x) | |
| x = self.act_fn(gate_up) | |
| x = self.down_proj(x) | |
| return x | |
| class Qwen2DecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| config: CosyVoice2LLMConfig, | |
| ) -> None: | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.self_attn = Qwen2Attention( | |
| hidden_size=self.hidden_size, | |
| num_heads=config.num_attention_heads, | |
| num_kv_heads=config.num_key_value_heads, | |
| max_position=config.max_position_embeddings, | |
| rms_norm_eps=config.rms_norm_eps, | |
| qkv_bias=getattr(config, "qkv_bias", True), | |
| head_dim=getattr(config, "head_dim", None), | |
| rope_theta=getattr(config, "rope_theta", 1000000.0), | |
| rope_scaling=getattr(config, "rope_scaling", None), | |
| ) | |
| self.mlp = Qwen2MLP( | |
| hidden_size=config.hidden_size, | |
| intermediate_size=config.intermediate_size, | |
| hidden_act=config.hidden_act, | |
| ) | |
| self.input_layernorm = RMSNorm(config.hidden_size, | |
| eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = RMSNorm(config.hidden_size, | |
| eps=config.rms_norm_eps) | |
| def forward( | |
| self, | |
| positions: torch.Tensor, | |
| hidden_states: torch.Tensor, | |
| residual: torch.Tensor | None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if residual is None: | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| else: | |
| hidden_states, residual = self.input_layernorm(hidden_states, residual) | |
| hidden_states = self.self_attn( | |
| positions=positions, | |
| hidden_states=hidden_states, | |
| ) | |
| hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) | |
| hidden_states = self.mlp(hidden_states) | |
| return hidden_states, residual | |