| std::vector<torch::Tensor> | |
| get_mla_metadata( | |
| torch::Tensor &seqlens_k, | |
| const int64_t num_heads_per_head_k, | |
| const int64_t num_heads_k | |
| ); | |
| std::vector<torch::Tensor> | |
| mha_fwd_kvcache_mla( | |
| torch::Tensor &q, | |
| const torch::Tensor &kcache, | |
| const c10::optional<torch::Tensor> &vcache_, | |
| const int64_t head_size_v, | |
| const torch::Tensor &seqlens_k, | |
| const torch::Tensor &block_table, | |
| const double softmax_scale, | |
| bool is_causal, | |
| const torch::Tensor &tile_scheduler_metadata, | |
| const torch::Tensor &num_splits | |
| ); |