from functools import lru_cache import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F import triton import triton.language as tl from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from flashcosyvoice.config import CosyVoice2LLMConfig from flashcosyvoice.utils.context import get_context class SiluAndMul(nn.Module): def __init__(self): super().__init__() @torch.compile def forward(self, x: torch.Tensor) -> torch.Tensor: x, y = x.chunk(2, -1) return F.silu(x) * y class RMSNorm(nn.Module): def __init__( self, hidden_size: int, eps: float = 1e-6, ) -> None: super().__init__() self.hidden_size = hidden_size self.eps = eps self.weight = nn.Parameter(torch.ones(hidden_size)) @torch.compile def rms_forward( self, x: torch.Tensor, ) -> torch.Tensor: orig_dtype = x.dtype x = x.to(torch.float32) var = x.pow(2).mean(dim=-1, keepdim=True) x.mul_(torch.rsqrt(var + self.eps)) x = x.to(orig_dtype).mul_(self.weight) return x @torch.compile def add_rms_forward( self, x: torch.Tensor, residual: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: orig_dtype = x.dtype x = x.to(torch.float32).add_(residual.to(torch.float32)) residual = x.to(orig_dtype) var = x.pow(2).mean(dim=-1, keepdim=True) x.mul_(torch.rsqrt(var + self.eps)) x = x.to(orig_dtype).mul_(self.weight) return x, residual def forward( self, x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if residual is None: return self.rms_forward(x) else: return self.add_rms_forward(x, residual) @triton.jit def store_kvcache_kernel( key_ptr, key_stride, value_ptr, value_stride, k_cache_ptr, v_cache_ptr, slot_mapping_ptr, D: tl.constexpr, ): idx = tl.program_id(0) key_offsets = idx * key_stride + tl.arange(0, D) value_offsets = idx * value_stride + tl.arange(0, D) key = tl.load(key_ptr + key_offsets) value = tl.load(value_ptr + value_offsets) slot = tl.load(slot_mapping_ptr + idx) cache_offsets = slot * D + tl.arange(0, D) tl.store(k_cache_ptr + cache_offsets, key) tl.store(v_cache_ptr + cache_offsets, value) def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor): N, num_heads, head_dim = key.shape D = num_heads * head_dim assert key.stride(-1) == 1 and value.stride(-1) == 1 assert key.stride(1) == head_dim and value.stride(1) == head_dim assert k_cache.stride(1) == D and v_cache.stride(1) == D assert slot_mapping.numel() == N store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D) class Attention(nn.Module): def __init__( self, num_heads, head_dim, scale, num_kv_heads, ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.scale = scale self.num_kv_heads = num_kv_heads self.k_cache = self.v_cache = torch.tensor([]) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): o: torch.Tensor q = q.view(-1, self.num_heads, self.head_dim) k = k.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim) context = get_context() k_cache, v_cache = self.k_cache, self.v_cache if k_cache.numel() and v_cache.numel(): store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: if context.block_tables is not None: # prefix cache k, v = k_cache, v_cache o = flash_attn_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, softmax_scale=self.scale, causal=True, block_table=context.block_tables) else: # decode o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, cache_seqlens=context.context_lens, block_table=context.block_tables, softmax_scale=self.scale, causal=True) o = o.view(-1, self.num_heads * self.head_dim) return o class VocabParallelEmbedding(nn.Module): def __init__( self, num_embeddings: int, embedding_dim: int, ): super().__init__() # TODO(xcsong): support tp > 1 self.tp_rank = 0 # dist.get_rank() self.tp_size = 1 # dist.get_world_size() assert num_embeddings % self.tp_size == 0 self.num_embeddings = num_embeddings self.num_embeddings_per_partition = self.num_embeddings // self.tp_size self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition self.embedding_dim = embedding_dim self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim)) self.weight.weight_loader = self.weight_loader def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data shard_size = param_data.size(0) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) assert param_data.size() == loaded_weight.size() param_data.copy_(loaded_weight) def forward(self, x: torch.Tensor): if self.tp_size > 1: mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx) x = mask * (x - self.vocab_start_idx) y = F.embedding(x, self.weight) if self.tp_size > 1: y = mask.unsqueeze(1) * y dist.all_reduce(y) return y class ParallelLMHead(VocabParallelEmbedding): def __init__( self, num_embeddings: int, embedding_dim: int, bias: bool = False, ): super().__init__(num_embeddings, embedding_dim) if bias: self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition)) self.bias.weight_loader = self.weight_loader else: self.register_parameter("bias", None) def forward(self, x: torch.Tensor): context = get_context() if context.is_prefill: last_indices = context.cu_seqlens_q[1:] - 1 x = x[last_indices].contiguous() logits = F.linear(x, self.weight, self.bias) if self.tp_size > 1: all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None dist.gather(logits, all_logits, 0) logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None return logits def divide(numerator, denominator): assert numerator % denominator == 0 return numerator // denominator class LinearBase(nn.Module): def __init__( self, input_size: int, output_size: int, tp_dim: int | None = None, ): super().__init__() self.input_size = input_size self.output_size = output_size self.tp_dim = tp_dim # TODO(xcsong): support tp > 1 self.tp_rank = 0 # dist.get_rank() self.tp_size = 1 # dist.get_world_size() def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError class ReplicatedLinear(LinearBase): def __init__( self, input_size: int, output_size: int, bias: bool = False, ): super().__init__(input_size, output_size) self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size)) self.weight.weight_loader = self.weight_loader if bias: self.bias = nn.Parameter(torch.empty(self.output_size)) self.bias.weight_loader = self.weight_loader else: self.register_parameter("bias", None) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.linear(x, self.weight, self.bias) class ColumnParallelLinear(LinearBase): def __init__( self, input_size: int, output_size: int, bias: bool = False, ): super().__init__(input_size, output_size, 0) self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) self.output_partition_sizes = [self.output_size_per_partition] if hasattr(self, "output_sizes"): self.output_partition_sizes = [ divide(output_size, self.tp_size) for output_size in self.output_sizes ] self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size)) self.weight.weight_loader = self.weight_loader if bias: self.bias = nn.Parameter(torch.empty(self.output_size_per_partition)) self.bias.weight_loader = self.weight_loader else: self.register_parameter("bias", None) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) assert param_data.size() == loaded_weight.size() param_data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.linear(x, self.weight, self.bias) class MergedColumnParallelLinear(ColumnParallelLinear): def __init__( self, input_size: int, output_sizes: list[int], bias: bool = False, ): self.output_sizes = output_sizes super().__init__(input_size, sum(output_sizes), bias=bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): param_data = param.data shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank] assert param_data.size() == loaded_weight.size() param_data.copy_(loaded_weight) class QKVParallelLinear(ColumnParallelLinear): def __init__( self, hidden_size: int, head_size: int, total_num_heads: int, total_num_kv_heads: int | None = None, bias: bool = False, ): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads if total_num_kv_heads is None: total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # TODO(xcsong): support tp > 1 tp_size = 1 # dist.get_world_size() self.num_heads = divide(self.total_num_heads, tp_size) self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size self.output_sizes = [ self.num_heads * self.head_size * tp_size, # q_proj self.num_kv_heads * self.head_size * tp_size, # k_proj self.num_kv_heads * self.head_size * tp_size, # v_proj ] super().__init__(input_size, output_size, bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): param_data = param.data assert loaded_shard_id in ["q", "k", "v"] if loaded_shard_id == "q": shard_size = self.num_heads * self.head_size shard_offset = 0 elif loaded_shard_id == "k": shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size else: shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank] assert param_data.size() == loaded_weight.size() param_data.copy_(loaded_weight) class RowParallelLinear(LinearBase): def __init__( self, input_size: int, output_size: int, bias: bool = False, ): super().__init__(input_size, output_size, 1) self.input_size_per_partition = divide(input_size, self.tp_size) self.output_size_per_partition = output_size self.output_partition_sizes = [output_size] self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition)) self.weight.weight_loader = self.weight_loader if bias: self.bias = nn.Parameter(torch.empty(self.output_size)) self.bias.weight_loader = self.weight_loader else: self.register_parameter("bias", None) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) assert param_data.size() == loaded_weight.size() param_data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> torch.Tensor: y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None) if self.tp_size > 1: dist.all_reduce(y) return y def apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: cos = cos.unsqueeze(-2) sin = sin.unsqueeze(-2) x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1) y1 = x1 * cos - x2 * sin y2 = x2 * cos + x1 * sin return torch.cat((y1, y2), dim=-1).to(x.dtype) class RotaryEmbedding(nn.Module): def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: float, ) -> None: super().__init__() self.head_size = head_size assert rotary_dim == head_size inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) t = torch.arange(max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) self.register_buffer("cos_sin_cache", cache, persistent=False) @torch.compile def forward( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) query = apply_rotary_emb(query, cos, sin).view(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key = apply_rotary_emb(key, cos, sin).view(key_shape) return query, key @lru_cache(1) def get_rope( head_size: int, rotary_dim: int, max_position: int, base: float, rope_scaling: dict | None = None, ): assert rope_scaling is None rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base) return rotary_emb class Qwen2Attention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, max_position: int = 4096 * 32, head_dim: int | None = None, rms_norm_eps: float = 1e-06, qkv_bias: bool = True, rope_theta: float = 1000000.0, rope_scaling: tuple | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size # TODO(xcsong): support tp > 1 tp_size = 1 # dist.get_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads assert self.total_num_kv_heads % tp_size == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = head_dim or hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=qkv_bias, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position, base=self.rope_theta, rope_scaling=rope_scaling, ) self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) o = self.attn(q, k, v) output = self.o_proj(o) return output class Qwen2MLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, ) assert hidden_act == "silu" self.act_fn = SiluAndMul() def forward(self, x): gate_up = self.gate_up_proj(x) x = self.act_fn(gate_up) x = self.down_proj(x) return x class Qwen2DecoderLayer(nn.Module): def __init__( self, config: CosyVoice2LLMConfig, ) -> None: super().__init__() self.hidden_size = config.hidden_size self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, max_position=config.max_position_embeddings, rms_norm_eps=config.rms_norm_eps, qkv_bias=getattr(config, "qkv_bias", True), head_dim=getattr(config, "head_dim", None), rope_theta=getattr(config, "rope_theta", 1000000.0), rope_scaling=getattr(config, "rope_scaling", None), ) self.mlp = Qwen2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual