drbh
commited on
Commit
·
d76b04d
1
Parent(s):
5cb0596
fix: remove unused trailing param
Browse files
flash_mla/flash_mla_api.cu
CHANGED
|
@@ -70,10 +70,10 @@ mha_fwd_kvcache_mla(
|
|
| 70 |
const double softmax_scale,
|
| 71 |
const bool is_causal_,
|
| 72 |
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
| 73 |
-
const at::Tensor &num_splits
|
| 74 |
|
| 75 |
// TODO: remove this once determined why build is adding this parameter
|
| 76 |
-
const int64_t unknown_param
|
| 77 |
) {
|
| 78 |
auto dprops = at::cuda::getCurrentDeviceProperties();
|
| 79 |
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
|
|
|
| 70 |
const double softmax_scale,
|
| 71 |
const bool is_causal_,
|
| 72 |
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
| 73 |
+
const at::Tensor &num_splits // batch_size + 1
|
| 74 |
|
| 75 |
// TODO: remove this once determined why build is adding this parameter
|
| 76 |
+
// const int64_t unknown_param
|
| 77 |
) {
|
| 78 |
auto dprops = at::cuda::getCurrentDeviceProperties();
|
| 79 |
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
torch-ext/flash_mla/__init__.py
CHANGED
|
@@ -19,8 +19,6 @@ def mha_fwd_kvcache_mla(
|
|
| 19 |
tile_scheduler_metadata: torch.Tensor,
|
| 20 |
num_splits: torch.Tensor,
|
| 21 |
) -> torch.Tensor:
|
| 22 |
-
# TODO: remove when resolved
|
| 23 |
-
unknown_param = 0
|
| 24 |
return ops.mha_fwd_kvcache_mla(
|
| 25 |
q,
|
| 26 |
kcache,
|
|
@@ -31,6 +29,5 @@ def mha_fwd_kvcache_mla(
|
|
| 31 |
softmax_scale,
|
| 32 |
is_causal_,
|
| 33 |
tile_scheduler_metadata,
|
| 34 |
-
num_splits
|
| 35 |
-
unknown_param,
|
| 36 |
)
|
|
|
|
| 19 |
tile_scheduler_metadata: torch.Tensor,
|
| 20 |
num_splits: torch.Tensor,
|
| 21 |
) -> torch.Tensor:
|
|
|
|
|
|
|
| 22 |
return ops.mha_fwd_kvcache_mla(
|
| 23 |
q,
|
| 24 |
kcache,
|
|
|
|
| 29 |
softmax_scale,
|
| 30 |
is_causal_,
|
| 31 |
tile_scheduler_metadata,
|
| 32 |
+
num_splits
|
|
|
|
| 33 |
)
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -8,7 +8,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 8 |
ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
|
| 9 |
|
| 10 |
// TOOD: remove last unknown_param when resolved
|
| 11 |
-
ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor! vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits
|
| 12 |
ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
|
| 13 |
}
|
| 14 |
|
|
|
|
| 8 |
ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
|
| 9 |
|
| 10 |
// TOOD: remove last unknown_param when resolved
|
| 11 |
+
ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor! vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits) -> Tensor[]");
|
| 12 |
ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
|
| 13 |
}
|
| 14 |
|
torch-ext/torch_binding.h
CHANGED
|
@@ -29,8 +29,5 @@ mha_fwd_kvcache_mla(
|
|
| 29 |
const bool is_causal_,
|
| 30 |
|
| 31 |
const torch::Tensor &tile_scheduler_metadata,
|
| 32 |
-
const torch::Tensor &num_splits
|
| 33 |
-
|
| 34 |
-
// TODO: remove when resolved
|
| 35 |
-
const int64_t unknown_param = 0
|
| 36 |
);
|
|
|
|
| 29 |
const bool is_causal_,
|
| 30 |
|
| 31 |
const torch::Tensor &tile_scheduler_metadata,
|
| 32 |
+
const torch::Tensor &num_splits
|
|
|
|
|
|
|
|
|
|
| 33 |
);
|