| import torch | |
| from ._ops import ops | |
| def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): | |
| return ops.get_mla_metadata(seqlens_k, s_q, h_kv) | |
| def mha_fwd_kvcache_mla( | |
| q: torch.Tensor, | |
| kcache: torch.Tensor, | |
| vcache_: torch.Tensor, | |
| head_size_v: int, | |
| seqlens_k: torch.Tensor, | |
| block_table: torch.Tensor, | |
| softmax_scale: float, | |
| is_causal_: bool, | |
| tile_scheduler_metadata: torch.Tensor, | |
| num_splits: torch.Tensor, | |
| ) -> torch.Tensor: | |
| return ops.mha_fwd_kvcache_mla( | |
| q, | |
| kcache, | |
| vcache_, | |
| head_size_v, | |
| seqlens_k, | |
| block_table, | |
| softmax_scale, | |
| is_causal_, | |
| tile_scheduler_metadata, | |
| num_splits | |
| ) | |