Kernels
danieldk HF Staff commited on
Commit
e61c028
·
1 Parent(s): ff82004

Remove source

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +5 -2
  2. build.toml +0 -92
  3. cutlass_extensions/include/cutlass_extensions/arch/mma.h +0 -46
  4. cutlass_extensions/include/cutlass_extensions/compute_occupancy.h +0 -51
  5. cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h +0 -48
  6. cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h +0 -148
  7. cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h +0 -390
  8. cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +0 -285
  9. cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h +0 -82
  10. cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h +0 -58
  11. cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +0 -123
  12. cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +0 -492
  13. cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h +0 -447
  14. cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +0 -89
  15. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h +0 -106
  16. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +0 -346
  17. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +0 -315
  18. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h +0 -426
  19. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +0 -527
  20. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h +0 -236
  21. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h +0 -599
  22. cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +0 -385
  23. cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +0 -127
  24. cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +0 -313
  25. cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +0 -469
  26. cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +0 -429
  27. cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h +0 -61
  28. cutlass_kernels/cutlass_heuristic.cu +0 -208
  29. cutlass_kernels/cutlass_heuristic.h +0 -39
  30. cutlass_kernels/cutlass_preprocessors.cc +0 -703
  31. cutlass_kernels/cutlass_preprocessors.h +0 -33
  32. cutlass_kernels/fpA_intB_gemm.cu +0 -99
  33. cutlass_kernels/fpA_intB_gemm.h +0 -36
  34. cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +0 -118
  35. cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +0 -858
  36. cutlass_kernels/fpA_intB_gemm_wrapper.cu +0 -201
  37. cutlass_kernels/fpA_intB_gemm_wrapper.h +0 -23
  38. flake.lock +0 -169
  39. flake.nix +0 -17
  40. torch-ext/quantization_eetq/__init__.py +0 -3
  41. torch-ext/quantization_eetq/custom_ops.py +0 -36
  42. torch-ext/torch_binding.cpp +0 -19
  43. torch-ext/torch_binding.h +0 -25
  44. utils/activation_types.h +0 -40
  45. utils/cuda_utils.cc +0 -55
  46. utils/cuda_utils.h +0 -76
  47. utils/logger.cc +0 -59
  48. utils/logger.h +0 -121
  49. utils/string_utils.h +0 -54
  50. utils/torch_utils.h +0 -68
README.md CHANGED
@@ -1,11 +1,14 @@
1
  ---
2
  license: apache-2.0
3
  tags:
4
- - kernel
5
  ---
6
 
7
  ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/quantization-eetq)
8
 
9
  ## eetq
10
 
11
- EETQ kernels from [NetEase-FuXi/EETQ](https://github.com/NetEase-FuXi/EETQ).
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  tags:
4
+ - kernel
5
  ---
6
 
7
  ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/quantization-eetq)
8
 
9
  ## eetq
10
 
11
+ EETQ kernels from [NetEase-FuXi/EETQ](https://github.com/NetEase-FuXi/EETQ).
12
+
13
+ Kernel source: https://github.com/huggingface/kernels-community/tree/main/quantization-eetq
14
+
build.toml DELETED
@@ -1,92 +0,0 @@
1
- [general]
2
- name = "quantization_eetq"
3
- universal = false
4
-
5
- [torch]
6
- src = [
7
- "torch-ext/torch_binding.cpp",
8
- "torch-ext/torch_binding.h",
9
- ]
10
-
11
- [kernel.weight_only_batched_gemv]
12
- backend = "cuda"
13
- depends = [
14
- "cutlass_2_10",
15
- "torch",
16
- ]
17
- include = ["cutlass_extensions/include"]
18
- src = [
19
- "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
20
- "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
21
- "weightOnlyBatchedGemv/common.h",
22
- "weightOnlyBatchedGemv/enabled.h",
23
- "weightOnlyBatchedGemv/kernel.h",
24
- "weightOnlyBatchedGemv/kernelLauncher.cu",
25
- "weightOnlyBatchedGemv/kernelLauncher.h",
26
- "weightOnlyBatchedGemv/utility.h",
27
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu",
28
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu",
29
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu",
30
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu",
31
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu",
32
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu",
33
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu",
34
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu",
35
- ]
36
-
37
- [kernel.cutlass_kernels]
38
- backend = "cuda"
39
- depends = [
40
- "cutlass_2_10",
41
- "torch",
42
- ]
43
- include = [
44
- ".",
45
- "utils",
46
- "cutlass_extensions/include",
47
- ]
48
- src = [
49
- "cutlass_extensions/include/cutlass_extensions/arch/mma.h",
50
- "cutlass_extensions/include/cutlass_extensions/compute_occupancy.h",
51
- "cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h",
52
- "cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h",
53
- "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h",
54
- "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h",
55
- "cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h",
56
- "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h",
57
- "cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h",
58
- "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h",
59
- "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h",
60
- "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
61
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h",
62
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h",
63
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h",
64
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h",
65
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h",
66
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h",
67
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h",
68
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h",
69
- "cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h",
70
- "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h",
71
- "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h",
72
- "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
73
- "cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h",
74
- "cutlass_kernels/cutlass_heuristic.cu",
75
- "cutlass_kernels/cutlass_heuristic.h",
76
- "cutlass_kernels/cutlass_preprocessors.cc",
77
- "cutlass_kernels/cutlass_preprocessors.h",
78
- "cutlass_kernels/fpA_intB_gemm.cu",
79
- "cutlass_kernels/fpA_intB_gemm.h",
80
- "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h",
81
- "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h",
82
- "cutlass_kernels/fpA_intB_gemm_wrapper.cu",
83
- "cutlass_kernels/fpA_intB_gemm_wrapper.h",
84
- "weightOnlyBatchedGemv/common.h",
85
- "weightOnlyBatchedGemv/enabled.h",
86
- "utils/activation_types.h",
87
- "utils/cuda_utils.h",
88
- "utils/logger.cc",
89
- "utils/logger.h",
90
- "utils/string_utils.h",
91
- "utils/torch_utils.h",
92
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/arch/mma.h DELETED
@@ -1,46 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Templates exposing architecture support for multiply-add operations
33
- */
34
-
35
- #pragma once
36
-
37
- /////////////////////////////////////////////////////////////////////////////////////////////////
38
-
39
- namespace cutlass {
40
- namespace arch {
41
-
42
- // Tag which triggers MMA which will trigger
43
- struct OpMultiplyAddDequantizeInterleavedBToA;
44
-
45
- } // namespace arch
46
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/compute_occupancy.h DELETED
@@ -1,51 +0,0 @@
1
- /*
2
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
- #pragma once
17
-
18
- #include <cuda_runtime_api.h>
19
-
20
- #include "cutlass/device_kernel.h"
21
- #include "utils/cuda_utils.h"
22
-
23
- namespace fastertransformer {
24
-
25
- template<typename GemmKernel>
26
- inline int compute_occupancy_for_kernel()
27
- {
28
-
29
- int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
30
-
31
- if (smem_size > (48 << 10)) {
32
- cudaError_t status =
33
- cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
34
- if (status == cudaError::cudaErrorInvalidValue) {
35
- // Clear the error bit since we can ignore this.
36
- // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an
37
- // occupancy of 0. This will cause the heuristic to ignore this configuration.
38
- status = cudaGetLastError();
39
- return 0;
40
- }
41
- check_cuda_error(status);
42
- }
43
-
44
- int max_active_blocks = -1;
45
- check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
46
- &max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
47
-
48
- return max_active_blocks;
49
- }
50
-
51
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h DELETED
@@ -1,48 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- #pragma once
33
-
34
- /////////////////////////////////////////////////////////////////////////////////////////////////
35
-
36
- namespace cutlass {
37
- namespace epilogue {
38
-
39
- // define scaling mode
40
- enum class QuantMode {
41
- PerTensorQuant,
42
- PerTokenQuant,
43
- PerChannelQuant,
44
- PerTokenChannelQuant
45
- };
46
-
47
- } // namespace epilogue
48
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h DELETED
@@ -1,148 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Functor performing linear combination with a maximum operation used by epilogues.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/array.h"
38
- #include "cutlass/cutlass.h"
39
- #include "cutlass/epilogue/thread/activation.h"
40
- #include "cutlass/epilogue/thread/scale_type.h"
41
- #include "cutlass/functional.h"
42
- #include "cutlass/half.h"
43
- #include "cutlass/numeric_conversion.h"
44
- #include "cutlass/numeric_types.h"
45
-
46
- /////////////////////////////////////////////////////////////////////////////////////////////////
47
-
48
- namespace cutlass {
49
- namespace epilogue {
50
- namespace thread {
51
-
52
- /////////////////////////////////////////////////////////////////////////////////////////////////
53
-
54
- __forceinline__ __device__ float copysignf_pos(float a, float b)
55
- {
56
- float r;
57
- r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
58
- return r;
59
- }
60
-
61
- __forceinline__ __device__ float tanh_opt(float x)
62
- {
63
- #if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
64
- const float exp_val = -1.f * fabs(2 * x);
65
- return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
66
- #else
67
- return fast_tanh(x);
68
- #endif
69
- }
70
-
71
- /////////////////////////////////////////////////////////////////////////////////////////////////
72
-
73
- // DdK: GELU_taylor ir incomplete in 2.10. Vendored fixes here.
74
-
75
- // GELU operator implemented using the Taylor series approximation
76
- template <typename T>
77
- struct GELU_taylor_fixed {
78
- static const bool kIsHeavy=true;
79
- CUTLASS_HOST_DEVICE
80
- T operator()(T const &z) const {
81
-
82
- T k0 = T(0.7978845608028654);
83
- T k1 = T(0.044715);
84
-
85
- return T(cutlass::constants::half<T>() * z *
86
- (cutlass::constants::one<T>() + fast_tanh(k0 * z * (cutlass::constants::one<T>() + k1 * z * z))));
87
- }
88
-
89
- using Params = LinearCombinationGenericParams<T>;
90
-
91
- CUTLASS_HOST_DEVICE
92
- T operator()(T const &scalar, Params const &params_) const {
93
- return this->operator()(scalar);
94
- }
95
- };
96
-
97
- template<>
98
- struct GELU_taylor_fixed<float> {
99
- static const bool kIsHeavy = true;
100
- CUTLASS_DEVICE
101
- float operator()(float const& z) const
102
- {
103
-
104
- float k0 = float(0.7978845608028654);
105
- float k1 = float(0.044715);
106
-
107
- return float(
108
- cutlass::constants::half<float>() * z
109
- * (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
110
- }
111
-
112
- using Params = LinearCombinationGenericParams<float>;
113
-
114
- CUTLASS_DEVICE
115
- float operator()(float const& scalar, Params const& params_) const
116
- {
117
- return this->operator()(scalar);
118
- }
119
- };
120
-
121
- template <typename T, int N>
122
- struct GELU_taylor_fixed<Array<T, N> > {
123
- static const bool kIsHeavy=true;
124
- CUTLASS_HOST_DEVICE
125
- Array<T, N> operator()(Array<T, N> const &rhs) const {
126
- Array<T, N> y;
127
- GELU_taylor<T> gelu_op;
128
-
129
- CUTLASS_PRAGMA_UNROLL
130
- for (int i = 0; i < N; ++i) {
131
- y[i] = gelu_op(rhs[i]);
132
- }
133
-
134
- return y;
135
- }
136
-
137
- using Params = LinearCombinationGenericParams<T>;
138
- CUTLASS_HOST_DEVICE
139
- Array<T, N> operator()(Array<T, N> const &rhs, Params const &params_) const {
140
- return this->operator()(rhs);
141
- }
142
- };
143
-
144
- } // namespace thread
145
- } // namespace epilogue
146
- } // namespace cutlass
147
-
148
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h DELETED
@@ -1,390 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
33
-
34
- original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
35
-
36
- */
37
-
38
- #pragma once
39
-
40
- /////////////////////////////////////////////////////////////////////////////////////////////////
41
-
42
- #include "../epilogue_quant_helper.h"
43
- #include "cutlass/arch/memory.h"
44
- #include "cutlass/arch/memory_sm75.h"
45
- #include "cutlass/cutlass.h"
46
- #include "cutlass/fast_math.h"
47
- #include "cutlass/numeric_conversion.h"
48
-
49
- namespace cutlass {
50
- namespace epilogue {
51
- namespace threadblock {
52
-
53
- template<typename ThreadblockShape_,
54
- int ThreadCount,
55
- typename ScaleTileIterator_,
56
- typename OutputTileIterator_,
57
- typename ElementAccumulator_,
58
- typename ElementCompute_,
59
- typename ElementwiseFunctor_,
60
- bool UseMasking_ = false>
61
- class EpilogueVisitorPerRowPerCol {
62
- public:
63
- using ThreadblockShape = ThreadblockShape_;
64
- static int const kThreadCount = ThreadCount;
65
-
66
- using ScaleTileIterator = ScaleTileIterator_;
67
- using OutputTileIterator = OutputTileIterator_;
68
- using ElementwiseFunctor = ElementwiseFunctor_;
69
-
70
- static int const kIterations = OutputTileIterator::kIterations;
71
- static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
72
-
73
- using ElementOutput = typename OutputTileIterator::Element;
74
- using LayoutOutput = cutlass::layout::RowMajor;
75
- using ElementAccumulator = ElementAccumulator_;
76
-
77
- using AlphaScaleElementType = typename ScaleTileIterator::Element;
78
-
79
- using ElementCompute = ElementCompute_;
80
- using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
81
- using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
82
- using OutputVector = Array<ElementOutput, kElementsPerAccess>;
83
-
84
- static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
85
- static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
86
-
87
- /// Argument structure
88
- struct Arguments {
89
-
90
- typename ElementwiseFunctor::Params elementwise;
91
- int64_t batch_stride_alpha;
92
- int64_t batch_stride_C;
93
- int64_t batch_stride_D;
94
-
95
- //
96
- // Methods
97
- //
98
- Arguments(): batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
99
-
100
- Arguments(typename ElementwiseFunctor::Params elementwise_):
101
- elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0)
102
- {
103
- }
104
-
105
- Arguments(typename ElementwiseFunctor::Params elementwise_,
106
- int64_t batch_stride_alpha_,
107
- int64_t batch_stride_C_,
108
- int64_t batch_stride_D_):
109
- elementwise(elementwise_),
110
- batch_stride_alpha(batch_stride_alpha_),
111
- batch_stride_C(batch_stride_C_),
112
- batch_stride_D(batch_stride_D_)
113
- {
114
- }
115
- };
116
-
117
- struct Params {
118
-
119
- typename ElementwiseFunctor::Params elementwise;
120
- int64_t batch_stride_alpha;
121
- int64_t batch_stride_C;
122
- int64_t batch_stride_D;
123
- //
124
- // Methods
125
- //
126
- CUTLASS_HOST_DEVICE
127
- Params() {}
128
-
129
- CUTLASS_HOST_DEVICE
130
- Params(Arguments const& args):
131
- elementwise(args.elementwise),
132
- batch_stride_alpha(args.batch_stride_alpha),
133
- batch_stride_C(args.batch_stride_C),
134
- batch_stride_D(args.batch_stride_D)
135
- {
136
- }
137
- };
138
-
139
- /// Shared storage
140
- struct SharedStorage {};
141
-
142
- private:
143
- Params const& params_;
144
- SharedStorage& shared_storage_;
145
- MatrixCoord extent_;
146
- MatrixCoord extent_real_;
147
- ElementwiseFunctor elementwise_;
148
-
149
- const bool per_token_quant_;
150
- const bool per_channel_quant_;
151
-
152
- AlphaScaleElementType* ptr_alpha_row_;
153
- AlphaScaleElementType* ptr_alpha_col_;
154
- ScaleTileIterator iterator_alpha_col_;
155
- OutputTileIterator iterator_C_;
156
- OutputTileIterator iterator_D_;
157
-
158
- AlphaScaleElementType element_alpha_row_ = 1.0f;
159
- AlphaScaleElementType element_alpha_col_ = 1.0f;
160
- typename ScaleTileIterator::Fragment fragment_alpha_col_;
161
- typename OutputTileIterator::Fragment fragment_C_;
162
- typename OutputTileIterator::Fragment fragment_D_;
163
-
164
- ElementAccumulator beta_;
165
-
166
- int column_offset_;
167
-
168
- MatrixCoord thread_offset_;
169
-
170
- public:
171
- CUTLASS_DEVICE
172
- EpilogueVisitorPerRowPerCol(Params const& params,
173
- SharedStorage& shared_storage,
174
- cutlass::MatrixCoord const& problem_size,
175
- int thread_idx,
176
- int warp_idx,
177
- int lane_idx,
178
- typename ScaleTileIterator::Params params_alpha_col,
179
- typename OutputTileIterator::Params params_C,
180
- typename OutputTileIterator::Params params_D,
181
- QuantMode quant_mode,
182
- AlphaScaleElementType* ptr_alpha_row,
183
- AlphaScaleElementType* ptr_alpha_col,
184
- typename OutputTileIterator::Element* ptr_C,
185
- typename OutputTileIterator::Element* ptr_D,
186
- cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
187
- int column_offset = 0,
188
- cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)):
189
- params_(params),
190
- shared_storage_(shared_storage),
191
- extent_(problem_size),
192
- elementwise_(params.elementwise),
193
- per_token_quant_(quant_mode == QuantMode::PerTokenQuant || quant_mode == QuantMode::PerTokenChannelQuant),
194
- per_channel_quant_(quant_mode == QuantMode::PerChannelQuant || quant_mode == QuantMode::PerTokenChannelQuant),
195
- ptr_alpha_row_(ptr_alpha_row),
196
- ptr_alpha_col_(ptr_alpha_col),
197
- iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
198
- iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
199
- iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
200
- extent_real_(problem_size_real)
201
- {
202
- beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
203
-
204
- if (beta_ == ElementAccumulator()) {
205
- iterator_C_.clear_mask();
206
- }
207
- }
208
-
209
- /// Helper to indicate split-K behavior
210
- CUTLASS_DEVICE
211
- void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
212
- int split_k_slices)
213
- { ///< Total number of split-K slices
214
- }
215
-
216
- /// Called to set the batch index
217
- CUTLASS_DEVICE
218
- void set_batch_index(int batch_idx)
219
- {
220
- iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
221
- iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
222
- iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
223
- }
224
-
225
- /// Called at the start of the epilogue just before iterating over accumulator slices
226
- CUTLASS_DEVICE
227
- void begin_epilogue()
228
- {
229
- if (per_channel_quant_) {
230
- iterator_alpha_col_.load(fragment_alpha_col_);
231
- }
232
- else if (ptr_alpha_col_ != nullptr) {
233
- arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
234
- element_alpha_col_, ptr_alpha_col_, true);
235
- }
236
-
237
- if (!per_token_quant_ && ptr_alpha_row_ != nullptr) {
238
- arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
239
- element_alpha_row_, ptr_alpha_row_, true);
240
- }
241
- }
242
-
243
- /// Called at the start of one step before starting accumulator exchange
244
- CUTLASS_DEVICE
245
- void begin_step(int step_idx)
246
- {
247
- fragment_D_.clear();
248
- fragment_C_.clear();
249
-
250
- if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
251
- iterator_C_.load(fragment_C_);
252
- ++iterator_C_;
253
- }
254
-
255
- // load alpha_row in begin_step only when per token(row) scaling is used
256
- if (per_token_quant_) {
257
- int thread_offset_row =
258
- iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(0).row();
259
-
260
- // element_alpha_row_ = ptr_alpha_row_[thread_offset_row];
261
- arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
262
- element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
263
- }
264
- }
265
-
266
- /// Called at the start of a row
267
- CUTLASS_DEVICE
268
- void begin_row(int row_idx)
269
- {
270
- // Clear accumulators for max and sum when starting a whole row
271
- }
272
-
273
- /// Called after accumulators have been exchanged for each accumulator vector
274
- CUTLASS_DEVICE
275
- void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum)
276
- {
277
-
278
- NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
279
-
280
- ComputeFragment result = source_converter(accum);
281
- if (per_channel_quant_) {
282
- ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[frag_idx];
283
- result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
284
- }
285
- else {
286
- result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
287
- }
288
-
289
- /* printf("%d %e\n", accum[0], result[0]); */
290
- /* scale_accumulator_(result, alpha_row_vector[0]); //TODO(mseznec) */
291
-
292
- /* if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { */
293
- /* result = source_converter(elementwise_(result)); */
294
- /* } else { */
295
- /* result = source_converter(elementwise_(result, source_vector)); */
296
- /* } */
297
-
298
- /* // Convert to the output */
299
- NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
300
- OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
301
- output = output_converter(result);
302
- }
303
-
304
- /// Called at the end of a row
305
- CUTLASS_DEVICE
306
- void end_row(int row_idx)
307
- {
308
-
309
- /* using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>; */
310
- /* using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>; */
311
-
312
- /* ConvertSumOutput convert_sum_output; */
313
- /* ConvertNormOutput convert_norm_output; */
314
-
315
- /* // Compute accumulate sum only in the last step */
316
- /* accum_sum_ = warp_reduce_sum_(accum_sum_); */
317
-
318
- /* bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0); */
319
- /* bool row_guard = thread_offset_.row() < extent_.row(); */
320
- /* bool is_write_thread = row_guard && is_first_thread_in_tile; */
321
-
322
- /* int block_batch = blockIdx.z; */
323
-
324
- /* ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch *
325
- * params_.batch_stride_Max; */
326
- /* ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch *
327
- * params_.batch_stride_Sum; */
328
-
329
- /* arch::global_store<ElementNorm, sizeof(ElementNorm)>( */
330
- /* convert_norm_output(accum_max_), */
331
- /* (void *)curr_ptr_max, */
332
- /* is_write_thread); */
333
-
334
- /* arch::global_store<ElementSum, sizeof(ElementSum)>( */
335
- /* convert_sum_output(accum_sum_), */
336
- /* (void *)curr_ptr_sum, */
337
- /* is_write_thread); */
338
-
339
- /* // Clear accumulators for max and sum when finishing a whole row */
340
- /* clear_accum_(); */
341
- }
342
-
343
- /// Called after all accumulator elements have been visited
344
- CUTLASS_DEVICE
345
- void end_step(int step_idx)
346
- {
347
-
348
- iterator_D_.store(fragment_D_);
349
- ++iterator_D_;
350
- }
351
-
352
- /// Called after all steps have been completed
353
- CUTLASS_DEVICE
354
- void end_epilogue() {}
355
-
356
- private:
357
- CUTLASS_DEVICE
358
- ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum,
359
- ComputeFragment const& scale_col,
360
- AlphaScaleElementType const& scale_row)
361
- {
362
-
363
- ComputeFragment result;
364
- CUTLASS_PRAGMA_UNROLL
365
- for (int i = 0; i < ComputeFragment::kElements; ++i) {
366
- result[i] = accum[i] * (scale_col[i] * scale_row);
367
- }
368
-
369
- return result;
370
- }
371
-
372
- CUTLASS_DEVICE
373
- ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum,
374
- AlphaScaleElementType const& scale_col,
375
- AlphaScaleElementType const& scale_row)
376
- {
377
-
378
- ComputeFragment result;
379
- CUTLASS_PRAGMA_UNROLL
380
- for (int i = 0; i < ComputeFragment::kElements; ++i) {
381
- result[i] = accum[i] * (scale_col * scale_row);
382
- }
383
-
384
- return result;
385
- }
386
- };
387
-
388
- } // namespace threadblock
389
- } // namespace epilogue
390
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h DELETED
@@ -1,285 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
33
-
34
- The epilogue rearranges the result of a matrix product through shared memory to match canonical
35
- tensor layouts in global memory. Epilogues support conversion and reduction operations.
36
-
37
- original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
38
-
39
- */
40
-
41
- #pragma once
42
-
43
- #include "cutlass/array.h"
44
- #include "cutlass/cutlass.h"
45
- #include "cutlass/numeric_types.h"
46
-
47
- #include "cutlass/platform/platform.h"
48
-
49
- #include "cutlass/gemm/gemm.h"
50
-
51
- #include "cutlass/epilogue/thread/linear_combination.h"
52
- #include "cutlass/epilogue/thread/linear_combination_clamp.h"
53
- #include "cutlass/epilogue/thread/linear_combination_gelu.h"
54
- #include "cutlass/epilogue/thread/linear_combination_hardswish.h"
55
- #include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
56
- #include "cutlass/epilogue/thread/linear_combination_relu.h"
57
- #include "cutlass/epilogue/thread/linear_combination_relu0.h"
58
- #include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
59
-
60
- #include "cutlass/epilogue/thread/conversion_op.h"
61
- #include "cutlass/epilogue/thread/reduction_op.h"
62
-
63
- #include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
64
-
65
- #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
66
- #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
67
- #include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
68
- #include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
69
- #include "cutlass/epilogue/threadblock/shared_load_iterator.h"
70
- #include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
71
- #include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
72
- #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
73
- #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
74
- #include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
75
-
76
- #include "cutlass/epilogue/threadblock/epilogue.h"
77
- #include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
78
-
79
- #include "cutlass/layout/permute.h"
80
-
81
- ////////////////////////////////////////////////////////////////////////////////
82
-
83
- namespace cutlass {
84
- namespace epilogue {
85
- namespace threadblock {
86
-
87
- ////////////////////////////////////////////////////////////////////////////////
88
-
89
- namespace detail {
90
-
91
- /// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
92
- template<typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
93
- struct DefaultIteratorsTensorOp<cutlass::half_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
94
-
95
- using WarpTileIterator =
96
- cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape, InstructionShape, int32_t, layout::RowMajor>;
97
-
98
- using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
99
-
100
- static int const kFragmentsPerIteration = 1;
101
- };
102
-
103
- /// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
104
- template<typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
105
- struct DefaultIteratorsTensorOp<cutlass::bfloat16_t,
106
- int32_t,
107
- 8,
108
- ThreadblockShape,
109
- WarpShape,
110
- InstructionShape,
111
- ThreadMap> {
112
-
113
- using WarpTileIterator =
114
- cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape, InstructionShape, int32_t, layout::RowMajor>;
115
-
116
- using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
117
-
118
- static int const kFragmentsPerIteration = 1;
119
- };
120
-
121
- /////////////////////////////////////////////////////////////////////////////////////////////////
122
-
123
- } // namespace detail
124
-
125
- /////////////////////////////////////////////////////////////////////////////////////////////////
126
-
127
- /// Tile iterator used to load output tile from shared memory in epilogue.
128
- ///
129
- /// Satisfies: ReadableTileIterator
130
- ///
131
- template<typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
132
- >
133
- class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
134
- public:
135
- using ThreadMap = ThreadMap_;
136
- using Shape = typename ThreadMap::Shape;
137
-
138
- using Element = int32_t;
139
-
140
- using Layout = layout::RowMajor;
141
- using TensorRef = TensorRef<Element, Layout>;
142
- using ConstTensorRef = typename TensorRef::ConstTensorRef;
143
-
144
- using Index = typename Layout::Index;
145
- using LongIndex = typename Layout::LongIndex;
146
- using TensorCoord = MatrixCoord;
147
-
148
- static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
149
-
150
- static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
151
-
152
- static int const kThreads = ThreadMap::kThreads;
153
-
154
- /// Fragment object
155
- using Fragment = Array<Element,
156
- ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup
157
- * ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
158
-
159
- /// Memory access size
160
- using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
161
-
162
- /// Vector type used for SMEM loads
163
- using LoadType = AlignedArray<Element,
164
- const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
165
- const_min(16, kAlignment)>;
166
-
167
- static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
168
-
169
- private:
170
- //
171
- // Data members
172
- //
173
-
174
- /// Byte-level pointer
175
- LoadType const* pointers_[kLoadsPerAccess];
176
-
177
- /// Stride along adjacent rows in units of LoadType
178
- int stride_;
179
-
180
- public:
181
- //
182
- // Methods
183
- //
184
-
185
- /// Constructor
186
- CUTLASS_DEVICE
187
- SharedLoadIteratorMixed(TensorRef ref, int thread_idx): stride_((ref.stride(0) / LoadType::kElements))
188
- {
189
-
190
- TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
191
-
192
- // Initialize pointers
193
- CUTLASS_PRAGMA_UNROLL
194
- for (int i = 0; i < kLoadsPerAccess; ++i) {
195
- pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
196
-
197
- int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
198
- int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
199
-
200
- col_idx += (bank_offset + i) % kLoadsPerAccess;
201
-
202
- pointers_[i] += thread_offset.row() * stride_ + col_idx;
203
- }
204
- }
205
-
206
- /// Adds a pointer offset in units of Element
207
- CUTLASS_HOST_DEVICE
208
- void add_pointer_offset(LongIndex pointer_offset)
209
- {
210
- CUTLASS_PRAGMA_UNROLL
211
- for (int i = 0; i < kLoadsPerAccess; ++i) {
212
- pointers_[i] += pointer_offset / LoadType::kElements;
213
- }
214
- }
215
-
216
- CUTLASS_DEVICE
217
- void add_tile_offset(TensorCoord const& offset)
218
- {
219
- CUTLASS_PRAGMA_UNROLL
220
- for (int i = 0; i < kLoadsPerAccess; ++i) {
221
- pointers_[i] +=
222
- offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
223
- }
224
- }
225
-
226
- /// Loads a fragment from memory
227
- CUTLASS_DEVICE
228
- void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const
229
- {
230
-
231
- CUTLASS_PRAGMA_UNROLL
232
- for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
233
-
234
- CUTLASS_PRAGMA_UNROLL
235
- for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
236
-
237
- CUTLASS_PRAGMA_UNROLL
238
- for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
239
-
240
- int row_ptr_offset =
241
- row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_
242
- + cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements;
243
-
244
- int frag_row_idx =
245
- (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
246
-
247
- LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
248
-
249
- CUTLASS_PRAGMA_UNROLL
250
- for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
251
-
252
- int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
253
-
254
- CUTLASS_PRAGMA_UNROLL
255
- for (int v = 0; v < kLoadsPerAccess; ++v) {
256
-
257
- int vector_idx =
258
- (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
259
-
260
- LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
261
-
262
- frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
263
- }
264
- }
265
- }
266
- }
267
- }
268
- }
269
-
270
- /// Loads a fragment
271
- CUTLASS_DEVICE
272
- void load(Fragment& frag) const
273
- {
274
-
275
- load_with_pointer_offset(frag, 0);
276
- }
277
- };
278
-
279
- /////////////////////////////////////////////////////////////////////////////////////////////////
280
-
281
- } // namespace threadblock
282
- } // namespace epilogue
283
- } // namespace cutlass
284
-
285
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h DELETED
@@ -1,82 +0,0 @@
1
- /**
2
- * @file epilogue_helpers.h
3
- *
4
- * This file includes types for the epilogues. The empty structs exist so we can signal to template
5
- * code the type of epilogue we want to run, and let the underlying code specify the details such as
6
- * element types, accumulator type and elements per vector access.
7
- *
8
- */
9
-
10
- #pragma once
11
-
12
- #include "cutlass/epilogue/thread/linear_combination.h"
13
- #include "cutlass/epilogue/thread/linear_combination_generic.h"
14
- #include "cutlass/epilogue/thread/linear_combination_relu.h"
15
- #include "cutlass/epilogue/thread/linear_combination_silu.h"
16
- #include "cutlass_extensions/epilogue/thread/ft_fused_activations.h"
17
-
18
- namespace fastertransformer {
19
-
20
- struct EpilogueOpBiasSilu {};
21
-
22
- struct EpilogueOpBiasReLU {};
23
-
24
- struct EpilogueOpBiasFtGelu {};
25
-
26
- struct EpilogueOpBias {};
27
-
28
- struct EpilogueOpNoBias {};
29
-
30
- template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
31
- struct Epilogue {
32
- };
33
-
34
- template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
35
- struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu> {
36
- using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType,
37
- ElementsPerVectorAccess,
38
- ElementAccumulator,
39
- ElementAccumulator,
40
- cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
41
- };
42
-
43
- template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
44
- struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU> {
45
- using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType,
46
- ElementsPerVectorAccess,
47
- ElementAccumulator,
48
- ElementAccumulator,
49
- cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
50
- };
51
-
52
- template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
53
- struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu> {
54
- using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor_fixed,
55
- ElementType,
56
- ElementsPerVectorAccess,
57
- ElementAccumulator,
58
- ElementAccumulator,
59
- cutlass::epilogue::thread::ScaleType::NoBetaScaling,
60
- cutlass::FloatRoundStyle::round_to_nearest,
61
- true>;
62
- };
63
-
64
- template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
65
- struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias> {
66
- using Op = cutlass::epilogue::thread::LinearCombination<ElementType,
67
- ElementsPerVectorAccess,
68
- ElementAccumulator,
69
- ElementAccumulator,
70
- cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
71
- };
72
-
73
- template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
74
- struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBias> {
75
- using Op = cutlass::epilogue::thread::LinearCombination<ElementType,
76
- ElementsPerVectorAccess,
77
- ElementAccumulator,
78
- ElementAccumulator,
79
- cutlass::epilogue::thread::ScaleType::Default>;
80
- };
81
-
82
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h DELETED
@@ -1,58 +0,0 @@
1
- /*
2
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #pragma once
18
-
19
- namespace fastertransformer {
20
- // Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
21
- // in the kernel layout details when doing weight only quantization.
22
- enum class CutlassTileConfig {
23
- // Signals that we should run heuristics do choose a config
24
- Undefined,
25
-
26
- // Signals that we should run heuristics do choose a config
27
- ChooseWithHeuristic,
28
-
29
- // SiMT config
30
- CtaShape128x128x8_WarpShape64x64x8,
31
-
32
- // TensorCore configs CTA_N = 128, CTA_K = 64
33
- // Warp configs for M=32
34
- CtaShape32x128x64_WarpShape32x32x64,
35
-
36
- // Warp configs for M=64
37
- CtaShape64x128x64_WarpShape32x64x64,
38
- CtaShape64x128x64_WarpShape64x32x64,
39
-
40
- // Warp configs for M=128
41
- CtaShape128x128x64_WarpShape64x32x64,
42
- CtaShape128x128x64_WarpShape128x32x64
43
- };
44
-
45
- enum class SplitKStyle {
46
- NO_SPLIT_K,
47
- SPLIT_K_SERIAL,
48
- // SPLIT_K_PARALLEL // Not supported yet
49
- };
50
-
51
- struct CutlassGemmConfig {
52
- CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
53
- SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
54
- int split_k_factor = -1;
55
- int stages = -1;
56
- };
57
-
58
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h DELETED
@@ -1,123 +0,0 @@
1
- #pragma once
2
-
3
- #include "cutlass/arch/arch.h"
4
- #include "cutlass/arch/mma.h"
5
- #include "cutlass/bfloat16.h"
6
- #include "cutlass/cutlass.h"
7
- #include "cutlass/gemm/gemm.h"
8
- #include "cutlass/layout/matrix.h"
9
-
10
- #include "cutlass_extensions/arch/mma.h"
11
- #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
12
-
13
- namespace cutlass {
14
- namespace gemm {
15
- namespace kernel {
16
-
17
- template<typename TypeA, typename TypeB, typename arch, typename Enable = void>
18
- struct MixedGemmArchTraits {
19
- };
20
-
21
- template<typename arch>
22
- struct MixedGemmArchTraits<float, float, arch> {
23
- static constexpr int Stages = 2;
24
- using OperatorClass = cutlass::arch::OpClassSimt;
25
- using AccType = float;
26
- using LayoutB = cutlass::layout::RowMajor;
27
-
28
- static constexpr int ElementsPerAccessA = 1;
29
- static constexpr int ElementsPerAccessB = 1;
30
- static constexpr int ElementsPerAccessC = 1;
31
- static constexpr int ThreadblockK = 8;
32
- using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
33
-
34
- using Operator = cutlass::arch::OpMultiplyAdd;
35
- };
36
-
37
- // ========================= Volta Traits ===========================
38
- // Volta will always dequantize after the global memory load.
39
- // This will instantiate any HMMA tensorcore kernels for Volta.
40
- // Note that volta does not have native bfloat support so weights and activations will be casted to fp16
41
- // and compute will happen in fp16 then will be converted for bf16 output.
42
- template<typename TypeA, typename TypeB>
43
- struct MixedGemmArchTraits<
44
- TypeA,
45
- TypeB,
46
- cutlass::arch::Sm70,
47
- typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
48
- || cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
49
- private:
50
- using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm70>;
51
-
52
- public:
53
- static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
54
-
55
- using OperatorClass = cutlass::arch::OpClassTensorOp;
56
- using AccType = float;
57
- using LayoutB = typename LayoutDetails::Layout;
58
-
59
- static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
60
- static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
61
- static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
62
- using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
63
-
64
- using Operator = typename LayoutDetails::Operator;
65
- };
66
-
67
- // ======================= Turing Traits ==============================
68
- // Note that turing does not have native bfloat support so weights and activations will be casted to fp16
69
- // and compute will happen in fp16 then will be converted for bf16 output.
70
- template<typename TypeA, typename TypeB>
71
- struct MixedGemmArchTraits<
72
- TypeA,
73
- TypeB,
74
- cutlass::arch::Sm75,
75
- typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
76
- || cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
77
- private:
78
- using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm75>;
79
-
80
- public:
81
- static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
82
-
83
- using OperatorClass = cutlass::arch::OpClassTensorOp;
84
- using AccType = float;
85
- using LayoutB = typename LayoutDetails::Layout;
86
-
87
- static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
88
- static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
89
- static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
90
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
91
-
92
- using Operator = typename LayoutDetails::Operator;
93
- };
94
-
95
- // ======================= Ampere Traits ==============================
96
- template<typename TypeA, typename TypeB>
97
- struct MixedGemmArchTraits<
98
- TypeA,
99
- TypeB,
100
- cutlass::arch::Sm80,
101
- typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
102
- || cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
103
- private:
104
- using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm80>;
105
-
106
- public:
107
- static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
108
-
109
- using OperatorClass = cutlass::arch::OpClassTensorOp;
110
- using AccType = float;
111
- using LayoutB = typename LayoutDetails::Layout;
112
-
113
- static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
114
- static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
115
- static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
116
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
117
-
118
- using Operator = typename LayoutDetails::Operator;
119
- };
120
-
121
- } // namespace kernel
122
- } // namespace gemm
123
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h DELETED
@@ -1,492 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
-
32
- /*! \file
33
- \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
34
- */
35
-
36
- #pragma once
37
-
38
- #include "cutlass/cutlass.h"
39
-
40
- #include "cutlass/arch/arch.h"
41
- #include "cutlass/gemm/gemm.h"
42
- #include "cutlass/matrix_coord.h"
43
- #include "cutlass/semaphore.h"
44
-
45
- /////////////////////////////////////////////////////////////////////////////////////////////////
46
-
47
- namespace cutlass {
48
- namespace gemm {
49
- namespace kernel {
50
-
51
- /////////////////////////////////////////////////////////////////////////////////////////////////
52
-
53
- template<typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
54
- typename Epilogue_, ///! Epilogue
55
- typename ThreadblockSwizzle_, ///! Threadblock swizzling function
56
- typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
57
- /// arch.
58
- bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
59
- >
60
- struct GemmFpAIntB {
61
-
62
- using Mma = Mma_;
63
- using Epilogue = Epilogue_;
64
- using EpilogueOutputOp = typename Epilogue::OutputOp;
65
- using ThreadblockSwizzle = ThreadblockSwizzle_;
66
- static bool const kSplitKSerial = SplitKSerial;
67
-
68
- using ElementA = typename Mma::IteratorA::Element;
69
- using LayoutA = typename Mma::IteratorA::Layout;
70
- using ElementB = typename Mma::IteratorB::Element;
71
- using LayoutB = typename Mma::IteratorB::Element;
72
- using ElementC = typename Epilogue::OutputTileIterator::Element;
73
- using LayoutC = typename Mma::LayoutC;
74
- using ElementScale = ElementC;
75
-
76
- static ComplexTransform const kTransformA = Mma::kTransformA;
77
- static ComplexTransform const kTransformB = Mma::kTransformA;
78
-
79
- // Type definitions about the mainloop.
80
- using Operator = typename Mma::Operator;
81
- using OperatorClass = typename Mma::Operator::OperatorClass;
82
- using ThreadblockShape = typename Mma::Shape;
83
- using WarpShape = typename Mma::Operator::Shape;
84
- using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
85
- using ArchTag = typename Mma::ArchTag;
86
-
87
- static int const kStages = Mma::kStages;
88
- static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
89
- static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
90
- static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
91
-
92
- /// Warp count (concept: GemmShape)
93
- using WarpCount = typename Mma::WarpCount;
94
- static int const kThreadCount = 32 * WarpCount::kCount;
95
-
96
- static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
97
-
98
- /// Parameters structure
99
- struct Arguments {
100
- GemmUniversalMode mode = GemmUniversalMode::kGemm;
101
-
102
- cutlass::gemm::GemmCoord problem_size;
103
- typename Mma::IteratorA::TensorRef ref_A;
104
- typename Mma::IteratorB::TensorRef ref_B;
105
- typename Mma::IteratorScale::TensorRef ref_scale;
106
- typename Epilogue::OutputTileIterator::TensorRef ref_C;
107
- typename Epilogue::OutputTileIterator::TensorRef ref_D;
108
-
109
- // Control serial split-k
110
- int batch_count;
111
-
112
- typename EpilogueOutputOp::Params output_op;
113
-
114
- // For gather+scatter operations
115
- int const* gather_A_indices;
116
- int const* gather_B_indices;
117
- int const* scatter_D_indices;
118
-
119
- // Included so we can use Gemm Universal
120
- int batch_stride_D = 0;
121
-
122
- //
123
- // Methods
124
- //
125
-
126
- CUTLASS_HOST_DEVICE
127
- Arguments() {}
128
-
129
- CUTLASS_HOST_DEVICE
130
- Arguments(cutlass::gemm::GemmCoord const& problem_size,
131
- typename Mma::IteratorA::TensorRef ref_A,
132
- typename Mma::IteratorB::TensorRef ref_B,
133
- typename Mma::IteratorScale::TensorRef ref_scale,
134
- typename Epilogue::OutputTileIterator::TensorRef ref_C,
135
- typename Epilogue::OutputTileIterator::TensorRef ref_D,
136
- int serial_split_k_factor,
137
- typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
138
- int const* gather_A_indices = nullptr,
139
- int const* gather_B_indices = nullptr,
140
- int const* scatter_D_indices = nullptr):
141
- problem_size(problem_size),
142
- ref_A(ref_A),
143
- ref_B(ref_B),
144
- ref_scale(ref_scale),
145
- ref_C(ref_C),
146
- ref_D(ref_D),
147
- batch_count(serial_split_k_factor),
148
- output_op(output_op),
149
- gather_A_indices(gather_A_indices),
150
- gather_B_indices(gather_B_indices),
151
- scatter_D_indices(scatter_D_indices)
152
- {
153
- }
154
- };
155
-
156
- /// Parameters structure
157
- struct Params {
158
- cutlass::gemm::GemmCoord problem_size;
159
- cutlass::gemm::GemmCoord grid_tiled_shape;
160
- int swizzle_log_tile;
161
- typename Mma::IteratorA::Params params_A;
162
- typename Mma::IteratorA::TensorRef ref_A;
163
- typename Mma::IteratorB::Params params_B;
164
- typename Mma::IteratorB::TensorRef ref_B;
165
- typename Mma::IteratorScale::Params params_scale;
166
- typename Mma::IteratorScale::TensorRef ref_scale;
167
- typename Epilogue::OutputTileIterator::Params params_C;
168
- typename Epilogue::OutputTileIterator::TensorRef ref_C;
169
- typename Epilogue::OutputTileIterator::Params params_D;
170
- typename Epilogue::OutputTileIterator::TensorRef ref_D;
171
- typename EpilogueOutputOp::Params output_op;
172
- int* semaphore;
173
- int gemm_k_size;
174
- // For gather+scatter operations
175
- int const* gather_A_indices;
176
- int const* gather_B_indices;
177
- int const* scatter_D_indices;
178
-
179
- //
180
- // Methods
181
- //
182
-
183
- CUTLASS_HOST_DEVICE
184
- Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {}
185
-
186
- CUTLASS_HOST_DEVICE
187
- Params(Arguments const& args,
188
- cutlass::gemm::GemmCoord const& grid_tiled_shape,
189
- const int gemm_k_size,
190
- void* workspace = nullptr):
191
- problem_size(args.problem_size),
192
- grid_tiled_shape(grid_tiled_shape),
193
- swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
194
- params_A(args.ref_A.layout()),
195
- ref_A(args.ref_A),
196
- params_B(args.ref_B.layout()),
197
- ref_B(args.ref_B),
198
- params_scale(args.ref_scale.layout()),
199
- ref_scale(args.ref_scale),
200
- params_C(args.ref_C.layout()),
201
- ref_C(args.ref_C),
202
- params_D(args.ref_D.layout()),
203
- ref_D(args.ref_D),
204
- output_op(args.output_op),
205
- semaphore(static_cast<int*>(workspace)),
206
- gemm_k_size(gemm_k_size),
207
- gather_A_indices(args.gather_A_indices),
208
- gather_B_indices(args.gather_B_indices),
209
- scatter_D_indices(args.scatter_D_indices)
210
- {
211
- }
212
- };
213
-
214
- /// Shared memory storage structure
215
- union SharedStorage {
216
- typename Mma::SharedStorage main_loop;
217
- typename Epilogue::SharedStorage epilogue;
218
- };
219
-
220
- //
221
- // Methods
222
- //
223
-
224
- CUTLASS_HOST_DEVICE
225
- GemmFpAIntB() {}
226
-
227
- /// Determines whether kernel satisfies alignment
228
- CUTLASS_HOST_DEVICE
229
- static Status can_implement(Arguments const& args)
230
- {
231
-
232
- static int const kAlignmentA =
233
- (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ?
234
- 32 :
235
- (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value) ?
236
- 64 :
237
- Mma::IteratorA::AccessType::kElements;
238
- static int const kAlignmentB =
239
- (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ?
240
- 32 :
241
- (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value) ?
242
- 64 :
243
- Mma::IteratorB::AccessType::kElements;
244
-
245
- static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
246
-
247
- static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
248
- layout::ColumnMajorInterleaved<32>>::value) ?
249
- 32 :
250
- (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
251
- layout::ColumnMajorInterleaved<64>>::value) ?
252
- 64 :
253
- Epilogue::OutputTileIterator::kElementsPerAccess;
254
-
255
- if (!TensorRef_aligned(args.ref_A, kAlignmentA)) {
256
- return Status::kErrorMisalignedOperand;
257
- }
258
-
259
- if (!TensorRef_aligned(args.ref_B, kAlignmentB)) {
260
- return Status::kErrorMisalignedOperand;
261
- }
262
-
263
- if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) {
264
- return Status::kErrorMisalignedOperand;
265
- }
266
-
267
- if (!TensorRef_aligned(args.ref_C, kAlignmentC)) {
268
- return Status::kErrorMisalignedOperand;
269
- }
270
-
271
- if (!TensorRef_aligned(args.ref_D, kAlignmentC)) {
272
- return Status::kErrorMisalignedOperand;
273
- }
274
-
275
- return Status::kSuccess;
276
- }
277
-
278
- static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
279
- {
280
-
281
- return 0;
282
- }
283
-
284
- // The dummy template parameter is not used and exists so that we can compile this code using
285
- // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in
286
- // a namespace
287
- template<bool B, typename dummy = void>
288
- struct KernelRunner {
289
- CUTLASS_DEVICE
290
- static void run_kernel(Params const& params, SharedStorage& shared_storage)
291
- {
292
- CUTLASS_NOT_IMPLEMENTED();
293
- }
294
- };
295
-
296
- template<typename dummy>
297
- struct KernelRunner<true, dummy> {
298
- CUTLASS_DEVICE
299
- static void run_kernel(Params const& params, SharedStorage& shared_storage)
300
- {
301
- using LayoutB = typename Mma::IteratorB::Layout;
302
- static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
303
- || platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
304
- "B must be row major/col major OR col major interleaved.");
305
-
306
- // Compute threadblock location
307
- ThreadblockSwizzle threadblock_swizzle;
308
-
309
- cutlass::gemm::GemmCoord threadblock_tile_offset =
310
- threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
311
-
312
- // Early exit if CTA is out of range
313
- if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
314
- || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
315
-
316
- return;
317
- }
318
-
319
- // Compute initial location in logical coordinates
320
- cutlass::MatrixCoord tb_offset_A{
321
- threadblock_tile_offset.m() * Mma::Shape::kM,
322
- threadblock_tile_offset.k() * params.gemm_k_size,
323
- };
324
-
325
- cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
326
- threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
327
-
328
- cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() * Mma::Shape::kN};
329
-
330
- // Problem size is a function of threadblock index in the K dimension
331
- int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
332
-
333
- // Compute threadblock-scoped matrix multiply-add
334
- int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
335
-
336
- // Compute position within threadblock
337
- int thread_idx = threadIdx.x;
338
-
339
- // Construct iterators to A and B operands
340
- typename Mma::IteratorA iterator_A(params.params_A,
341
- params.ref_A.data(),
342
- {params.problem_size.m(), problem_size_k},
343
- thread_idx,
344
- tb_offset_A,
345
- params.gather_A_indices);
346
-
347
- typename Mma::IteratorB iterator_B(params.params_B,
348
- params.ref_B.data(),
349
- {problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
350
- thread_idx,
351
- tb_offset_B,
352
- params.gather_B_indices);
353
-
354
- typename Mma::IteratorScale iterator_scale(params.params_scale,
355
- params.ref_scale.data(),
356
- {1, params.problem_size.n()},
357
- thread_idx,
358
- tb_offset_scale);
359
-
360
- // Broadcast the warp_id computed by lane 0 to ensure dependent code
361
- // is compiled as warp-uniform.
362
- int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
363
- int lane_idx = threadIdx.x % 32;
364
-
365
- //
366
- // Main loop
367
- //
368
- // Construct thread-scoped matrix multiply
369
- Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
370
-
371
- typename Mma::FragmentC accumulators;
372
-
373
- accumulators.clear();
374
-
375
- if (!kSplitKSerial || gemm_k_iterations > 0) {
376
- // Compute threadblock-scoped matrix multiply-add
377
- mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
378
- }
379
-
380
- //
381
- // Epilogue
382
- //
383
-
384
- EpilogueOutputOp output_op(params.output_op);
385
-
386
- //
387
- // Masked tile iterators constructed from members
388
- //
389
-
390
- threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
391
-
392
- // assume identity swizzle
393
- MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM,
394
- threadblock_tile_offset.n() * Mma::Shape::kN);
395
-
396
- int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
397
-
398
- // Construct the semaphore.
399
- Semaphore semaphore(params.semaphore + block_idx, thread_idx);
400
-
401
- // If performing a reduction via split-K, fetch the initial synchronization
402
- if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
403
-
404
- // Fetch the synchronization lock initially but do not block.
405
- semaphore.fetch();
406
-
407
- // Indicate which position in a serial reduction the output operator is currently updating
408
- output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
409
- }
410
-
411
- // Tile iterator loading from source tensor.
412
- typename Epilogue::OutputTileIterator iterator_C(params.params_C,
413
- params.ref_C.data(),
414
- params.problem_size.mn(),
415
- thread_idx,
416
- threadblock_offset,
417
- params.scatter_D_indices);
418
-
419
- // Tile iterator writing to destination tensor.
420
- typename Epilogue::OutputTileIterator iterator_D(params.params_D,
421
- params.ref_D.data(),
422
- params.problem_size.mn(),
423
- thread_idx,
424
- threadblock_offset,
425
- params.scatter_D_indices);
426
-
427
- Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
428
-
429
- // Wait on the semaphore - this latency may have been covered by iterator construction
430
- if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
431
-
432
- // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
433
- if (threadblock_tile_offset.k()) {
434
- iterator_C = iterator_D;
435
- }
436
-
437
- semaphore.wait(threadblock_tile_offset.k());
438
- }
439
-
440
- // Execute the epilogue operator to update the destination tensor.
441
- epilogue(output_op, iterator_D, accumulators, iterator_C);
442
-
443
- //
444
- // Release the semaphore
445
- //
446
-
447
- if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
448
-
449
- int lock = 0;
450
- if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
451
-
452
- // The final threadblock resets the semaphore for subsequent grids.
453
- lock = 0;
454
- }
455
- else {
456
- // Otherwise, the semaphore is incremented
457
- lock = threadblock_tile_offset.k() + 1;
458
- }
459
-
460
- semaphore.release(lock);
461
- }
462
- }
463
- };
464
-
465
- /*
466
- To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
467
- to the ArchTag of the cutlass kernel operator.
468
- */
469
- /// Executes one GEMM
470
- CUTLASS_DEVICE
471
- void operator()(Params const& params, SharedStorage& shared_storage)
472
- {
473
- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
474
- static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value;
475
- KernelRunner<compile_needed>::run_kernel(params, shared_storage);
476
- #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
477
- static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
478
- KernelRunner<compile_needed>::run_kernel(params, shared_storage);
479
- #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
480
- static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
481
- KernelRunner<compile_needed>::run_kernel(params, shared_storage);
482
- #else
483
- CUTLASS_NOT_IMPLEMENTED();
484
- #endif
485
- }
486
- };
487
-
488
- /////////////////////////////////////////////////////////////////////////////////////////////////
489
-
490
- } // namespace kernel
491
- } // namespace gemm
492
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h DELETED
@@ -1,447 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
3
- *reserved. SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice,
9
- *this list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22
- *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23
- *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24
- *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25
- *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26
- *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27
- *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28
- *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29
- *POSSIBILITY OF SUCH DAMAGE.
30
- *
31
- **************************************************************************************************/
32
-
33
- /*! \file
34
- \brief Template for a pipelined GEMM kernel. Does not compute batching or
35
- support split-K.
36
- */
37
-
38
- #pragma once
39
-
40
- #include "cutlass/cutlass.h"
41
-
42
- #include "cutlass/arch/arch.h"
43
- #include "cutlass/gemm/gemm.h"
44
- #include "cutlass/matrix_coord.h"
45
- #include "cutlass/semaphore.h"
46
-
47
- /////////////////////////////////////////////////////////////////////////////////////////////////
48
-
49
- namespace cutlass {
50
- namespace gemm {
51
- namespace kernel {
52
-
53
- /////////////////////////////////////////////////////////////////////////////////////////////////
54
-
55
- template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
56
- typename Epilogue_, ///! Epilogue
57
- typename ThreadblockSwizzle_, ///! Threadblock swizzling function
58
- typename KernelArch ///! The Architecture this kernel is compiled for.
59
- /// Used since SIMT kernels lose top-level arch.
60
- //////
61
- >
62
- struct GemmFpAIntBWithBroadcast {
63
-
64
- using Mma = Mma_;
65
- using Epilogue = Epilogue_;
66
- using EpilogueOutputOp = typename Epilogue::OutputOp;
67
- using ThreadblockSwizzle = ThreadblockSwizzle_;
68
-
69
- using ElementA = typename Mma::IteratorA::Element;
70
- using LayoutA = typename Mma::IteratorA::Layout;
71
- using ElementB = typename Mma::IteratorB::Element;
72
- using LayoutB = typename Mma::IteratorB::Element;
73
- using ElementC = typename Epilogue::OutputTileIterator::Element;
74
- using LayoutC = typename Mma::LayoutC;
75
- using ElementScale = ElementC;
76
-
77
- static ComplexTransform const kTransformA = Mma::kTransformA;
78
- static ComplexTransform const kTransformB = Mma::kTransformA;
79
-
80
- // Type definitions about the mainloop.
81
- using Operator = typename Mma::Operator;
82
- using OperatorClass = typename Mma::Operator::OperatorClass;
83
- using ThreadblockShape = typename Mma::Shape;
84
- using WarpShape = typename Mma::Operator::Shape;
85
- using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
86
- using ArchTag = typename Mma::ArchTag;
87
-
88
- static int const kStages = Mma::kStages;
89
- static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
90
- static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
91
- static int const kAlignmentC =
92
- Epilogue::OutputTileIterator::kElementsPerAccess;
93
-
94
- /// Warp count (concept: GemmShape)
95
- using WarpCount = typename Mma::WarpCount;
96
- static int const kThreadCount = 32 * WarpCount::kCount;
97
-
98
- static constexpr int kInterleave =
99
- Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
100
-
101
- /// Parameters structure
102
- struct Arguments {
103
- GemmUniversalMode mode = GemmUniversalMode::kGemm;
104
-
105
- cutlass::gemm::GemmCoord problem_size;
106
- int batch_count;
107
- typename EpilogueOutputOp::Params epilogue;
108
-
109
- void const *ptr_A;
110
- void const *ptr_B;
111
- void const *ptr_scales;
112
- void const *ptr_C;
113
- void *ptr_D;
114
-
115
- void const *ptr_Vector;
116
- void const *ptr_Tensor;
117
-
118
- int64_t batch_stride_A;
119
- int64_t batch_stride_B;
120
- int64_t batch_stride_C;
121
- int64_t batch_stride_D;
122
- int64_t batch_stride_Vector;
123
- int64_t batch_stride_Tensor;
124
-
125
- int lda, ldb, ldc, ldd, ldr, ldt;
126
-
127
- typename EpilogueOutputOp::Params output_op;
128
-
129
- // For gather+scatter operations
130
- int const *gather_A_indices;
131
- int const *gather_B_indices;
132
- int const *scatter_D_indices;
133
-
134
- CUTLASS_HOST_DEVICE
135
- Arguments() {}
136
-
137
- CUTLASS_HOST_DEVICE
138
- Arguments(cutlass::gemm::GemmCoord const &problem_size, int batch_count,
139
- typename EpilogueOutputOp::Params epilogue, void const *ptr_A,
140
- void const *ptr_B, void const *ptr_scales, void const *ptr_C,
141
- void *ptr_D, const void *ptr_Vector, const void *ptr_Tensor,
142
- int64_t batch_stride_A, int64_t batch_stride_B,
143
- int64_t batch_stride_C, int64_t batch_stride_D,
144
- int64_t batch_stride_Vector, int64_t batch_stride_Tensor,
145
- int lda, int ldb, int ldc, int ldd, int ldr, int ldt,
146
- typename EpilogueOutputOp::Params output_op =
147
- typename EpilogueOutputOp::Params())
148
- : problem_size(problem_size), batch_count(batch_count),
149
- epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B),
150
- ptr_scales(ptr_scales), ptr_C(ptr_C), ptr_D(ptr_D),
151
- ptr_Vector(ptr_Vector), ptr_Tensor(ptr_Tensor),
152
- batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B),
153
- batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
154
- batch_stride_Vector(batch_stride_Vector),
155
- batch_stride_Tensor(batch_stride_Tensor), lda(lda), ldb(ldb),
156
- ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt), output_op(output_op),
157
- gather_A_indices(nullptr), gather_B_indices(nullptr),
158
- scatter_D_indices(nullptr) {}
159
- };
160
-
161
- /// Parameters structure
162
- struct Params {
163
- cutlass::gemm::GemmCoord problem_size;
164
- cutlass::gemm::GemmCoord grid_tiled_shape;
165
- int swizzle_log_tile;
166
-
167
- typename Mma::IteratorA::Params params_A;
168
- typename Mma::IteratorB::Params params_B;
169
- typename Mma::IteratorScale::Params params_scale;
170
- typename Epilogue::OutputTileIterator::Params params_C;
171
- typename Epilogue::OutputTileIterator::Params params_D;
172
- typename Epilogue::TensorTileIterator::Params params_Tensor;
173
-
174
- typename EpilogueOutputOp::Params output_op;
175
-
176
- // GemmUniversalMode mode; todo
177
- int batch_count;
178
- int gemm_k_size;
179
- void *ptr_A;
180
- void *ptr_B;
181
- void *ptr_C;
182
- void *ptr_scales;
183
- void *ptr_D;
184
-
185
- void *ptr_Vector;
186
- typename LayoutC::Stride::Index ldr;
187
-
188
- void *ptr_Tensor;
189
-
190
- int64_t batch_stride_A;
191
- int64_t batch_stride_B;
192
- int64_t batch_stride_C;
193
- int64_t batch_stride_D;
194
- int64_t batch_stride_Vector;
195
- int64_t batch_stride_Tensor;
196
-
197
- // For gather+scatter operations
198
- int const *gather_A_indices;
199
- int const *gather_B_indices;
200
- int const *scatter_D_indices;
201
-
202
- //
203
- // Methods
204
- //
205
-
206
- CUTLASS_HOST_DEVICE
207
- Params() : swizzle_log_tile(0), gemm_k_size(0) {}
208
-
209
- CUTLASS_HOST_DEVICE
210
- Params(Arguments const &args,
211
- cutlass::gemm::GemmCoord const &grid_tiled_shape,
212
- const int gemm_k_size, void *workspace = nullptr)
213
- : problem_size(args.problem_size), grid_tiled_shape(grid_tiled_shape),
214
- swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
215
- params_A(args.lda), params_B(args.ldb), params_C(args.ldc),
216
- params_D(args.ldd), params_Tensor(args.ldt), output_op(args.epilogue),
217
- batch_count(args.batch_count), gemm_k_size(gemm_k_size),
218
- ptr_A(const_cast<void *>(args.ptr_A)),
219
- ptr_B(const_cast<void *>(args.ptr_B)),
220
- ptr_scales(const_cast<void *>(args.ptr_scales)),
221
- ptr_C(const_cast<void *>(args.ptr_C)), ptr_D(args.ptr_D),
222
- ptr_Vector(const_cast<void *>(args.ptr_Vector)), ldr(args.ldr),
223
- ptr_Tensor(const_cast<void *>(args.ptr_Tensor)), batch_stride_A(args.batch_stride_A),
224
- batch_stride_B(args.batch_stride_B),
225
- batch_stride_C(args.batch_stride_C),
226
- batch_stride_D(args.batch_stride_D),
227
- batch_stride_Vector(args.batch_stride_Vector),
228
- batch_stride_Tensor(args.batch_stride_Tensor),
229
- gather_A_indices(args.gather_A_indices),
230
- gather_B_indices(args.gather_B_indices),
231
- scatter_D_indices(args.scatter_D_indices) {}
232
- };
233
-
234
- /// Shared memory storage structure
235
- union SharedStorage {
236
- typename Mma::SharedStorage main_loop;
237
- typename Epilogue::SharedStorage epilogue;
238
- };
239
-
240
- //
241
- // Methods
242
- //
243
-
244
- CUTLASS_HOST_DEVICE
245
- GemmFpAIntBWithBroadcast() {}
246
-
247
- CUTLASS_HOST_DEVICE
248
- static Status can_implement(Arguments const &args) {
249
- // todo
250
- return Status::kSuccess;
251
- }
252
-
253
- static size_t
254
- get_extra_workspace_size(Arguments const &args,
255
- cutlass::gemm::GemmCoord const &grid_tiled_shape) {
256
-
257
- return 0;
258
- }
259
-
260
- // The dummy template parameter is not used and exists so that we can compile
261
- // this code using a standard earlier than C++17. Prior to C++17, fully
262
- // specialized templates HAD to exists in a namespace
263
- template <bool B, typename dummy = void> struct KernelRunner {
264
- CUTLASS_DEVICE
265
- static void run_kernel(Params const &params,
266
- SharedStorage &shared_storage) {
267
- CUTLASS_NOT_IMPLEMENTED();
268
- }
269
- };
270
-
271
- template <typename dummy> struct KernelRunner<true, dummy> {
272
- CUTLASS_DEVICE
273
- static void run_kernel(Params const &params,
274
- SharedStorage &shared_storage) {
275
- using LayoutB = typename Mma::IteratorB::Layout;
276
- static_assert(
277
- platform::is_same<LayoutB, layout::RowMajor>::value &&
278
- kInterleave == 1 ||
279
- platform::is_same<LayoutB, layout::ColumnMajor>::value &&
280
- kInterleave >= 1,
281
- "B must be row major/col major OR col major interleaved.");
282
-
283
- // Compute threadblock location
284
- ThreadblockSwizzle threadblock_swizzle;
285
-
286
- cutlass::gemm::GemmCoord threadblock_tile_offset =
287
- threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
288
-
289
- // Early exit if CTA is out of range
290
- if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
291
- params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
292
-
293
- return;
294
- }
295
-
296
- // Compute initial location in logical coordinates
297
- cutlass::MatrixCoord tb_offset_A{
298
- threadblock_tile_offset.m() * Mma::Shape::kM,
299
- threadblock_tile_offset.k() * params.gemm_k_size,
300
- };
301
-
302
- cutlass::MatrixCoord tb_offset_B{
303
- threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
304
- threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
305
-
306
- cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() *
307
- Mma::Shape::kN};
308
-
309
- // Problem size is a function of threadblock index in the K dimension
310
- int problem_size_k =
311
- min(params.problem_size.k(),
312
- (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
313
-
314
- // Compute threadblock-scoped matrix multiply-add
315
- int gemm_k_iterations =
316
- (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) /
317
- Mma::Shape::kK;
318
-
319
- // Compute position within threadblock
320
- int thread_idx = threadIdx.x;
321
-
322
- // Construct iterators to A and B operands
323
- typename Mma::IteratorA iterator_A(
324
- params.params_A, static_cast<ElementA *>(params.ptr_A),
325
- {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A,
326
- params.gather_A_indices);
327
-
328
- typename Mma::IteratorB iterator_B(
329
- params.params_B, static_cast<ElementB *>(params.ptr_B),
330
- {problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
331
- thread_idx, tb_offset_B, params.gather_B_indices);
332
-
333
- typename Mma::IteratorScale iterator_scale(
334
- params.params_scale, static_cast<ElementScale *>(params.ptr_scales),
335
- {1, params.problem_size.n()}, thread_idx, tb_offset_scale);
336
-
337
- // Broadcast the warp_id computed by lane 0 to ensure dependent code is
338
- // compiled as warp-uniform.
339
- int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
340
- int lane_idx = threadIdx.x % 32;
341
-
342
- //
343
- // Main loop
344
- //
345
- // Construct thread-scoped matrix multiply
346
- Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
347
-
348
- typename Mma::FragmentC accumulators;
349
-
350
- accumulators.clear();
351
-
352
- if (gemm_k_iterations > 0) {
353
- // Compute threadblock-scoped matrix multiply-add
354
- mma(gemm_k_iterations, accumulators, iterator_A, iterator_B,
355
- iterator_scale, accumulators);
356
- }
357
-
358
- //
359
- // Epilogue
360
- //
361
-
362
- EpilogueOutputOp output_op(params.output_op);
363
-
364
- //
365
- // Masked tile iterators constructed from members
366
- //
367
-
368
- threadblock_tile_offset =
369
- threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
370
-
371
- // assume identity swizzle
372
- MatrixCoord threadblock_offset(
373
- threadblock_tile_offset.m() * Mma::Shape::kM,
374
- threadblock_tile_offset.n() * Mma::Shape::kN);
375
-
376
- int block_idx = threadblock_tile_offset.m() +
377
- threadblock_tile_offset.n() * params.grid_tiled_shape.m();
378
-
379
- ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
380
- ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
381
-
382
- // Tile iterator loading from source tensor.
383
- typename Epilogue::OutputTileIterator iterator_C(
384
- params.params_C, ptr_C, params.problem_size.mn(),
385
- thread_idx, threadblock_offset, params.scatter_D_indices);
386
-
387
- // Tile iterator writing to destination tensor.
388
- typename Epilogue::OutputTileIterator iterator_D(
389
- params.params_D, ptr_D, params.problem_size.mn(),
390
- thread_idx, threadblock_offset, params.scatter_D_indices);
391
-
392
- typename Epilogue::ElementTensor *ptr_Tensor =
393
- static_cast<typename Epilogue::ElementTensor *>(params.ptr_Tensor);
394
-
395
- // Define the reduction output pointer and move to the appropriate place
396
- typename Epilogue::ElementVector *ptr_Vector =
397
- static_cast<typename Epilogue::ElementVector *>(params.ptr_Vector);
398
-
399
- typename Epilogue::TensorTileIterator tensor_iterator(
400
- params.params_Tensor,
401
- // Only the final block outputs Tensor
402
- ptr_Tensor, params.problem_size.mn(), thread_idx, threadblock_offset);
403
-
404
- Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx,
405
- lane_idx);
406
-
407
- if (ptr_Vector) {
408
- ptr_Vector += threadblock_offset.column() +
409
- threadblock_tile_offset.m() * params.ldr;
410
- }
411
-
412
- epilogue(output_op, ptr_Vector, iterator_D, accumulators, iterator_C,
413
- tensor_iterator, params.problem_size.mn(), threadblock_offset);
414
- }
415
- };
416
-
417
- /*
418
- To improve compilation speed, we do not compile the device operator if the
419
- CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel
420
- operator.
421
- */
422
- /// Executes one GEMM
423
- CUTLASS_DEVICE
424
- void operator()(Params const &params, SharedStorage &shared_storage) {
425
- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
426
- static constexpr bool compile_needed =
427
- platform::is_same<KernelArch, arch::Sm70>::value;
428
- KernelRunner<compile_needed>::run_kernel(params, shared_storage);
429
- #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
430
- static constexpr bool compile_needed =
431
- platform::is_same<KernelArch, arch::Sm75>::value;
432
- KernelRunner<compile_needed>::run_kernel(params, shared_storage);
433
- #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
434
- static constexpr bool compile_needed =
435
- platform::is_same<KernelArch, arch::Sm80>::value;
436
- KernelRunner<compile_needed>::run_kernel(params, shared_storage);
437
- #else
438
- CUTLASS_NOT_IMPLEMENTED();
439
- #endif
440
- }
441
- };
442
-
443
- /////////////////////////////////////////////////////////////////////////////////////////////////
444
-
445
- } // namespace kernel
446
- } // namespace gemm
447
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h DELETED
@@ -1,89 +0,0 @@
1
- /*
2
- This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
3
- quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
4
- to be consumed by CUTLASS.
5
-
6
- Note that for int4, ThreadBlockK MUST be 64.
7
-
8
- */
9
-
10
- #pragma once
11
-
12
- #include "cutlass/layout/matrix.h"
13
- #include "cutlass/numeric_types.h"
14
-
15
- #include "cutlass/arch/arch.h"
16
- #include "cutlass/arch/mma.h"
17
- #include "cutlass/platform/platform.h"
18
-
19
- #include "cutlass_extensions/arch/mma.h"
20
- #include "cutlass_extensions/tile_interleaved_layout.h"
21
-
22
- namespace cutlass {
23
- namespace gemm {
24
- namespace kernel {
25
-
26
- template<typename TypeB, typename Arch, typename Enable = void>
27
- struct LayoutDetailsB {
28
- };
29
-
30
- // Volta specialiations. Volta will dequantize before STS, so we need a different operator
31
- template<typename TypeB>
32
- struct LayoutDetailsB<TypeB, arch::Sm70> {
33
- static constexpr int ThreadblockK = 64;
34
- using Layout = layout::RowMajor;
35
- static constexpr int ElementsPerAccess = 8;
36
- using Operator = cutlass::arch::OpMultiplyAdd;
37
- };
38
-
39
- // Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
40
- // TODO - Switch this to column major for weights since gemms should be more performant.
41
- template<typename Arch>
42
- struct LayoutDetailsB<half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
43
- static constexpr int ThreadblockK = 64;
44
- using Layout = layout::RowMajor;
45
- static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
46
- using Operator = cutlass::arch::OpMultiplyAdd;
47
- };
48
-
49
- template<typename Arch>
50
- struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
51
- static constexpr int ThreadblockK = 64;
52
- using Layout = layout::RowMajor;
53
- static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
54
- using Operator = cutlass::arch::OpMultiplyAdd;
55
- };
56
-
57
- // Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
58
- // which signals that we want to dequantize after loading from smem.
59
- template<typename Arch>
60
- struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
61
- static constexpr int ThreadblockK = 64;
62
-
63
- private:
64
- static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
65
- static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
66
-
67
- public:
68
- using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
69
- static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
70
- using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
71
- };
72
-
73
- template<typename Arch>
74
- struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
75
- static constexpr int ThreadblockK = 64;
76
-
77
- private:
78
- static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
79
- static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
80
-
81
- public:
82
- using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
83
- static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
84
- using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
85
- };
86
-
87
- } // namespace kernel
88
- } // namespace gemm
89
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h DELETED
@@ -1,106 +0,0 @@
1
- #pragma once
2
-
3
- #include "cutlass_extensions/arch/mma.h"
4
- #include "cutlass_extensions/interleaved_numeric_conversion.h"
5
-
6
- namespace cutlass {
7
- namespace gemm {
8
- namespace threadblock {
9
- ////////////////////////////////////////////////////////////////////////////////
10
-
11
- // We need to distinguish here, since we want volta support. It is too much effort
12
- // to write shared memory iterators that are probably needed for volta to function
13
- // properly. As a result, we allow converters both after the LDG (for volta) and after
14
- // the LDS for Turing+.
15
- template<
16
- /// Iterator for B matrix in global memory
17
- typename IteratorB,
18
- /// Warp level Mma
19
- typename MmaOperator,
20
- /// Math operation perform by warp level operator
21
- typename MathOperator>
22
- struct SetConverters {
23
- };
24
-
25
- // Dequantize after LDG, so set transforms accordingly
26
- template<
27
- /// Iterator for B matrix in global memory
28
- typename IteratorB,
29
- /// Mma Policy
30
- typename MmaOperator>
31
- struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd> {
32
- using TransformAfterLDG =
33
- FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
34
- typename IteratorB::Element,
35
- IteratorB::Fragment::kElements>;
36
-
37
- using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
38
- typename MmaOperator::ArchMmaOperator::ElementB,
39
- MmaOperator::FragmentB::kElements>;
40
- };
41
-
42
- // Dequantize after LDS, so set transforms accordingly
43
-
44
- template<
45
- /// Iterator for B matrix in global memory
46
- typename IteratorB,
47
- /// Mma Policy
48
- typename MmaOperator>
49
- struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA> {
50
- using TransformAfterLDG =
51
- NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element, IteratorB::Fragment::kElements>;
52
-
53
- using TransformAfterLDS =
54
- FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
55
- typename TransformAfterLDG::result_type::Element,
56
- MmaOperator::FragmentB::kElements>;
57
- };
58
-
59
- ////////////////////////////////////////////////////////////////////////////////
60
-
61
- template<
62
- /// Element type for A matrix operand
63
- typename ElementA_,
64
- /// Layout type for A matrix operand
65
- typename LayoutA_,
66
- /// Access granularity of A matrix in units of elements
67
- int kAlignmentA,
68
- /// Element type for B matrix operand
69
- typename ElementB_,
70
- /// Layout type for B matrix operand
71
- typename LayoutB_,
72
- /// Access granularity of B matrix in units of elements
73
- int kAlignmentB,
74
- /// Element type for the input scale
75
- typename ElementScale_,
76
- /// Layout for the scale operand
77
- typename LayoutScale_,
78
- /// Access granularity of Scales in unit of elements
79
- int kAlignmentScale,
80
- /// Element type for internal accumulation
81
- typename ElementAccumulator_,
82
- /// Layout type for C and D matrix operands
83
- typename LayoutC_,
84
- /// Operator class tag
85
- typename OperatorClass_,
86
- /// Tag indicating architecture to tune for
87
- typename ArchTag_,
88
- /// Threadblock-level tile size (concept: GemmShape)
89
- typename ThreadblockShape_,
90
- /// Warp-level tile size (concept: GemmShape)
91
- typename WarpShape_,
92
- /// Instruction-level tile size (concept: GemmShape)
93
- typename InstructionShape_,
94
- /// Number of stages used in the pipelined mainloop
95
- int Stages,
96
- /// Operation performed by GEMM
97
- typename Operator_,
98
- /// Use zfill or predicate for out-of-bound cp.async
99
- SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
100
- ///
101
- typename Enable = void>
102
- struct DqMma;
103
-
104
- } // namespace threadblock
105
- } // namespace gemm
106
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h DELETED
@@ -1,346 +0,0 @@
1
- #pragma once
2
-
3
- #include "cutlass/gemm/threadblock/default_mma.h"
4
- #include "cutlass_extensions/arch/mma.h"
5
-
6
- #include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
7
- #include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
8
- #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
9
- #include "cutlass_extensions/tile_interleaved_layout.h"
10
-
11
- #include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
12
-
13
- namespace cutlass {
14
- namespace gemm {
15
- namespace threadblock {
16
-
17
- ////////////////////////////////////////////////////////////////////////////////
18
-
19
- template<
20
- /// Type for elementA
21
- typename ElementA,
22
- /// Layout type for A matrix operand
23
- typename LayoutA,
24
- /// Access granularity of A matrix in units of elements
25
- int kAlignmentA,
26
- /// Type for element B
27
- typename ElementB,
28
- /// Layout type for B matrix operand
29
- typename LayoutB,
30
- /// Access granularity of B matrix in units of elements
31
- int kAlignmentB,
32
- /// Element type for the input scale
33
- typename ElementScale,
34
- /// Layout for the scale operand
35
- typename LayoutScale,
36
- /// Access granularity of Scales in unit of elements
37
- int kAlignmentScale,
38
- /// Element type for internal accumulation
39
- typename ElementAccumulator,
40
- /// Operator class tag
41
- typename OperatorClass,
42
- /// Tag indicating architecture to tune for
43
- typename ArchTag,
44
- /// Threadblock-level tile size (concept: GemmShape)
45
- typename ThreadblockShape,
46
- /// Warp-level tile size (concept: GemmShape)
47
- typename WarpShape,
48
- /// Instruction-level tile size (concept: GemmShape)
49
- typename InstructionShape,
50
- /// Stages in GEMM
51
- int kStages,
52
- ///
53
- typename Operator,
54
- ///
55
- SharedMemoryClearOption SharedMemoryClear>
56
- struct DqMma<ElementA,
57
- LayoutA,
58
- kAlignmentA,
59
- ElementB,
60
- LayoutB,
61
- kAlignmentB,
62
- ElementScale,
63
- LayoutScale,
64
- kAlignmentScale,
65
- ElementAccumulator,
66
- layout::RowMajor,
67
- OperatorClass,
68
- ArchTag,
69
- ThreadblockShape,
70
- WarpShape,
71
- InstructionShape,
72
- kStages,
73
- Operator,
74
- SharedMemoryClear,
75
- typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> {
76
-
77
- static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
78
- "Element A must be fp16 or bf16");
79
-
80
- static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
81
- "Mma multistage must dequantize after ldsm");
82
-
83
- static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
84
- "Element B must be uint8 or uint4");
85
-
86
- static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128) ?
87
- cutlass::arch::CacheOperation::Global :
88
- cutlass::arch::CacheOperation::Always;
89
-
90
- static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128) ?
91
- cutlass::arch::CacheOperation::Global :
92
- cutlass::arch::CacheOperation::Always;
93
-
94
- // Define the MmaCore components
95
- // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
96
- using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
97
- WarpShape,
98
- InstructionShape,
99
- ElementA,
100
- LayoutA,
101
- ElementB,
102
- LayoutB,
103
- ElementAccumulator,
104
- layout::RowMajor,
105
- OperatorClass,
106
- std::max(kStages, 3),
107
- Operator,
108
- false,
109
- CacheOpA,
110
- CacheOpB>;
111
-
112
- // Define iterators over tiles from the A operand
113
- using ThreadMapA = typename MmaCore::IteratorThreadMapA;
114
- using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
115
- using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
116
- cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
117
- ElementA,
118
- LayoutA,
119
- 1,
120
- ThreadMapA,
121
- AccessTypeA>;
122
-
123
- // Define iterators over tiles from the B operand
124
- using ThreadMapB = typename MmaCore::IteratorThreadMapB;
125
- using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
126
- using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
127
- cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
128
- ElementB,
129
- LayoutB,
130
- 0,
131
- ThreadMapB,
132
- AccessTypeB>;
133
-
134
- // ThreadMap for scale iterator
135
- static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
136
- using IteratorScaleThreadMap =
137
- transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
138
- MmaCore::Shape::kN / kAlignmentScale,
139
- kAlignmentScale>;
140
-
141
- // Define iterators over tiles from the scale operand
142
- using IteratorScale =
143
- cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
144
- ElementScale,
145
- LayoutScale,
146
- 0,
147
- IteratorScaleThreadMap,
148
- kAlignmentScale>;
149
-
150
- using SmemIteratorScale = IteratorScale;
151
-
152
- using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA,
153
- ElementB,
154
- MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
155
-
156
- // Define the threadblock-scoped pipelined matrix multiply
157
- using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape,
158
- IteratorA,
159
- typename MmaCore::SmemIteratorA,
160
- MmaCore::kCacheOpA,
161
- IteratorB,
162
- typename MmaCore::SmemIteratorB,
163
- MmaCore::kCacheOpB,
164
- IteratorScale,
165
- SmemIteratorScale,
166
- ElementAccumulator,
167
- layout::RowMajor,
168
- typename MmaCore::MmaPolicy,
169
- kStages,
170
- Converter,
171
- SharedMemoryClear>;
172
- };
173
-
174
- template<
175
- /// Type for element A
176
- typename ElementA,
177
- /// Layout type for A matrix operand
178
- typename LayoutA,
179
- /// Access granularity of A matrix in units of elements
180
- int kAlignmentA,
181
- /// Type for element B
182
- typename ElementB,
183
- /// Access granularity of B matrix in units of elements
184
- int kAlignmentB,
185
- /// Element type for the input scale
186
- typename ElementScale,
187
- /// Layout for the scale operand
188
- typename LayoutScale,
189
- /// Access granularity of Scales in unit of elements
190
- int kAlignmentScale,
191
- /// Element type for internal accumulation
192
- typename ElementAccumulator,
193
- /// Operator class tag
194
- typename OperatorClass,
195
- /// Tag indicating architecture to tune for
196
- typename ArchTag,
197
- /// Threadblock-level tile size (concept: GemmShape)
198
- typename ThreadblockShape,
199
- /// Warp-level tile size (concept: GemmShape)
200
- typename WarpShape,
201
- /// Instruction-level tile size (concept: GemmShape)
202
- typename InstructionShape,
203
- /// Stages in GEMM
204
- int kStages,
205
- ///
206
- typename Operator,
207
- ///
208
- SharedMemoryClearOption SharedMemoryClear,
209
- ///
210
- int RowsPerTile,
211
- ///
212
- int ColumnsInterleaved>
213
- struct DqMma<ElementA,
214
- LayoutA,
215
- kAlignmentA,
216
- ElementB,
217
- layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
218
- kAlignmentB,
219
- ElementScale,
220
- LayoutScale,
221
- kAlignmentScale,
222
- ElementAccumulator,
223
- layout::RowMajor,
224
- OperatorClass,
225
- ArchTag,
226
- ThreadblockShape,
227
- WarpShape,
228
- InstructionShape,
229
- kStages,
230
- Operator,
231
- SharedMemoryClear,
232
- typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> {
233
-
234
- static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
235
- "Element A must be fp16 or bf16");
236
-
237
- static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
238
- "Mma multistage must dequantize after ldsm");
239
-
240
- static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
241
- "Element B must be uint8 or uint4");
242
-
243
- static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128) ?
244
- cutlass::arch::CacheOperation::Global :
245
- cutlass::arch::CacheOperation::Always;
246
-
247
- static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128) ?
248
- cutlass::arch::CacheOperation::Global :
249
- cutlass::arch::CacheOperation::Always;
250
-
251
- // Define the MmaCore components
252
- // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
253
- using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
254
- WarpShape,
255
- InstructionShape,
256
- ElementA,
257
- LayoutA,
258
- ElementB,
259
- layout::ColumnMajor,
260
- ElementAccumulator,
261
- layout::RowMajor,
262
- OperatorClass,
263
- std::max(kStages, 3),
264
- Operator,
265
- false,
266
- CacheOpA,
267
- CacheOpB>;
268
-
269
- // Define iterators over tiles from the A operand
270
- using ThreadMapA = typename MmaCore::IteratorThreadMapA;
271
- using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
272
- using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
273
- cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
274
- ElementA,
275
- LayoutA,
276
- 1,
277
- ThreadMapA,
278
- AccessTypeA>;
279
-
280
- private:
281
- static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
282
- static_assert(RowsPerTile == MmaCore::Shape::kK, "");
283
-
284
- using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
285
- using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
286
- static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
287
-
288
- using GmemIteratorShape =
289
- MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
290
- using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
291
- layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>,
292
- OriginalThreadMap::kThreads,
293
- layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
294
- OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
295
- MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
296
-
297
- public:
298
- // Define iterators over tiles from the B operand
299
- using ThreadMapB = typename MmaCore::IteratorThreadMapB;
300
- using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
301
- using IteratorB = cutlass::transform::threadblock::
302
- PredicatedTileAccessIterator<GmemIteratorShape, ElementB, layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
303
-
304
- // ThreadMap for scale iterator
305
- static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
306
- using IteratorScaleThreadMap =
307
- transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
308
- MmaCore::Shape::kN / kAlignmentScale,
309
- kAlignmentScale>;
310
-
311
- // Define iterators over tiles from the scale operand
312
- using IteratorScale =
313
- cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
314
- ElementScale,
315
- LayoutScale,
316
- 0,
317
- IteratorScaleThreadMap,
318
- kAlignmentScale>;
319
-
320
- using SmemIteratorScale = IteratorScale;
321
-
322
- using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA,
323
- ElementB,
324
- MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
325
-
326
- // Define the threadblock-scoped pipelined matrix multiply
327
- using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape,
328
- IteratorA,
329
- typename MmaCore::SmemIteratorA,
330
- MmaCore::kCacheOpA,
331
- IteratorB,
332
- typename MmaCore::SmemIteratorB,
333
- MmaCore::kCacheOpB,
334
- IteratorScale,
335
- SmemIteratorScale,
336
- ElementAccumulator,
337
- layout::RowMajor,
338
- typename MmaCore::MmaPolicy,
339
- kStages,
340
- Converter,
341
- SharedMemoryClear>;
342
- };
343
-
344
- } // namespace threadblock
345
- } // namespace gemm
346
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h DELETED
@@ -1,315 +0,0 @@
1
- #pragma once
2
-
3
- #include "cutlass/gemm/threadblock/default_mma.h"
4
- #include "cutlass_extensions/arch/mma.h"
5
-
6
- #include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
7
- #include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
8
- #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
9
- #include "cutlass_extensions/tile_interleaved_layout.h"
10
-
11
- #include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
12
-
13
- namespace cutlass {
14
- namespace gemm {
15
- namespace threadblock {
16
-
17
- ////////////////////////////////////////////////////////////////////////////////
18
-
19
- template<
20
- /// Type for element A
21
- typename ElementA,
22
- /// Layout type for A matrix operand
23
- typename LayoutA,
24
- /// Access granularity of A matrix in units of elements
25
- int kAlignmentA,
26
- /// Type for element B
27
- typename ElementB,
28
- /// Layout type for B matrix operand
29
- typename LayoutB,
30
- /// Access granularity of B matrix in units of elements
31
- int kAlignmentB,
32
- /// Element type for the input scale
33
- typename ElementScale,
34
- /// Layout for the scale operand
35
- typename LayoutScale,
36
- /// Access granularity of Scales in unit of elements
37
- int kAlignmentScale,
38
- /// Element type for internal accumulation
39
- typename ElementAccumulator,
40
- /// Operator class tag
41
- typename OperatorClass,
42
- /// Tag indicating architecture to tune for
43
- typename ArchTag,
44
- /// Threadblock-level tile size (concept: GemmShape)
45
- typename ThreadblockShape,
46
- /// Warp-level tile size (concept: GemmShape)
47
- typename WarpShape,
48
- /// Instruction-level tile size (concept: GemmShape)
49
- typename InstructionShape,
50
- /// Operation performed by GEMM
51
- typename Operator>
52
- struct DqMma<ElementA,
53
- LayoutA,
54
- kAlignmentA,
55
- ElementB,
56
- LayoutB,
57
- kAlignmentB,
58
- ElementScale,
59
- LayoutScale,
60
- kAlignmentScale,
61
- ElementAccumulator,
62
- layout::RowMajor,
63
- OperatorClass,
64
- ArchTag,
65
- ThreadblockShape,
66
- WarpShape,
67
- InstructionShape,
68
- 2,
69
- Operator,
70
- SharedMemoryClearOption::kNone,
71
- typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
72
-
73
- static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
74
- "Element A must be fp16 or bf16");
75
-
76
- static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
77
- "Element B must be uint8 or uint4");
78
-
79
- static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
80
- static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
81
- using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
82
- using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
83
-
84
- // Define the MmaCore components
85
- using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
86
- WarpShape,
87
- InstructionShape,
88
- MmaCoreElementA,
89
- LayoutA,
90
- MmaCoreElementB,
91
- LayoutB,
92
- ElementAccumulator,
93
- layout::RowMajor,
94
- OperatorClass,
95
- 2,
96
- Operator>;
97
-
98
- // Define iterators over tiles from the A operand
99
- using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
100
- cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
101
- ElementA,
102
- LayoutA,
103
- 1,
104
- typename MmaCore::IteratorThreadMapA,
105
- kAlignmentA>;
106
-
107
- // Define iterators over tiles from the B operand
108
- using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
109
- cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
110
- ElementB,
111
- LayoutB,
112
- 0,
113
- typename MmaCore::IteratorThreadMapB,
114
- kAlignmentB>;
115
-
116
- // ThreadMap for scale iterator
117
- static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
118
- using IteratorScaleThreadMap =
119
- transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
120
- MmaCore::Shape::kN / kAlignmentScale,
121
- kAlignmentScale>;
122
-
123
- // Define iterators over tiles from the scale operand
124
- using IteratorScale =
125
- cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
126
- ElementScale,
127
- LayoutScale,
128
- 0,
129
- IteratorScaleThreadMap,
130
- kAlignmentScale>;
131
-
132
- using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
133
- using SmemIteratorScale =
134
- cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
135
- SmemScaleType,
136
- LayoutScale,
137
- 0,
138
- IteratorScaleThreadMap,
139
- kAlignmentScale>;
140
-
141
- using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
142
-
143
- // Define the threadblock-scoped pipelined matrix multiply
144
- using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape,
145
- IteratorA,
146
- typename MmaCore::SmemIteratorA,
147
- IteratorB,
148
- typename MmaCore::SmemIteratorB,
149
- IteratorScale,
150
- SmemIteratorScale,
151
- ElementAccumulator,
152
- layout::RowMajor,
153
- typename MmaCore::MmaPolicy,
154
- typename Converters::TransformAfterLDG,
155
- typename Converters::TransformAfterLDS>;
156
- };
157
-
158
- // Specialization to handle column major interleave B
159
- template<
160
- /// Type for element A
161
- typename ElementA,
162
- /// Layout type for A matrix operand
163
- typename LayoutA,
164
- /// Access granularity of A matrix in units of elements
165
- int kAlignmentA,
166
- /// Type for element B
167
- typename ElementB,
168
- /// Access granularity of B matrix in units of elements
169
- int kAlignmentB,
170
- /// Element type for the input scale
171
- typename ElementScale,
172
- /// Layout for the scale operand
173
- typename LayoutScale,
174
- /// Access granularity of Scales in unit of elements
175
- int kAlignmentScale,
176
- /// Element type for internal accumulation
177
- typename ElementAccumulator,
178
- /// Operator class tag
179
- typename OperatorClass,
180
- /// Tag indicating architecture to tune for
181
- typename ArchTag,
182
- /// Threadblock-level tile size (concept: GemmShape)
183
- typename ThreadblockShape,
184
- /// Warp-level tile size (concept: GemmShape)
185
- typename WarpShape,
186
- /// Instruction-level tile size (concept: GemmShape)
187
- typename InstructionShape,
188
- /// Operation performed by GEMM
189
- typename Operator,
190
- ///
191
- int RowsPerTile,
192
- ///
193
- int ColumnsInterleaved>
194
- struct DqMma<ElementA,
195
- LayoutA,
196
- kAlignmentA,
197
- ElementB,
198
- layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
199
- kAlignmentB,
200
- ElementScale,
201
- LayoutScale,
202
- kAlignmentScale,
203
- ElementAccumulator,
204
- layout::RowMajor,
205
- OperatorClass,
206
- ArchTag,
207
- ThreadblockShape,
208
- WarpShape,
209
- InstructionShape,
210
- 2,
211
- Operator,
212
- SharedMemoryClearOption::kNone,
213
- typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
214
-
215
- static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
216
- "Element A must be fp16 or bf16");
217
-
218
- static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
219
- "Element B must be uint8 or uint4");
220
-
221
- static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
222
- static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
223
- using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
224
- using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
225
-
226
- // Define the MmaCore components
227
- using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
228
- WarpShape,
229
- InstructionShape,
230
- MmaCoreElementA,
231
- LayoutA,
232
- MmaCoreElementB,
233
- layout::ColumnMajor,
234
- ElementAccumulator,
235
- layout::RowMajor,
236
- OperatorClass,
237
- 2,
238
- Operator>;
239
-
240
- // Define iterators over tiles from the A operand
241
- using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
242
- cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
243
- ElementA,
244
- LayoutA,
245
- 1,
246
- typename MmaCore::IteratorThreadMapA,
247
- kAlignmentA>;
248
-
249
- private:
250
- static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
251
- static_assert(RowsPerTile == MmaCore::Shape::kK, "");
252
-
253
- using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
254
- using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
255
- static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
256
-
257
- using GmemIteratorShape =
258
- MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
259
- using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
260
- layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>,
261
- OriginalThreadMap::kThreads,
262
- layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
263
- OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
264
- MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
265
-
266
- public:
267
- // Define iterators over tiles from the B operand
268
- using IteratorB = cutlass::transform::threadblock::
269
- PredicatedTileIterator<GmemIteratorShape, ElementB, layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
270
-
271
- // ThreadMap for scale iterator
272
- static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
273
- using IteratorScaleThreadMap =
274
- transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
275
- MmaCore::Shape::kN / kAlignmentScale,
276
- kAlignmentScale>;
277
-
278
- // Define iterators over tiles from the scale operand
279
- using IteratorScale =
280
- cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
281
- ElementScale,
282
- LayoutScale,
283
- 0,
284
- IteratorScaleThreadMap,
285
- kAlignmentScale>;
286
-
287
- using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
288
- using SmemIteratorScale =
289
- cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
290
- SmemScaleType,
291
- LayoutScale,
292
- 0,
293
- IteratorScaleThreadMap,
294
- kAlignmentScale>;
295
-
296
- using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
297
-
298
- // Define the threadblock-scoped pipelined matrix multiply
299
- using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape,
300
- IteratorA,
301
- typename MmaCore::SmemIteratorA,
302
- IteratorB,
303
- typename MmaCore::SmemIteratorB,
304
- IteratorScale,
305
- SmemIteratorScale,
306
- ElementAccumulator,
307
- layout::RowMajor,
308
- typename MmaCore::MmaPolicy,
309
- typename Converters::TransformAfterLDG,
310
- typename Converters::TransformAfterLDS>;
311
- };
312
-
313
- } // namespace threadblock
314
- } // namespace gemm
315
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h DELETED
@@ -1,426 +0,0 @@
1
- #pragma once
2
-
3
- #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
4
- #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
5
- #include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
6
-
7
- namespace cutlass {
8
- namespace gemm {
9
- namespace threadblock {
10
-
11
- ////////////////////////////////////////////////////////////////////////////////
12
-
13
- /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight
14
- template<
15
- /// Layout type for A matrix operand
16
- typename LayoutA,
17
- /// Access granularity of A matrix in units of elements
18
- int kAlignmentA,
19
- /// Layout type for B matrix operand
20
- typename LayoutB,
21
- /// Access granularity of B matrix in units of elements
22
- int kAlignmentB,
23
- /// Element type for internal accumulation
24
- typename ElementAccumulator,
25
- /// Tag indicating architecture to tune for
26
- typename ArchTag,
27
- /// Threadblock-level tile size (concept: GemmShape)
28
- typename ThreadblockShape,
29
- /// Warp-level tile size (concept: GemmShape)
30
- typename WarpShape,
31
- /// Instruction-level tile size (concept: GemmShape)
32
- typename InstructionShape,
33
- /// Operation performed by GEMM
34
- typename Operator>
35
- struct DefaultMma<cutlass::half_t,
36
- LayoutA,
37
- kAlignmentA,
38
- uint8_t,
39
- LayoutB,
40
- kAlignmentB,
41
- ElementAccumulator,
42
- layout::RowMajor,
43
- arch::OpClassTensorOp,
44
- ArchTag,
45
- ThreadblockShape,
46
- WarpShape,
47
- InstructionShape,
48
- 2,
49
- Operator> {
50
-
51
- private:
52
- static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
53
-
54
- using Mma = DqMma<half_t,
55
- LayoutA,
56
- kAlignmentA,
57
- uint8_t,
58
- LayoutB,
59
- kAlignmentB,
60
- half_t,
61
- layout::RowMajor,
62
- kAlignmentScale,
63
- ElementAccumulator,
64
- layout::RowMajor,
65
- arch::OpClassTensorOp,
66
- ArchTag,
67
- ThreadblockShape,
68
- WarpShape,
69
- InstructionShape,
70
- 2,
71
- Operator>;
72
-
73
- public:
74
- // Define the MmaCore components
75
- using MmaCore = typename Mma::MmaCore;
76
-
77
- // Define iterators over tiles from the A operand
78
- using IteratorA = typename Mma::IteratorA;
79
-
80
- // Define iterators over tiles from the B operand
81
- using IteratorB = typename Mma::IteratorB;
82
-
83
- // Define the threadblock-scoped pipelined matrix multiply
84
- using ThreadblockMma = typename Mma::ThreadblockMma;
85
- };
86
-
87
- ////////////////////////////////////////////////////////////////////////////////
88
- /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
89
- template<
90
- /// Layout type for A matrix operand
91
- typename LayoutA,
92
- /// Access granularity of A matrix in units of elements
93
- int kAlignmentA,
94
- /// Layout type for B matrix operand
95
- typename LayoutB,
96
- /// Access granularity of B matrix in units of elements
97
- int kAlignmentB,
98
- /// Element type for internal accumulation
99
- typename ElementAccumulator,
100
- /// Tag indicating architecture to tune for
101
- typename ArchTag,
102
- /// Threadblock-level tile size (concept: GemmShape)
103
- typename ThreadblockShape,
104
- /// Warp-level tile size (concept: GemmShape)
105
- typename WarpShape,
106
- /// Instruction-level tile size (concept: GemmShape)
107
- typename InstructionShape,
108
- /// Operation performed by GEMM
109
- typename Operator>
110
- struct DefaultMma<cutlass::half_t,
111
- LayoutA,
112
- kAlignmentA,
113
- uint4b_t,
114
- LayoutB,
115
- kAlignmentB,
116
- ElementAccumulator,
117
- layout::RowMajor,
118
- arch::OpClassTensorOp,
119
- ArchTag,
120
- ThreadblockShape,
121
- WarpShape,
122
- InstructionShape,
123
- 2,
124
- Operator> {
125
-
126
- private:
127
- static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
128
-
129
- using Mma = DqMma<half_t,
130
- LayoutA,
131
- kAlignmentA,
132
- uint4b_t,
133
- LayoutB,
134
- kAlignmentB,
135
- half_t,
136
- layout::RowMajor,
137
- kAlignmentScale,
138
- ElementAccumulator,
139
- layout::RowMajor,
140
- arch::OpClassTensorOp,
141
- ArchTag,
142
- ThreadblockShape,
143
- WarpShape,
144
- InstructionShape,
145
- 2,
146
- Operator>;
147
-
148
- public:
149
- // Define the MmaCore components
150
- using MmaCore = typename Mma::MmaCore;
151
-
152
- // Define iterators over tiles from the A operand
153
- using IteratorA = typename Mma::IteratorA;
154
-
155
- // Define iterators over tiles from the B operand
156
- using IteratorB = typename Mma::IteratorB;
157
-
158
- // Define the threadblock-scoped pipelined matrix multiply
159
- using ThreadblockMma = typename Mma::ThreadblockMma;
160
- };
161
-
162
- template<
163
- /// Layout type for A matrix operand
164
- typename LayoutA,
165
- /// Access granularity of A matrix in units of elements
166
- int kAlignmentA,
167
- /// Layout type for B matrix operand
168
- typename LayoutB,
169
- /// Access granularity of B matrix in units of elements
170
- int kAlignmentB,
171
- /// Element type for internal accumulation
172
- typename ElementAccumulator,
173
- /// Tag indicating architecture to tune for
174
- typename ArchTag,
175
- /// Threadblock-level tile size (concept: GemmShape)
176
- typename ThreadblockShape,
177
- /// Warp-level tile size (concept: GemmShape)
178
- typename WarpShape,
179
- /// Instruction-level tile size (concept: GemmShape)
180
- typename InstructionShape,
181
- /// Operation performed by GEMM
182
- typename Operator,
183
- ///
184
- int kStages,
185
- /// Shared memory clear option
186
- SharedMemoryClearOption SharedMemoryClear>
187
- struct DefaultMma<cutlass::half_t,
188
- LayoutA,
189
- kAlignmentA,
190
- uint8_t,
191
- LayoutB,
192
- kAlignmentB,
193
- ElementAccumulator,
194
- layout::RowMajor,
195
- arch::OpClassTensorOp,
196
- ArchTag,
197
- ThreadblockShape,
198
- WarpShape,
199
- InstructionShape,
200
- kStages,
201
- Operator,
202
- false,
203
- SharedMemoryClear> {
204
-
205
- private:
206
- static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
207
-
208
- using Mma = DqMma<half_t,
209
- LayoutA,
210
- kAlignmentA,
211
- uint8_t,
212
- LayoutB,
213
- kAlignmentB,
214
- half_t,
215
- layout::RowMajor,
216
- kAlignmentScale,
217
- ElementAccumulator,
218
- layout::RowMajor,
219
- arch::OpClassTensorOp,
220
- ArchTag,
221
- ThreadblockShape,
222
- WarpShape,
223
- InstructionShape,
224
- kStages,
225
- Operator,
226
- SharedMemoryClear>;
227
-
228
- public:
229
- // Define the MmaCore components
230
- using MmaCore = typename Mma::MmaCore;
231
-
232
- // Define iterators over tiles from the A operand
233
- using IteratorA = typename Mma::IteratorA;
234
-
235
- // Define iterators over tiles from the B operand
236
- using IteratorB = typename Mma::IteratorB;
237
-
238
- // Define the threadblock-scoped pipelined matrix multiply
239
- using ThreadblockMma = typename Mma::ThreadblockMma;
240
- };
241
-
242
- ////////////////////////////////////////////////////////////////////////////////
243
- /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
244
- template<
245
- /// Layout type for A matrix operand
246
- typename LayoutA,
247
- /// Access granularity of A matrix in units of elements
248
- int kAlignmentA,
249
- /// Layout type for B matrix operand
250
- typename LayoutB,
251
- /// Access granularity of B matrix in units of elements
252
- int kAlignmentB,
253
- /// Element type for internal accumulation
254
- typename ElementAccumulator,
255
- /// Tag indicating architecture to tune for
256
- typename ArchTag,
257
- /// Threadblock-level tile size (concept: GemmShape)
258
- typename ThreadblockShape,
259
- /// Warp-level tile size (concept: GemmShape)
260
- typename WarpShape,
261
- /// Instruction-level tile size (concept: GemmShape)
262
- typename InstructionShape,
263
- /// Operation performed by GEMM
264
- typename Operator,
265
- ///
266
- int kStages,
267
- /// Shared memory clear option
268
- SharedMemoryClearOption SharedMemoryClear>
269
- struct DefaultMma<cutlass::half_t,
270
- LayoutA,
271
- kAlignmentA,
272
- uint4b_t,
273
- LayoutB,
274
- kAlignmentB,
275
- ElementAccumulator,
276
- layout::RowMajor,
277
- arch::OpClassTensorOp,
278
- ArchTag,
279
- ThreadblockShape,
280
- WarpShape,
281
- InstructionShape,
282
- kStages,
283
- Operator,
284
- false,
285
- SharedMemoryClear> {
286
-
287
- private:
288
- static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
289
-
290
- using Mma = DqMma<half_t,
291
- LayoutA,
292
- kAlignmentA,
293
- uint4b_t,
294
- LayoutB,
295
- kAlignmentB,
296
- half_t,
297
- layout::RowMajor,
298
- kAlignmentScale,
299
- ElementAccumulator,
300
- layout::RowMajor,
301
- arch::OpClassTensorOp,
302
- ArchTag,
303
- ThreadblockShape,
304
- WarpShape,
305
- InstructionShape,
306
- kStages,
307
- Operator,
308
- SharedMemoryClear>;
309
-
310
- public:
311
- // Define the MmaCore components
312
- using MmaCore = typename Mma::MmaCore;
313
-
314
- // Define iterators over tiles from the A operand
315
- using IteratorA = typename Mma::IteratorA;
316
-
317
- // Define iterators over tiles from the B operand
318
- using IteratorB = typename Mma::IteratorB;
319
-
320
- // Define the threadblock-scoped pipelined matrix multiply
321
- using ThreadblockMma = typename Mma::ThreadblockMma;
322
- };
323
-
324
- // fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
325
- // large tile when not enough shared mem is present to do 3+ stage
326
- template<
327
- /// Layout type for A matrix operand
328
- typename LayoutA,
329
- /// Access granularity of A matrix in units of elements
330
- int kAlignmentA,
331
- /// Layout type for B matrix operand
332
- typename LayoutB,
333
- /// Access granularity of B matrix in units of elements
334
- int kAlignmentB,
335
- /// Element type for internal accumulation
336
- typename ElementAccumulator,
337
- /// Threadblock-level tile size (concept: GemmShape)
338
- typename ThreadblockShape,
339
- /// Warp-level tile size (concept: GemmShape)
340
- typename WarpShape,
341
- /// Instruction-level tile size (concept: GemmShape)
342
- typename InstructionShape,
343
- /// Operation performed by GEMM
344
- typename Operator,
345
- /// Use zfill or predicate for out-of-bound cp.async
346
- SharedMemoryClearOption SharedMemoryClear,
347
- /// Gather operand A by using an index array
348
- bool GatherA,
349
- /// Gather operand B by using an index array
350
- bool GatherB>
351
- struct DefaultMma<half_t,
352
- LayoutA,
353
- kAlignmentA,
354
- half_t,
355
- LayoutB,
356
- kAlignmentB,
357
- ElementAccumulator,
358
- layout::RowMajor,
359
- arch::OpClassTensorOp,
360
- arch::Sm80,
361
- ThreadblockShape,
362
- WarpShape,
363
- InstructionShape,
364
- 2,
365
- Operator,
366
- false,
367
- SharedMemoryClear,
368
- GatherA,
369
- GatherB> {
370
-
371
- // Define the MmaCore components
372
- // 3 is used on purpose here to trigger components for mma multistage
373
- using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
374
- WarpShape,
375
- InstructionShape,
376
- half_t,
377
- LayoutA,
378
- half_t,
379
- LayoutB,
380
- ElementAccumulator,
381
- layout::RowMajor,
382
- arch::OpClassTensorOp,
383
- 3,
384
- Operator>;
385
-
386
- // Define iterators over tiles from the A operand
387
- using ThreadMapA = typename MmaCore::IteratorThreadMapA;
388
- using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
389
- using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
390
- cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
391
- half_t,
392
- LayoutA,
393
- 1,
394
- ThreadMapA,
395
- AccessTypeA,
396
- GatherA>;
397
-
398
- // Define iterators over tiles from the B operand
399
- using ThreadMapB = typename MmaCore::IteratorThreadMapB;
400
- using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
401
- using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
402
- cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
403
- half_t,
404
- LayoutB,
405
- 0,
406
- ThreadMapB,
407
- AccessTypeB,
408
- GatherB>;
409
-
410
- // Define the threadblock-scoped multistage matrix multiply
411
- using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
412
- IteratorA,
413
- typename MmaCore::SmemIteratorA,
414
- MmaCore::kCacheOpA,
415
- IteratorB,
416
- typename MmaCore::SmemIteratorB,
417
- MmaCore::kCacheOpB,
418
- ElementAccumulator,
419
- layout::RowMajor,
420
- typename MmaCore::MmaPolicy,
421
- 2>;
422
- };
423
-
424
- } // namespace threadblock
425
- } // namespace gemm
426
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h DELETED
@@ -1,527 +0,0 @@
1
- #pragma once
2
-
3
- #include "cutlass/gemm/threadblock/default_mma.h"
4
- #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
5
- #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
6
-
7
- namespace cutlass {
8
- namespace gemm {
9
- namespace threadblock {
10
-
11
- ////////////////////////////////////////////////////////////////////////////////
12
-
13
- /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
14
- template<
15
- /// Layout type for A matrix operand
16
- typename LayoutA,
17
- /// Access granularity of A matrix in units of elements
18
- int kAlignmentA,
19
- /// Layout type for B matrix operand
20
- typename LayoutB,
21
- /// Access granularity of B matrix in units of elements
22
- int kAlignmentB,
23
- /// Element type for internal accumulation
24
- typename ElementAccumulator,
25
- /// Tag indicating architecture to tune for
26
- typename ArchTag,
27
- /// Threadblock-level tile size (concept: GemmShape)
28
- typename ThreadblockShape,
29
- /// Warp-level tile size (concept: GemmShape)
30
- typename WarpShape,
31
- /// Instruction-level tile size (concept: GemmShape)
32
- typename InstructionShape,
33
- /// Operation performed by GEMM
34
- typename Operator,
35
- /// Use zfill or predicate for out-of-bound cp.async
36
- SharedMemoryClearOption SharedMemoryClear,
37
- /// Gather operand A by using an index array
38
- bool GatherA,
39
- /// Gather operand B by using an index array
40
- bool GatherB>
41
- struct DefaultMma<bfloat16_t,
42
- LayoutA,
43
- kAlignmentA,
44
- bfloat16_t,
45
- LayoutB,
46
- kAlignmentB,
47
- ElementAccumulator,
48
- layout::RowMajor,
49
- arch::OpClassTensorOp,
50
- ArchTag,
51
- ThreadblockShape,
52
- WarpShape,
53
- InstructionShape,
54
- 2,
55
- Operator,
56
- false,
57
- SharedMemoryClear,
58
- GatherA,
59
- GatherB> {
60
-
61
- private:
62
- // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
63
- static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
64
- using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
65
- using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
66
-
67
- public:
68
- // Define the MmaCore components
69
- using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
70
- WarpShape,
71
- InstructionShape,
72
- MmaElementA,
73
- LayoutA,
74
- MmaElementB,
75
- LayoutB,
76
- ElementAccumulator,
77
- layout::RowMajor,
78
- arch::OpClassTensorOp,
79
- 2,
80
- Operator>;
81
-
82
- using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
83
- cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
84
- bfloat16_t,
85
- LayoutA,
86
- 1,
87
- typename MmaCore::IteratorThreadMapA,
88
- kAlignmentA,
89
- GatherA>;
90
-
91
- // Define iterators over tiles from the B operand
92
- using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
93
- cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
94
- bfloat16_t,
95
- LayoutB,
96
- 0,
97
- typename MmaCore::IteratorThreadMapB,
98
- kAlignmentB,
99
- GatherB>;
100
-
101
- // Define the threadblock-scoped pipelined matrix multiply
102
- using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape,
103
- IteratorA,
104
- typename MmaCore::SmemIteratorA,
105
- IteratorB,
106
- typename MmaCore::SmemIteratorB,
107
- ElementAccumulator,
108
- layout::RowMajor,
109
- typename MmaCore::MmaPolicy>;
110
- };
111
-
112
- // bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
113
- // large tile when not enough shared mem is present to do 3+ stage
114
- template<
115
- /// Layout type for A matrix operand
116
- typename LayoutA,
117
- /// Access granularity of A matrix in units of elements
118
- int kAlignmentA,
119
- /// Layout type for B matrix operand
120
- typename LayoutB,
121
- /// Access granularity of B matrix in units of elements
122
- int kAlignmentB,
123
- /// Element type for internal accumulation
124
- typename ElementAccumulator,
125
- /// Threadblock-level tile size (concept: GemmShape)
126
- typename ThreadblockShape,
127
- /// Warp-level tile size (concept: GemmShape)
128
- typename WarpShape,
129
- /// Instruction-level tile size (concept: GemmShape)
130
- typename InstructionShape,
131
- /// Operation performed by GEMM
132
- typename Operator,
133
- /// Use zfill or predicate for out-of-bound cp.async
134
- SharedMemoryClearOption SharedMemoryClear,
135
- /// Gather operand A by using an index array
136
- bool GatherA,
137
- /// Gather operand B by using an index array
138
- bool GatherB>
139
- struct DefaultMma<bfloat16_t,
140
- LayoutA,
141
- kAlignmentA,
142
- bfloat16_t,
143
- LayoutB,
144
- kAlignmentB,
145
- ElementAccumulator,
146
- layout::RowMajor,
147
- arch::OpClassTensorOp,
148
- arch::Sm80,
149
- ThreadblockShape,
150
- WarpShape,
151
- InstructionShape,
152
- 2,
153
- Operator,
154
- false,
155
- SharedMemoryClear,
156
- GatherA,
157
- GatherB> {
158
-
159
- // Define the MmaCore components
160
- // 3 is used on purpose here to trigger components for mma multistage
161
- using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
162
- WarpShape,
163
- InstructionShape,
164
- bfloat16_t,
165
- LayoutA,
166
- bfloat16_t,
167
- LayoutB,
168
- ElementAccumulator,
169
- layout::RowMajor,
170
- arch::OpClassTensorOp,
171
- 3,
172
- Operator>;
173
-
174
- // Define iterators over tiles from the A operand
175
- using ThreadMapA = typename MmaCore::IteratorThreadMapA;
176
- using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
177
- using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
178
- cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
179
- bfloat16_t,
180
- LayoutA,
181
- 1,
182
- ThreadMapA,
183
- AccessTypeA,
184
- GatherA>;
185
-
186
- // Define iterators over tiles from the B operand
187
- using ThreadMapB = typename MmaCore::IteratorThreadMapB;
188
- using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
189
- using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
190
- cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
191
- bfloat16_t,
192
- LayoutB,
193
- 0,
194
- ThreadMapB,
195
- AccessTypeB,
196
- GatherB>;
197
-
198
- // Define the threadblock-scoped multistage matrix multiply
199
- using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
200
- IteratorA,
201
- typename MmaCore::SmemIteratorA,
202
- MmaCore::kCacheOpA,
203
- IteratorB,
204
- typename MmaCore::SmemIteratorB,
205
- MmaCore::kCacheOpB,
206
- ElementAccumulator,
207
- layout::RowMajor,
208
- typename MmaCore::MmaPolicy,
209
- 2>;
210
- };
211
-
212
- ////////////////////////////////////////////////////////////////////////////////
213
-
214
- /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
215
- template<
216
- /// Layout type for A matrix operand
217
- typename LayoutA,
218
- /// Access granularity of A matrix in units of elements
219
- int kAlignmentA,
220
- /// Layout type for B matrix operand
221
- typename LayoutB,
222
- /// Access granularity of B matrix in units of elements
223
- int kAlignmentB,
224
- /// Element type for internal accumulation
225
- typename ElementAccumulator,
226
- /// Tag indicating architecture to tune for
227
- typename ArchTag,
228
- /// Threadblock-level tile size (concept: GemmShape)
229
- typename ThreadblockShape,
230
- /// Warp-level tile size (concept: GemmShape)
231
- typename WarpShape,
232
- /// Instruction-level tile size (concept: GemmShape)
233
- typename InstructionShape,
234
- /// Operation performed by GEMM
235
- typename Operator>
236
- struct DefaultMma<cutlass::bfloat16_t,
237
- LayoutA,
238
- kAlignmentA,
239
- uint8_t,
240
- LayoutB,
241
- kAlignmentB,
242
- ElementAccumulator,
243
- layout::RowMajor,
244
- arch::OpClassTensorOp,
245
- ArchTag,
246
- ThreadblockShape,
247
- WarpShape,
248
- InstructionShape,
249
- 2,
250
- Operator> {
251
-
252
- private:
253
- static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
254
-
255
- using Mma = DqMma<bfloat16_t,
256
- LayoutA,
257
- kAlignmentA,
258
- uint8_t,
259
- LayoutB,
260
- kAlignmentB,
261
- bfloat16_t,
262
- layout::RowMajor,
263
- kAlignmentScale,
264
- ElementAccumulator,
265
- layout::RowMajor,
266
- arch::OpClassTensorOp,
267
- ArchTag,
268
- ThreadblockShape,
269
- WarpShape,
270
- InstructionShape,
271
- 2,
272
- Operator>;
273
-
274
- public:
275
- // Define the MmaCore components
276
- using MmaCore = typename Mma::MmaCore;
277
-
278
- // Define iterators over tiles from the A operand
279
- using IteratorA = typename Mma::IteratorA;
280
-
281
- // Define iterators over tiles from the B operand
282
- using IteratorB = typename Mma::IteratorB;
283
-
284
- // Define the threadblock-scoped pipelined matrix multiply
285
- using ThreadblockMma = typename Mma::ThreadblockMma;
286
- };
287
-
288
- ////////////////////////////////////////////////////////////////////////////////
289
- /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
290
- template<
291
- /// Layout type for A matrix operand
292
- typename LayoutA,
293
- /// Access granularity of A matrix in units of elements
294
- int kAlignmentA,
295
- /// Layout type for B matrix operand
296
- typename LayoutB,
297
- /// Access granularity of B matrix in units of elements
298
- int kAlignmentB,
299
- /// Element type for internal accumulation
300
- typename ElementAccumulator,
301
- /// Tag indicating architecture to tune for
302
- typename ArchTag,
303
- /// Threadblock-level tile size (concept: GemmShape)
304
- typename ThreadblockShape,
305
- /// Warp-level tile size (concept: GemmShape)
306
- typename WarpShape,
307
- /// Instruction-level tile size (concept: GemmShape)
308
- typename InstructionShape,
309
- /// Operation performed by GEMM
310
- typename Operator>
311
- struct DefaultMma<cutlass::bfloat16_t,
312
- LayoutA,
313
- kAlignmentA,
314
- uint4b_t,
315
- LayoutB,
316
- kAlignmentB,
317
- ElementAccumulator,
318
- layout::RowMajor,
319
- arch::OpClassTensorOp,
320
- ArchTag,
321
- ThreadblockShape,
322
- WarpShape,
323
- InstructionShape,
324
- 2,
325
- Operator> {
326
-
327
- private:
328
- static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
329
-
330
- using Mma = DqMma<bfloat16_t,
331
- LayoutA,
332
- kAlignmentA,
333
- uint4b_t,
334
- LayoutB,
335
- kAlignmentB,
336
- bfloat16_t,
337
- layout::RowMajor,
338
- kAlignmentScale,
339
- ElementAccumulator,
340
- layout::RowMajor,
341
- arch::OpClassTensorOp,
342
- ArchTag,
343
- ThreadblockShape,
344
- WarpShape,
345
- InstructionShape,
346
- 2,
347
- Operator>;
348
-
349
- public:
350
- // Define the MmaCore components
351
- using MmaCore = typename Mma::MmaCore;
352
-
353
- // Define iterators over tiles from the A operand
354
- using IteratorA = typename Mma::IteratorA;
355
-
356
- // Define iterators over tiles from the B operand
357
- using IteratorB = typename Mma::IteratorB;
358
-
359
- // Define the threadblock-scoped pipelined matrix multiply
360
- using ThreadblockMma = typename Mma::ThreadblockMma;
361
- };
362
-
363
- template<
364
- /// Layout type for A matrix operand
365
- typename LayoutA,
366
- /// Access granularity of A matrix in units of elements
367
- int kAlignmentA,
368
- /// Layout type for B matrix operand
369
- typename LayoutB,
370
- /// Access granularity of B matrix in units of elements
371
- int kAlignmentB,
372
- /// Element type for internal accumulation
373
- typename ElementAccumulator,
374
- /// Tag indicating architecture to tune for
375
- typename ArchTag,
376
- /// Threadblock-level tile size (concept: GemmShape)
377
- typename ThreadblockShape,
378
- /// Warp-level tile size (concept: GemmShape)
379
- typename WarpShape,
380
- /// Instruction-level tile size (concept: GemmShape)
381
- typename InstructionShape,
382
- /// Operation performed by GEMM
383
- typename Operator,
384
- ///
385
- int kStages,
386
- /// Shared memory clear option
387
- SharedMemoryClearOption SharedMemoryClear>
388
- struct DefaultMma<cutlass::bfloat16_t,
389
- LayoutA,
390
- kAlignmentA,
391
- uint8_t,
392
- LayoutB,
393
- kAlignmentB,
394
- ElementAccumulator,
395
- layout::RowMajor,
396
- arch::OpClassTensorOp,
397
- ArchTag,
398
- ThreadblockShape,
399
- WarpShape,
400
- InstructionShape,
401
- kStages,
402
- Operator,
403
- false,
404
- SharedMemoryClear> {
405
-
406
- private:
407
- static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
408
-
409
- using Mma = DqMma<bfloat16_t,
410
- LayoutA,
411
- kAlignmentA,
412
- uint8_t,
413
- LayoutB,
414
- kAlignmentB,
415
- bfloat16_t,
416
- layout::RowMajor,
417
- kAlignmentScale,
418
- ElementAccumulator,
419
- layout::RowMajor,
420
- arch::OpClassTensorOp,
421
- ArchTag,
422
- ThreadblockShape,
423
- WarpShape,
424
- InstructionShape,
425
- kStages,
426
- Operator,
427
- SharedMemoryClear>;
428
-
429
- public:
430
- // Define the MmaCore components
431
- using MmaCore = typename Mma::MmaCore;
432
-
433
- // Define iterators over tiles from the A operand
434
- using IteratorA = typename Mma::IteratorA;
435
-
436
- // Define iterators over tiles from the B operand
437
- using IteratorB = typename Mma::IteratorB;
438
-
439
- // Define the threadblock-scoped pipelined matrix multiply
440
- using ThreadblockMma = typename Mma::ThreadblockMma;
441
- };
442
-
443
- ////////////////////////////////////////////////////////////////////////////////
444
- /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
445
- template<
446
- /// Layout type for A matrix operand
447
- typename LayoutA,
448
- /// Access granularity of A matrix in units of elements
449
- int kAlignmentA,
450
- /// Layout type for B matrix operand
451
- typename LayoutB,
452
- /// Access granularity of B matrix in units of elements
453
- int kAlignmentB,
454
- /// Element type for internal accumulation
455
- typename ElementAccumulator,
456
- /// Tag indicating architecture to tune for
457
- typename ArchTag,
458
- /// Threadblock-level tile size (concept: GemmShape)
459
- typename ThreadblockShape,
460
- /// Warp-level tile size (concept: GemmShape)
461
- typename WarpShape,
462
- /// Instruction-level tile size (concept: GemmShape)
463
- typename InstructionShape,
464
- /// Operation performed by GEMM
465
- typename Operator,
466
- ///
467
- int kStages,
468
- /// Shared memory clear option
469
- SharedMemoryClearOption SharedMemoryClear>
470
- struct DefaultMma<cutlass::bfloat16_t,
471
- LayoutA,
472
- kAlignmentA,
473
- uint4b_t,
474
- LayoutB,
475
- kAlignmentB,
476
- ElementAccumulator,
477
- layout::RowMajor,
478
- arch::OpClassTensorOp,
479
- ArchTag,
480
- ThreadblockShape,
481
- WarpShape,
482
- InstructionShape,
483
- kStages,
484
- Operator,
485
- false,
486
- SharedMemoryClear> {
487
-
488
- private:
489
- static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
490
-
491
- using Mma = DqMma<bfloat16_t,
492
- LayoutA,
493
- kAlignmentA,
494
- uint4b_t,
495
- LayoutB,
496
- kAlignmentB,
497
- bfloat16_t,
498
- layout::RowMajor,
499
- kAlignmentScale,
500
- ElementAccumulator,
501
- layout::RowMajor,
502
- arch::OpClassTensorOp,
503
- ArchTag,
504
- ThreadblockShape,
505
- WarpShape,
506
- InstructionShape,
507
- kStages,
508
- Operator,
509
- SharedMemoryClear>;
510
-
511
- public:
512
- // Define the MmaCore components
513
- using MmaCore = typename Mma::MmaCore;
514
-
515
- // Define iterators over tiles from the A operand
516
- using IteratorA = typename Mma::IteratorA;
517
-
518
- // Define iterators over tiles from the B operand
519
- using IteratorB = typename Mma::IteratorB;
520
-
521
- // Define the threadblock-scoped pipelined matrix multiply
522
- using ThreadblockMma = typename Mma::ThreadblockMma;
523
- };
524
-
525
- } // namespace threadblock
526
- } // namespace gemm
527
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h DELETED
@@ -1,236 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Template for a double-buffered threadblock-scoped GEMM kernel.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/aligned_buffer.h"
38
- #include "cutlass/arch/memory.h"
39
- #include "cutlass/array.h"
40
- #include "cutlass/cutlass.h"
41
- #include "cutlass/gemm/gemm.h"
42
- #include "cutlass/gemm/threadblock/mma_base.h"
43
- #include "cutlass/matrix_shape.h"
44
- #include "cutlass/numeric_types.h"
45
-
46
- ////////////////////////////////////////////////////////////////////////////////
47
-
48
- namespace cutlass {
49
- namespace gemm {
50
- namespace threadblock {
51
-
52
- ////////////////////////////////////////////////////////////////////////////////
53
- // SFINAE trick so I can keep the same loop code for Volta and dispatch to the
54
- // correct warp level mma. On volta, all data is stored to shared memory as FP16.
55
- template<typename WarpMma, int kExpansionFactor = 1>
56
- CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
57
- typename WarpMma::FragmentC& D,
58
- typename WarpMma::FragmentA const& A,
59
- typename WarpMma::FragmentB const& B,
60
- typename WarpMma::FragmentC const& C,
61
- const int warp_tileB_k_offset)
62
- {
63
- warp_mma(D, A, B, C);
64
- }
65
-
66
- template<typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
67
- CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
68
- typename WarpMma::FragmentC& D,
69
- typename WarpMma::TransformedFragmentA const& A,
70
- typename WarpMma::TransformedFragmentB const& B,
71
- typename WarpMma::FragmentC const& C,
72
- const int warp_tileB_k_offset)
73
- {
74
- warp_mma(D, A, B, C, warp_tileB_k_offset);
75
- }
76
- ////////////////////////////////////////////////////////////////////////////////
77
-
78
- /// Structure to compute the matrix product targeting CUDA cores and SIMT math
79
- /// instructions.
80
- template<
81
- /// Size of the Gemm problem - concept: gemm::GemmShape<>
82
- typename Shape_,
83
- /// Policy describing tuning details (concept: MmaPolicy)
84
- typename Policy_,
85
- /// The type of the scales
86
- typename ElementScale_,
87
- /// Number of stages,
88
- int Stages,
89
- /// Used for partial specialization
90
- typename Enable = bool>
91
- class DqMmaBase {
92
- public:
93
- ///< Size of the Gemm problem - concept: gemm::GemmShape<>
94
- using Shape = Shape_;
95
-
96
- ///< Policy describing tuning details
97
- using Policy = Policy_;
98
-
99
- ///< Type of the scale to be loaded
100
- using ElementScale = ElementScale_;
101
-
102
- //
103
- // Dependent types
104
- //
105
-
106
- /// Warp-level Mma
107
- using Operator = typename Policy::Operator;
108
-
109
- /// Shape describing the overall GEMM computed from shared memory
110
- /// by each warp.
111
- using WarpGemm = typename Policy::Operator::Shape;
112
-
113
- /// Shape describing the number of warps filling the CTA
114
- using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
115
-
116
- /// Number of warp-level GEMM oeprations
117
- static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
118
-
119
- static constexpr int kNumKIterationsPerWarpBLoad =
120
- Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
121
-
122
- static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
123
- static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
124
-
125
- /// Number of stages
126
- static int const kStages = Stages;
127
-
128
- /// Tensor reference to the A operand
129
- using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
130
-
131
- /// Tensor reference to the B operand
132
- using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
133
-
134
- //
135
- // Nested structs
136
- //
137
-
138
- /// Shared storage object needed by threadblock-scoped GEMM
139
- class SharedStorage {
140
- public:
141
- //
142
- // Type definitions
143
- //
144
-
145
- /// Shape of the A matrix operand in shared memory
146
- using ShapeA =
147
- MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
148
-
149
- /// Shape of the B matrix operand in shared memory
150
- using ShapeB =
151
- MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
152
-
153
- public:
154
- //
155
- // Data members
156
- //
157
-
158
- /// Buffer for A operand
159
- AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
160
-
161
- /// Buffer for B operand
162
- AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
163
-
164
- /// Buffer to hold scales for threadblock
165
- AlignedBuffer<ElementScale, Shape::kN> operand_scale;
166
-
167
- public:
168
- //
169
- // Methods
170
- //
171
-
172
- /// Returns a layout object for the A matrix
173
- CUTLASS_DEVICE
174
- static typename Operator::LayoutA LayoutA()
175
- {
176
- return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
177
- }
178
-
179
- /// Returns a layout object for the B matrix
180
- CUTLASS_HOST_DEVICE
181
- static typename Operator::LayoutB LayoutB()
182
- {
183
- return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
184
- }
185
-
186
- /// Returns a TensorRef to the A operand
187
- CUTLASS_HOST_DEVICE
188
- TensorRefA operand_A_ref()
189
- {
190
- return TensorRefA{operand_A.data(), LayoutA()};
191
- }
192
-
193
- /// Returns a TensorRef to the B operand
194
- CUTLASS_HOST_DEVICE
195
- TensorRefB operand_B_ref()
196
- {
197
- return TensorRefB{operand_B.data(), LayoutB()};
198
- }
199
- };
200
-
201
- protected:
202
- //
203
- // Data members
204
- //
205
-
206
- /// Iterator to load a warp-scoped tile of A operand from shared memory
207
- typename Operator::IteratorA warp_tile_iterator_A_;
208
-
209
- /// Iterator to load a warp-scoped tile of B operand from shared memory
210
- typename Operator::IteratorB warp_tile_iterator_B_;
211
-
212
- public:
213
- /// Construct from tensor references
214
- CUTLASS_DEVICE
215
- DqMmaBase(
216
- ///< Shared storage needed for internal use by threadblock-scoped GEMM
217
- SharedStorage& shared_storage,
218
- ///< ID within the threadblock
219
- int thread_idx,
220
- ///< ID of warp
221
- int warp_idx,
222
- ///< ID of each thread within a warp
223
- int lane_idx):
224
- warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
225
- warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
226
- {
227
- }
228
- };
229
-
230
- /////////////////////////////////////////////////////////////////////////////////////////////////
231
-
232
- } // namespace threadblock
233
- } // namespace gemm
234
- } // namespace cutlass
235
-
236
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h DELETED
@@ -1,599 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Template for a double-buffered threadblock-scoped GEMM kernel.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/aligned_buffer.h"
38
- #include "cutlass/arch/memory.h"
39
- #include "cutlass/array.h"
40
- #include "cutlass/cutlass.h"
41
- #include "cutlass/gemm/gemm.h"
42
- #include "cutlass/matrix_shape.h"
43
- #include "cutlass/numeric_types.h"
44
-
45
- #include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
46
- #include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
47
- #include "cutlass_extensions/interleaved_numeric_conversion.h"
48
-
49
- /////////////////////////////////////////////////////////////////////////////////////////////////
50
-
51
- namespace cutlass {
52
- namespace gemm {
53
- namespace threadblock {
54
-
55
- /////////////////////////////////////////////////////////////////////////////////////////////////
56
-
57
- /// Structure to compute the matrix product targeting CUDA cores and SIMT math
58
- /// instructions.
59
- template<
60
- /// Size of the Gemm problem - concept: gemm::GemmShape<>
61
- typename Shape_,
62
- /// Iterates over tiles of A operand in global memory
63
- // (concept: ReadableTileIterator | ForwardTileIterator |
64
- // MaskedTileIterator)
65
- typename IteratorA_,
66
- /// Iterates over tiles of A operand in shared memory
67
- /// (concept: WriteableTileIterator | RandomAccessTileIterator)
68
- typename SmemIteratorA_,
69
- /// Cache operation for operand A
70
- cutlass::arch::CacheOperation::Kind CacheOpA,
71
- /// Iterates over tiles of B operand in global memory
72
- // (concept: ReadableTileIterator | ForwardTileIterator |
73
- // MaskedTileIterator)
74
- typename IteratorB_,
75
- /// Iterates over tiles of B operand in shared memory
76
- /// (concept: WriteableTileIterator | RandomAccessTileIterator)
77
- typename SmemIteratorB_,
78
- /// Cache operation for operand B
79
- cutlass::arch::CacheOperation::Kind CacheOpB,
80
- /// Data type for the scales
81
- typename IteratorScale_,
82
- /// Iterators over scales in shared memory
83
- typename SmemIteratorScale_,
84
- /// Data type of accumulator matrix
85
- typename ElementC_,
86
- /// Data type of accumulator matrix
87
- typename LayoutC_,
88
- /// Policy describing tuning details (concept: MmaPolicy)
89
- typename Policy_,
90
- /// Number of stages,
91
- int Stages,
92
- /// Converter for B matrix applited immediately after the LDS
93
- typename TransformBAfterLDS_,
94
- /// Use zfill or predicate for out-of-bound cp.async
95
- SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
96
- /// Used for partial specialization
97
- typename Enable = bool>
98
- class DqMmaMultistage: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages> {
99
- public:
100
- ///< Base class
101
- using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages>;
102
- ///< Size of the Gemm problem - concept: gemm::GemmShape<>
103
- using Shape = Shape_;
104
- ///< Iterates over tiles of A operand in global memory
105
- using IteratorA = IteratorA_;
106
- ///< Iterates over tiles of B operand in global memory
107
- using IteratorB = IteratorB_;
108
- ///< Data type of accumulator matrix
109
- using ElementC = ElementC_;
110
- ///< Layout of accumulator matrix
111
- using LayoutC = LayoutC_;
112
- ///< Policy describing tuning details
113
- using Policy = Policy_;
114
-
115
- using IteratorScale = IteratorScale_;
116
- using ElementScale = typename IteratorScale::Element;
117
- using LayoutScale = typename IteratorScale::Layout;
118
-
119
- using SmemIteratorA = SmemIteratorA_;
120
- using SmemIteratorB = SmemIteratorB_;
121
- using SmemIteratorScale = SmemIteratorScale_;
122
-
123
- static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
124
- static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
125
-
126
- using TransformBAfterLDS = TransformBAfterLDS_;
127
-
128
- //
129
- // Dependent types
130
- //
131
-
132
- /// Fragment of operand Scale loaded from global memory;
133
- using FragmentScale = typename IteratorScale::Fragment;
134
-
135
- /// Fragment of accumulator tile
136
- using FragmentC = typename Policy::Operator::FragmentC;
137
-
138
- /// Warp-level Mma
139
- using Operator = typename Policy::Operator;
140
-
141
- /// Minimum architecture is Sm80 to support cp.async
142
- using ArchTag = arch::Sm80;
143
-
144
- using Dequantizer =
145
- warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale, LayoutScale, 32>;
146
-
147
- /// Complex transform on A operand
148
- static ComplexTransform const kTransformA = Operator::kTransformA;
149
-
150
- /// Complex transform on B operand
151
- static ComplexTransform const kTransformB = Operator::kTransformB;
152
-
153
- /// Internal structure exposed for introspection.
154
- struct Detail {
155
-
156
- static_assert(Base::kWarpGemmIterations > 1,
157
- "The pipelined structure requires at least two warp-level "
158
- "GEMM operations.");
159
-
160
- /// Number of cp.async instructions to load one stage of operand A
161
- static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
162
-
163
- /// Number of cp.async instructions to load one stage of operand B
164
- static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
165
-
166
- /// Number of stages
167
- static int const kStages = Stages;
168
-
169
- /// Number of cp.async instructions to load on group of operand A
170
- static int const kAccessesPerGroupA =
171
- (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
172
-
173
- /// Number of cp.async instructions to load on group of operand B
174
- static int const kAccessesPerGroupB =
175
- (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
176
- };
177
-
178
- private:
179
- using WarpFragmentA = typename Operator::FragmentA;
180
- using WarpFragmentB = typename Operator::FragmentB;
181
- Dequantizer warp_dequantizer_;
182
-
183
- using ElementB = typename IteratorB::Element;
184
- using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
185
-
186
- static constexpr bool RequiresTileInterleave =
187
- layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
188
- static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
189
- "Layout K must match threadblockK");
190
-
191
- private:
192
- //
193
- // Data members
194
- //
195
-
196
- /// Iterator to write threadblock-scoped tile of A operand to shared memory
197
- SmemIteratorA smem_iterator_A_;
198
-
199
- /// Iterator to write threadblock-scoped tile of B operand to shared memory
200
- SmemIteratorB smem_iterator_B_;
201
-
202
- /// Iterator to write threadblock-scoped tile of scale operand to shared memory
203
- SmemIteratorScale smem_iterator_scale_;
204
-
205
- public:
206
- /// Construct from tensor references
207
- CUTLASS_DEVICE
208
- DqMmaMultistage(
209
- ///< Shared storage needed for internal use by threadblock-scoped GEMM
210
- typename Base::SharedStorage& shared_storage,
211
- ///< ID within the threadblock
212
- int thread_idx,
213
- ///< ID of warp
214
- int warp_idx,
215
- ///< ID of each thread within a warp
216
- int lane_idx):
217
- Base(shared_storage, thread_idx, warp_idx, lane_idx),
218
- warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
219
- (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
220
- lane_idx),
221
- smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
222
- smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
223
- smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
224
- {
225
- // Compute warp location within threadblock tile by mapping the warp_id to
226
- // three coordinates:
227
- // _m: the warp's position within the threadblock along the M dimension
228
- // _n: the warp's position within the threadblock along the N dimension
229
- // _k: the warp's position within the threadblock along the K dimension
230
-
231
- int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
232
- int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
233
-
234
- int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
235
- int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
236
-
237
- // Add per-warp offsets in units of warp-level tiles
238
- this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
239
- this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
240
- }
241
-
242
- CUTLASS_DEVICE
243
- void
244
- copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
245
- {
246
- iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
247
- this->smem_iterator_A_.set_iteration_index(group_start_A);
248
-
249
- // Async Copy for operand A
250
- CUTLASS_PRAGMA_UNROLL
251
- for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
252
- if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
253
- typename IteratorA::AccessType* dst_ptr =
254
- reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
255
-
256
- int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
257
- * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
258
-
259
- CUTLASS_PRAGMA_UNROLL
260
- for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
261
- auto gmem_ptr = iterator_A.get();
262
-
263
- if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
264
- cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
265
- }
266
- else {
267
- cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
268
- }
269
-
270
- ++iterator_A;
271
- }
272
-
273
- ++this->smem_iterator_A_;
274
- }
275
- }
276
-
277
- iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
278
- this->smem_iterator_B_.set_iteration_index(group_start_B);
279
-
280
- // Async Copy for operand B
281
- CUTLASS_PRAGMA_UNROLL
282
- for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
283
- if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
284
- typename IteratorB::AccessType* dst_ptr =
285
- reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
286
-
287
- int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
288
- * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
289
-
290
- CUTLASS_PRAGMA_UNROLL
291
- for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
292
- auto gmem_ptr = iterator_B.get();
293
-
294
- if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
295
- cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
296
- }
297
- else {
298
- cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
299
- }
300
-
301
- ++iterator_B;
302
- }
303
- ++this->smem_iterator_B_;
304
- }
305
- }
306
- }
307
-
308
- /// Perform a threadblock-scoped matrix multiply-accumulate
309
- CUTLASS_DEVICE
310
- void operator()(
311
- ///< problem size of GEMM
312
- int gemm_k_iterations,
313
- ///< destination accumulator tile
314
- FragmentC& accum,
315
- ///< iterator over A operand in global memory
316
- IteratorA iterator_A,
317
- ///< iterator over B operand in global memory
318
- IteratorB iterator_B,
319
- ///< iterator over scale operand in global memory
320
- IteratorScale iterator_scale,
321
- ///< initial value of accumulator
322
- FragmentC const& src_accum)
323
- {
324
-
325
- //
326
- // Prologue
327
- //
328
-
329
- TransformBAfterLDS lds_converter;
330
-
331
- // NOTE - switch to ldg.sts
332
- // Issue this first, so cp.async.commit_group will commit this load as well.
333
- // Note: we do not commit here and this load will commit in the same group as
334
- // the first load of A.
335
- FragmentScale tb_frag_scales;
336
- tb_frag_scales.clear();
337
- iterator_scale.load(tb_frag_scales);
338
- this->smem_iterator_scale_.store(tb_frag_scales);
339
-
340
- // Issue several complete stages
341
- CUTLASS_PRAGMA_UNROLL
342
- for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
343
-
344
- iterator_A.clear_mask(gemm_k_iterations == 0);
345
- iterator_B.clear_mask(gemm_k_iterations == 0);
346
-
347
- iterator_A.set_iteration_index(0);
348
- this->smem_iterator_A_.set_iteration_index(0);
349
-
350
- // Async Copy for operand A
351
- CUTLASS_PRAGMA_UNROLL
352
- for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
353
- typename IteratorA::AccessType* dst_ptr =
354
- reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
355
-
356
- CUTLASS_PRAGMA_UNROLL
357
- for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
358
- int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
359
- * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector
360
- / 8;
361
-
362
- int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
363
-
364
- cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
365
- dst_ptr + v, iterator_A.get(), iterator_A.valid());
366
-
367
- ++iterator_A;
368
- }
369
-
370
- ++this->smem_iterator_A_;
371
- }
372
-
373
- iterator_B.set_iteration_index(0);
374
- this->smem_iterator_B_.set_iteration_index(0);
375
-
376
- // Async Copy for operand B
377
- CUTLASS_PRAGMA_UNROLL
378
- for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
379
- typename IteratorB::AccessType* dst_ptr =
380
- reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
381
-
382
- CUTLASS_PRAGMA_UNROLL
383
- for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
384
- int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
385
- * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector
386
- / 8;
387
-
388
- cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
389
- dst_ptr + v, iterator_B.get(), iterator_B.valid());
390
-
391
- ++iterator_B;
392
- }
393
-
394
- ++this->smem_iterator_B_;
395
- }
396
-
397
- // Move to the next stage
398
- iterator_A.add_tile_offset({0, 1});
399
- iterator_B.add_tile_offset({1, 0});
400
-
401
- this->smem_iterator_A_.add_tile_offset({0, 1});
402
- this->smem_iterator_B_.add_tile_offset({1, 0});
403
-
404
- // Defines the boundary of a stage of cp.async.
405
- cutlass::arch::cp_async_fence();
406
- }
407
-
408
- // Perform accumulation in the 'd' output operand
409
- accum = src_accum;
410
-
411
- //
412
- // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
413
- // so that all accumulator elements outside the GEMM footprint are zero.
414
- //
415
-
416
- if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
417
-
418
- /// Iterator to write threadblock-scoped tile of A operand to shared memory
419
- SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
420
-
421
- typename IteratorA::AccessType zero_A;
422
- zero_A.clear();
423
-
424
- last_smem_iterator_A.set_iteration_index(0);
425
-
426
- // Async Copy for operand A
427
- CUTLASS_PRAGMA_UNROLL
428
- for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
429
-
430
- typename IteratorA::AccessType* dst_ptr =
431
- reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
432
-
433
- *dst_ptr = zero_A;
434
-
435
- ++last_smem_iterator_A;
436
- }
437
-
438
- /// Iterator to write threadblock-scoped tile of B operand to shared memory
439
- SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
440
- typename IteratorB::AccessType zero_B;
441
-
442
- zero_B.clear();
443
- last_smem_iterator_B.set_iteration_index(0);
444
-
445
- // Async Copy for operand B
446
- CUTLASS_PRAGMA_UNROLL
447
- for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
448
-
449
- typename IteratorB::AccessType* dst_ptr =
450
- reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
451
-
452
- *dst_ptr = zero_B;
453
-
454
- ++last_smem_iterator_B;
455
- }
456
- }
457
-
458
- // Waits until kStages-2 stages have committed.
459
- cutlass::arch::cp_async_wait<Base::kStages - 2>();
460
- __syncthreads();
461
-
462
- // Pair of fragments used to overlap shared memory loads and math
463
- // instructions
464
- WarpFragmentA warp_frag_A[2];
465
- WarpFragmentB warp_frag_B[2];
466
- typename Dequantizer::FragmentScale warp_frag_scales;
467
-
468
- Operator warp_mma;
469
-
470
- this->warp_tile_iterator_A_.set_kgroup_index(0);
471
- this->warp_tile_iterator_B_.set_kgroup_index(0);
472
-
473
- this->warp_tile_iterator_A_.load(warp_frag_A[0]);
474
- this->warp_tile_iterator_B_.load(warp_frag_B[0]);
475
- warp_dequantizer_.load(warp_frag_scales);
476
-
477
- ++this->warp_tile_iterator_A_;
478
- ++this->warp_tile_iterator_B_;
479
-
480
- iterator_A.clear_mask(gemm_k_iterations == 0);
481
- iterator_B.clear_mask(gemm_k_iterations == 0);
482
-
483
- int smem_write_stage_idx = Base::kStages - 1;
484
- int smem_read_stage_idx = 0;
485
-
486
- //
487
- // Mainloop
488
- //
489
-
490
- CUTLASS_GEMM_LOOP
491
- for (; gemm_k_iterations > (-Base::kStages + 1);) {
492
- //
493
- // Loop over GEMM K dimension
494
- //
495
-
496
- // Computes a warp-level GEMM on data held in shared memory
497
- // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
498
- CUTLASS_PRAGMA_UNROLL
499
- for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
500
-
501
- // Load warp-level tiles from shared memory, wrapping to k offset if
502
- // this is the last group as the case may be.
503
-
504
- this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
505
- this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
506
- ++this->warp_tile_iterator_A_;
507
-
508
- const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
509
- const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
510
- if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) {
511
- this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1)
512
- % Base::kWarpGemmIterationsForB);
513
- this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
514
- ++this->warp_tile_iterator_B_;
515
- }
516
-
517
- typename TransformBAfterLDS::result_type converted_frag_B =
518
- lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
519
- warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
520
-
521
- run_warp_mma(
522
- warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
523
-
524
- // Issue global->shared copies for the this stage
525
- if (warp_mma_k < Base::kWarpGemmIterations - 1) {
526
- int group_start_iteration_A, group_start_iteration_B;
527
-
528
- group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
529
- group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
530
-
531
- copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
532
- }
533
-
534
- if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
535
- int group_start_iteration_A, group_start_iteration_B;
536
- group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
537
- group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
538
-
539
- copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
540
-
541
- // Inserts a memory fence between stages of cp.async instructions.
542
- cutlass::arch::cp_async_fence();
543
-
544
- // Waits until kStages-2 stages have committed.
545
- arch::cp_async_wait<Base::kStages - 2>();
546
- __syncthreads();
547
-
548
- // Move to the next stage
549
- iterator_A.add_tile_offset({0, 1});
550
- iterator_B.add_tile_offset({1, 0});
551
-
552
- this->smem_iterator_A_.add_tile_offset({0, 1});
553
- this->smem_iterator_B_.add_tile_offset({1, 0});
554
-
555
- // Add negative offsets to return iterators to the 'start' of the
556
- // circular buffer in shared memory
557
- if (smem_write_stage_idx == (Base::kStages - 1)) {
558
- this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
559
- this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
560
- smem_write_stage_idx = 0;
561
- }
562
- else {
563
- ++smem_write_stage_idx;
564
- }
565
-
566
- if (smem_read_stage_idx == (Base::kStages - 1)) {
567
- this->warp_tile_iterator_A_.add_tile_offset(
568
- {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
569
- this->warp_tile_iterator_B_.add_tile_offset(
570
- {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
571
- smem_read_stage_idx = 0;
572
- }
573
- else {
574
- ++smem_read_stage_idx;
575
- }
576
-
577
- --gemm_k_iterations;
578
- iterator_A.clear_mask(gemm_k_iterations == 0);
579
- iterator_B.clear_mask(gemm_k_iterations == 0);
580
- }
581
- }
582
- }
583
-
584
- if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
585
- // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
586
- cutlass::arch::cp_async_fence();
587
- cutlass::arch::cp_async_wait<0>();
588
- __syncthreads();
589
- }
590
- }
591
- };
592
-
593
- /////////////////////////////////////////////////////////////////////////////////////////////////
594
-
595
- } // namespace threadblock
596
- } // namespace gemm
597
- } // namespace cutlass
598
-
599
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h DELETED
@@ -1,385 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Template for a double-buffered threadblock-scoped GEMM kernel.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/aligned_buffer.h"
38
- #include "cutlass/array.h"
39
- #include "cutlass/cutlass.h"
40
- #include "cutlass/numeric_conversion.h"
41
-
42
- #include "cutlass/matrix_shape.h"
43
- #include "cutlass/numeric_types.h"
44
-
45
- #include "cutlass/gemm/gemm.h"
46
-
47
- #include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
48
- #include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
49
- #include "cutlass_extensions/interleaved_numeric_conversion.h"
50
-
51
- #include "cutlass_extensions/ft_gemm_configs.h"
52
- #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
53
-
54
- /////////////////////////////////////////////////////////////////////////////////////////////////
55
-
56
- namespace cutlass {
57
- namespace gemm {
58
- namespace threadblock {
59
-
60
- /////////////////////////////////////////////////////////////////////////////////////////////////
61
-
62
- /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
63
- template<
64
- /// Size of the Gemm problem - concept: gemm::GemmShape<>
65
- typename Shape_,
66
- /// Iterates over tiles of A operand in global memory
67
- // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
68
- typename IteratorA_,
69
- /// Iterates over tiles of A operand in shared memory
70
- /// (concept: WriteableTileIterator | RandomAccessTileIterator)
71
- typename SmemIteratorA_,
72
- /// Iterates over tiles of B operand in global memory
73
- // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
74
- typename IteratorB_,
75
- /// Iterates over tiles of B operand in shared memory
76
- /// (concept: WriteableTileIterator | RandomAccessTileIterator)
77
- typename SmemIteratorB_,
78
- /// Data type for the scales
79
- typename IteratorScale_,
80
- /// Iterators over scales in shared memory
81
- typename SmemIteratorScale_,
82
- /// Data type of accumulator matrix
83
- typename ElementC_,
84
- /// Data type of accumulator matrix
85
- typename LayoutC_,
86
- /// Policy describing tuning details (concept: MmaPolicy)
87
- typename Policy_,
88
- /// Converter for B matrix applied immediately after the LDG (before STS)
89
- typename TransformBAfterLDG_,
90
- /// Converter for B matrix applited immediately after the LDS
91
- typename TransformBAfterLDS_,
92
- /// Used for partial specialization
93
- typename Enable = bool>
94
- class DqMmaPipelined: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2> {
95
- public:
96
- ///< Base class
97
- using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2>;
98
-
99
- using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
100
- using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
101
- using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
102
- using ElementC = ElementC_; ///< Data type of accumulator matrix
103
- using LayoutC = LayoutC_; ///< Layout of accumulator matrix
104
- using Policy = Policy_; ///< Policy describing tuning details
105
-
106
- using IteratorScale = IteratorScale_;
107
- using ElementScale = typename IteratorScale::Element;
108
- using LayoutScale = typename IteratorScale::Layout;
109
-
110
- using SmemIteratorA = SmemIteratorA_;
111
- using SmemIteratorB = SmemIteratorB_;
112
- using SmemIteratorScale = SmemIteratorScale_;
113
-
114
- using TransformBAfterLDG = TransformBAfterLDG_;
115
- using TransformBAfterLDS = TransformBAfterLDS_;
116
-
117
- //
118
- // Dependent types
119
- //
120
-
121
- /// Fragment of operand A loaded from global memory
122
- using FragmentA = typename IteratorA::Fragment;
123
-
124
- /// Fragment of operand B loaded from global memory
125
- using FragmentB = typename IteratorB::Fragment;
126
-
127
- /// Fragment of operand Scale loaded from global memory;
128
- using FragmentScale = typename IteratorScale::Fragment;
129
-
130
- /// Fragment of accumulator tile
131
- using FragmentC = typename Policy::Operator::FragmentC;
132
-
133
- /// Warp-level Mma
134
- using Operator = typename Policy::Operator;
135
-
136
- /// Obtain the arch tag from the warp-level operator
137
- using ArchTag = typename Policy::Operator::ArchTag;
138
-
139
- using Dequantizer = warp::MmaTensorOpDequantizer<Operator,
140
- typename Base::WarpGemm,
141
- Operand::kB,
142
- typename SmemIteratorScale::Fragment::Element,
143
- LayoutScale,
144
- 32>;
145
-
146
- /// Complex transform on A operand
147
- static ComplexTransform const kTransformA = Operator::kTransformA;
148
-
149
- /// Complex transform on B operand
150
- static ComplexTransform const kTransformB = Operator::kTransformB;
151
-
152
- // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
153
- static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
154
-
155
- private:
156
- using WarpFragmentA = typename Operator::FragmentA;
157
- using WarpFragmentB = typename Operator::FragmentB;
158
- Dequantizer warp_dequantizer_;
159
-
160
- using ElementB = typename IteratorB::Element;
161
- using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
162
-
163
- static constexpr bool RequiresTileInterleave =
164
- layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
165
- static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
166
- "Layout K must match threadblockK");
167
-
168
- protected:
169
- /// Iterator to write threadblock-scoped tile of A operand to shared memory
170
- SmemIteratorA smem_iterator_A_;
171
-
172
- /// Iterator to write threadblock-scoped tile of B operand to shared memory
173
- SmemIteratorB smem_iterator_B_;
174
-
175
- /// Iterator to write threadblock-scoped tile of scale operand to shared memory
176
- SmemIteratorScale smem_iterator_scale_;
177
-
178
- public:
179
- /// Construct from tensor references
180
- CUTLASS_DEVICE
181
- DqMmaPipelined(typename Base::SharedStorage&
182
- shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
183
- int thread_idx, ///< ID within the threadblock
184
- int warp_idx, ///< ID of warp
185
- int lane_idx ///< ID of each thread within a warp
186
- ):
187
- Base(shared_storage, thread_idx, warp_idx, lane_idx),
188
- warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
189
- (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
190
- lane_idx),
191
- smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
192
- smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
193
- smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
194
- {
195
-
196
- // Compute warp location within threadblock tile by mapping the warp_id to
197
- // three coordinates:
198
- // _m: the warp's position within the threadblock along the M dimension
199
- // _n: the warp's position within the threadblock along the N dimension
200
- // _k: the warp's position within the threadblock along the K dimension
201
-
202
- int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
203
- int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
204
-
205
- int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
206
- int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
207
-
208
- // Add per-warp offsets in units of warp-level tiles
209
- this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
210
- this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
211
- }
212
-
213
- /// Perform a threadblock-scoped matrix multiply-accumulate
214
- CUTLASS_DEVICE
215
- void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
216
- FragmentC& accum, ///< destination accumulator tile
217
- IteratorA iterator_A, ///< iterator over A operand in global memory
218
- IteratorB iterator_B, ///< iterator over B operand in global memory
219
- IteratorScale iterator_scale, ///< iterator over scale operand in global memory
220
- FragmentC const& src_accum)
221
- { ///< source accumulator tile
222
-
223
- //
224
- // Prologue
225
- //
226
- TransformBAfterLDG ldg_converter;
227
- TransformBAfterLDS lds_converter;
228
-
229
- using TransformA =
230
- NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
231
-
232
- using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
233
- typename FragmentScale::Element,
234
- FragmentScale::kElements>;
235
-
236
- // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
237
- // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
238
- TransformA transformA;
239
- TransformScale transformScale;
240
-
241
- // Perform accumulation in the 'd' output operand
242
- accum = src_accum;
243
-
244
- FragmentA tb_frag_A;
245
- FragmentB tb_frag_B;
246
- FragmentScale tb_frag_scales;
247
-
248
- using WarpFragmentScale = typename Dequantizer::FragmentScale;
249
- WarpFragmentScale warp_frag_scales;
250
-
251
- tb_frag_A.clear();
252
- tb_frag_B.clear();
253
- tb_frag_scales.clear();
254
-
255
- // The last kblock is loaded in the prolog
256
- iterator_A.load(tb_frag_A);
257
- iterator_B.load(tb_frag_B);
258
- iterator_scale.load(tb_frag_scales);
259
-
260
- ++iterator_A;
261
- ++iterator_B;
262
-
263
- this->smem_iterator_A_.store(transformA(tb_frag_A));
264
- this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
265
- this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
266
-
267
- ++this->smem_iterator_A_;
268
- ++this->smem_iterator_B_;
269
-
270
- __syncthreads();
271
-
272
- warp_dequantizer_.load(warp_frag_scales);
273
-
274
- // Pair of fragments used to overlap shared memory loads and math instructions
275
- WarpFragmentA warp_frag_A[2];
276
- WarpFragmentB warp_frag_B[2];
277
-
278
- this->warp_tile_iterator_A_.set_kgroup_index(0);
279
- this->warp_tile_iterator_B_.set_kgroup_index(0);
280
-
281
- this->warp_tile_iterator_A_.load(warp_frag_A[0]);
282
- this->warp_tile_iterator_B_.load(warp_frag_B[0]);
283
-
284
- ++this->warp_tile_iterator_A_;
285
- ++this->warp_tile_iterator_B_;
286
-
287
- Operator warp_mma;
288
-
289
- int smem_write_stage_idx = 1;
290
-
291
- // Avoid reading out of bounds
292
- iterator_A.clear_mask(gemm_k_iterations <= 1);
293
- iterator_B.clear_mask(gemm_k_iterations <= 1);
294
-
295
- // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
296
- // shared memory loads (which have the tighest latency requirement).
297
-
298
- //
299
- // Mainloop
300
- //
301
-
302
- // Note: The main loop does not support Base::kWarpGemmIterations == 2.
303
- CUTLASS_GEMM_LOOP
304
- for (; gemm_k_iterations > 0; --gemm_k_iterations) {
305
- //
306
- // Loop over GEMM K dimension
307
- //
308
-
309
- CUTLASS_PRAGMA_UNROLL
310
- for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
311
-
312
- // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
313
- // as the case may be.
314
-
315
- if (warp_mma_k == Base::kWarpGemmIterations - 1) {
316
-
317
- // Write fragments to shared memory
318
- this->smem_iterator_A_.store(transformA(tb_frag_A));
319
-
320
- this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
321
-
322
- __syncthreads();
323
-
324
- ++this->smem_iterator_A_;
325
- ++this->smem_iterator_B_;
326
-
327
- // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
328
- if (smem_write_stage_idx == 1) {
329
- this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
330
- this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
331
- }
332
- else {
333
- this->warp_tile_iterator_A_.add_tile_offset(
334
- {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
335
- this->warp_tile_iterator_B_.add_tile_offset(
336
- {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
337
- }
338
-
339
- smem_write_stage_idx ^= 1;
340
- }
341
-
342
- this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
343
- this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
344
- ++this->warp_tile_iterator_A_;
345
-
346
- const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
347
- const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
348
- // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
349
- if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) {
350
- this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1)
351
- % Base::kWarpGemmIterationsForB);
352
- this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
353
- ++this->warp_tile_iterator_B_;
354
- }
355
-
356
- if (warp_mma_k == 0) {
357
-
358
- iterator_A.load(tb_frag_A);
359
- iterator_B.load(tb_frag_B);
360
-
361
- ++iterator_A;
362
- ++iterator_B;
363
-
364
- // Avoid reading out of bounds if this was the last loop iteration
365
- iterator_A.clear_mask(gemm_k_iterations <= 2);
366
- iterator_B.clear_mask(gemm_k_iterations <= 2);
367
- }
368
-
369
- typename TransformBAfterLDS::result_type converted_frag_B =
370
- lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
371
- warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
372
- run_warp_mma(
373
- warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
374
- }
375
- }
376
- }
377
- };
378
-
379
- /////////////////////////////////////////////////////////////////////////////////////////////////
380
-
381
- } // namespace threadblock
382
- } // namespace gemm
383
- } // namespace cutlass
384
-
385
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h DELETED
@@ -1,127 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/cutlass.h"
38
- #include "cutlass/gemm/warp/default_mma_tensor_op.h"
39
- #include "cutlass/gemm/warp/mma_tensor_op.h"
40
-
41
- #include "cutlass_extensions/arch/mma.h"
42
- #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
43
-
44
- namespace cutlass {
45
- namespace gemm {
46
- namespace warp {
47
-
48
- /////////////////////////////////////////////////////////////////////////////////////////////////
49
-
50
- /// Partial specialization for m-by-n-by-kgroup
51
- template<
52
- /// Shape of one matrix production operation (concept: GemmShape)
53
- typename WarpShape_,
54
- /// Shape of one matrix production operation (concept: GemmShape)
55
- typename InstructionShape_,
56
- /// Data type of A elements,
57
- typename ElementA,
58
- /// Layout of A matrix (concept: MatrixLayout)
59
- typename LayoutA,
60
- /// Data type of B elements
61
- typename ElementB,
62
- /// Layout of B matrix (concept: MatrixLayout)
63
- typename LayoutB,
64
- /// Element type of C matrix
65
- typename ElementC,
66
- /// Layout of C matrix (concept: MatrixLayout)
67
- typename LayoutC,
68
- /// Number of partitions along K dimension
69
- int PartitionsK,
70
- /// Store the accumulators in row major or column major. Row major is used
71
- /// when output layout is interleaved.
72
- bool AccumulatorsInRowMajor>
73
- struct DefaultMmaTensorOp<WarpShape_,
74
- InstructionShape_,
75
- ElementA,
76
- LayoutA,
77
- ElementB,
78
- LayoutB,
79
- ElementC,
80
- LayoutC,
81
- arch::OpMultiplyAddDequantizeInterleavedBToA,
82
- PartitionsK,
83
- AccumulatorsInRowMajor> {
84
-
85
- private:
86
- // Shape for computing the FP16s
87
- using ComputeInstructionShape = InstructionShape_;
88
-
89
- // Chosen so we get K=16 for int8 and K=32 for int4.
90
- static constexpr int LoadInstructionK = 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value;
91
-
92
- // Shape for loading the narrow data type from shared memory
93
- using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
94
-
95
- public:
96
- using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<InstructionShape_,
97
- 32,
98
- ElementA,
99
- cutlass::layout::RowMajor,
100
- ElementA,
101
- cutlass::layout::ColumnMajor,
102
- ElementC,
103
- cutlass::layout::RowMajor,
104
- arch::OpMultiplyAdd>,
105
- cutlass::MatrixShape<1, 1>>;
106
-
107
- // Define the warp-level tensor op
108
- using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_,
109
- ElementA,
110
- LayoutA,
111
- ElementB,
112
- LayoutB,
113
- ElementC,
114
- LayoutC,
115
- Policy,
116
- LoadInstructionShape,
117
- PartitionsK,
118
- AccumulatorsInRowMajor>;
119
- };
120
-
121
- /////////////////////////////////////////////////////////////////////////////////////////////////
122
-
123
- } // namespace warp
124
- } // namespace gemm
125
- } // namespace cutlass
126
-
127
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h DELETED
@@ -1,313 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Templates implementing warp-level matrix multiply-accumulate operations targeting
33
- Tensor Cores.
34
- */
35
-
36
- #pragma once
37
-
38
- #include "cutlass/array.h"
39
- #include "cutlass/cutlass.h"
40
- #include "cutlass/platform/platform.h"
41
-
42
- #include "cutlass/matrix_shape.h"
43
- #include "cutlass/numeric_conversion.h"
44
- #include "cutlass/numeric_types.h"
45
-
46
- #include "cutlass/arch/memory_sm75.h"
47
- #include "cutlass/arch/mma_sm75.h"
48
- #include "cutlass/arch/mma_sm80.h"
49
-
50
- #include "cutlass/gemm/gemm.h"
51
- #include "cutlass/gemm/warp/mma.h"
52
-
53
- #include "cutlass/gemm/warp/mma_tensor_op_policy.h"
54
-
55
- #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
56
- #include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
57
-
58
- /////////////////////////////////////////////////////////////////////////////////////////////////
59
-
60
- namespace cutlass {
61
- namespace gemm {
62
- namespace warp {
63
-
64
- /////////////////////////////////////////////////////////////////////////////////////////////////
65
- /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
66
- template<
67
- /// Size of the Gemm problem - concept: gemm::GemmShape<>
68
- typename Shape_,
69
- /// Data type of A elements
70
- typename ElementA_,
71
- /// Layout of A matrix (concept: MatrixLayout)
72
- typename LayoutA_,
73
- /// Data type of B elements
74
- typename ElementB_,
75
- /// Layout of B matrix (concept: MatrixLayout)
76
- typename LayoutB_,
77
- /// Element type of C matrix
78
- typename ElementC_,
79
- /// Layout of C matrix (concept: MatrixLayout)
80
- typename LayoutC_,
81
- /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
82
- typename Policy_,
83
- /// Instruction shape to override shared memory iterators with
84
- typename SharedMemoryInstructionShape_,
85
- /// Number of partitions along K dimension
86
- int PartitionsK_ = 1,
87
- /// Store the accumulators in row major or column major. Row major is used
88
- /// when output layout is interleaved.
89
- bool AccumulatorsInRowMajor = false,
90
- /// Used for partial specialization
91
- typename Enable = bool>
92
- class MmaTensorOpComputeBWithF16 {
93
- public:
94
- /// Shape of warp-level matrix operation (concept: GemmShape)
95
- using Shape = Shape_;
96
-
97
- /// Data type of multiplicand A
98
- using ElementA = ElementA_;
99
-
100
- /// Layout of multiplicand A
101
- using LayoutA = LayoutA_;
102
-
103
- /// Data type of multiplicand B
104
- using ElementB = ElementB_;
105
-
106
- /// Layout of multiplicand B
107
- using LayoutB = LayoutB_;
108
-
109
- /// Data type of accumulator matrix C
110
- using ElementC = ElementC_;
111
-
112
- /// Layout of accumulator matrix C
113
- using LayoutC = LayoutC_;
114
-
115
- /// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
116
- using Policy = Policy_;
117
-
118
- /// Underlying matrix multiply operator (concept: arch::Mma)
119
- using ArchMmaOperator = typename Policy::Operator;
120
-
121
- /// Indicates math operator
122
- using MathOperator = typename ArchMmaOperator::Operator;
123
-
124
- /// Architecture tag from underlying instruction
125
- using ArchTag = typename ArchMmaOperator::ArchTag;
126
- static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
127
- && platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
128
- || (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
129
- && platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
130
- && ArchTag::kMinComputeCapability >= 80),
131
- "MmaTensorOpCvtBToA only supports underlying HMMA");
132
-
133
- static_assert(platform::is_same<ElementA, half_t>::value
134
- || (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
135
- "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
136
-
137
- /// Indicates class of matrix operator
138
- using OperatorClass = arch::OpClassTensorOp;
139
-
140
- /// Shape of underlying instruction
141
- using InstructionShape = typename ArchMmaOperator::Shape;
142
-
143
- /// Instruction shape to override shared memory iterators with
144
- using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
145
-
146
- static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM,
147
- "M dimension of compute instruction must match load");
148
- static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN,
149
- "N dimension of compute instruction must match load");
150
-
151
- static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
152
-
153
- static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
154
-
155
- /// Complex transform on A operand
156
- static ComplexTransform const kTransformA = ComplexTransform::kNone;
157
-
158
- /// Complex transform on B operand
159
- static ComplexTransform const kTransformB = ComplexTransform::kNone;
160
-
161
- /// Number of threads participating in warp-level matrix product
162
- static int const kThreadCount = 32;
163
-
164
- /// Number of partitions along K dimension
165
- static int const kPartitionsK = PartitionsK_;
166
-
167
- public:
168
- /// Iterates over the A operand in memory
169
- using IteratorA = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>,
170
- Operand::kA,
171
- ElementA,
172
- LayoutA,
173
- MatrixShape<InstructionShape::kM, InstructionShape::kK>,
174
- Policy::OpDelta::kRow,
175
- kThreadCount,
176
- kPartitionsK>;
177
-
178
- /// Storage for A tile
179
- using FragmentA = typename IteratorA::Fragment;
180
-
181
- /// Storage for transformed A tile
182
- using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
183
-
184
- /// Iterates over the B operand in memory
185
- using IteratorB =
186
- MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>,
187
- Operand::kB,
188
- ElementB,
189
- LayoutB,
190
- MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>,
191
- Policy::OpDelta::kRow,
192
- kThreadCount,
193
- kPartitionsK>;
194
-
195
- /// Storage for B tile
196
- using FragmentB = typename IteratorB::Fragment;
197
-
198
- /// Storage for transformed B tile
199
- using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
200
-
201
- /// Iterates over the C operand in memory
202
- using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>,
203
- ElementC,
204
- LayoutC,
205
- typename ArchMmaOperator::Shape,
206
- typename Policy::OpDelta>;
207
-
208
- /// Storage for C tile
209
- using FragmentC = typename IteratorC::Fragment;
210
-
211
- /// Number of mma operations performed
212
- using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
213
- (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
214
-
215
- public:
216
- /// Underlying matrix multiply operator (concept: arch::Mma)
217
- ArchMmaOperator mma;
218
-
219
- public:
220
- //
221
- // Methods
222
- //
223
-
224
- /// Ctor
225
- CUTLASS_DEVICE
226
- MmaTensorOpComputeBWithF16() {}
227
-
228
- /// Performs a warp-level matrix multiply-accumulate operation
229
- CUTLASS_DEVICE
230
- void operator()(FragmentC& D,
231
- TransformedFragmentA const& A,
232
- TransformedFragmentB const& B,
233
- FragmentC const& C,
234
- const int warp_tileB_k_offset) const
235
- {
236
-
237
- using MmaOperandA = typename ArchMmaOperator::FragmentA;
238
- using MmaOperandB = typename ArchMmaOperator::FragmentB;
239
- using MmaOperandC = typename ArchMmaOperator::FragmentC;
240
-
241
- static_assert(
242
- TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
243
- "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B");
244
-
245
- D = C;
246
-
247
- MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
248
- MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
249
- MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
250
-
251
- #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
252
- // Serpentine visitation order maximizing reuse of Rb
253
- CUTLASS_PRAGMA_UNROLL
254
- for (int n = 0; n < MmaIterations::kColumn; ++n) {
255
-
256
- CUTLASS_PRAGMA_UNROLL
257
- for (int m = 0; m < MmaIterations::kRow; ++m) {
258
-
259
- int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
260
-
261
- int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
262
- if (AccumulatorsInRowMajor) { // matrix B is reordered
263
- mma(ptr_D[n + m_serpentine * MmaIterations::kColumn],
264
- ptr_A[m_serpentine],
265
- ptr_B[n_offsetB],
266
- ptr_D[n + m_serpentine * MmaIterations::kColumn]);
267
- }
268
- else {
269
- mma(ptr_D[m_serpentine + n * MmaIterations::kRow],
270
- ptr_A[m_serpentine],
271
- ptr_B[n_offsetB],
272
- ptr_D[m_serpentine + n * MmaIterations::kRow]);
273
- }
274
- }
275
- }
276
- #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
277
- // Serpentine visitation order maximizing reuse of Ra
278
- CUTLASS_PRAGMA_UNROLL
279
- for (int m = 0; m < MmaIterations::kRow; ++m) {
280
-
281
- CUTLASS_PRAGMA_UNROLL
282
- for (int n = 0; n < MmaIterations::kColumn; ++n) {
283
-
284
- int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
285
-
286
- int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine;
287
- if (AccumulatorsInRowMajor) { // matrix B is reordered
288
- mma(ptr_D[n_serpentine + m * MmaIterations::kColumn],
289
- ptr_A[m],
290
- ptr_B[n_serpentine_offsetB],
291
- ptr_D[n_serpentine + m * MmaIterations::kColumn]);
292
- }
293
- else {
294
- mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
295
- ptr_A[m],
296
- ptr_B[n_serpentine_offsetB],
297
- ptr_D[m + n_serpentine * MmaIterations::kRow]);
298
- }
299
- }
300
- }
301
- #else
302
- assert(0);
303
- #endif
304
- }
305
- };
306
-
307
- /////////////////////////////////////////////////////////////////////////////////////////////////
308
-
309
- } // namespace warp
310
- } // namespace gemm
311
- } // namespace cutlass
312
-
313
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h DELETED
@@ -1,469 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
33
- */
34
-
35
- #pragma once
36
-
37
- #include "cutlass/cutlass.h"
38
-
39
- #include "cutlass/array.h"
40
- #include "cutlass/matrix_shape.h"
41
- #include "cutlass/numeric_types.h"
42
- #include "cutlass/tensor_ref.h"
43
-
44
- #include "cutlass/arch/arch.h"
45
- #include "cutlass/arch/memory_sm75.h"
46
- #include "cutlass/gemm/gemm.h"
47
-
48
- #include "cutlass/layout/matrix.h"
49
- #include "cutlass/layout/pitch_linear.h"
50
- #include "cutlass/layout/tensor.h"
51
-
52
- #include "cutlass/functional.h"
53
- #include "cutlass/platform/platform.h"
54
-
55
-
56
- ////////////////////////////////////////////////////////////////////////////////
57
-
58
- namespace cutlass {
59
- namespace gemm {
60
- namespace warp {
61
-
62
- ////////////////////////////////////////////////////////////////////////////////
63
-
64
- template<
65
- /// Matrix multiply operator
66
- typename MmaOperator_,
67
- /// Size of the matrix to load (concept: MatrixShape)
68
- typename Shape_,
69
- /// Operand identity
70
- Operand Operand,
71
- /// Data type of Scale elements
72
- typename Element_,
73
- /// Layout of operand
74
- typename Layout_,
75
- /// Number of threads participating in one matrix operation
76
- int Threads,
77
- ///
78
- typename Enable = void>
79
- class MmaTensorOpDequantizer;
80
-
81
- ////////////////////////////////////////////////////////////////////////////////
82
- // Bfloat specialization for Ampere
83
- template<
84
- /// Underlying matrix multiply operator (concept: MmaTensorOp)
85
- typename MmaOperator_,
86
- /// Shape of the warp level matrix multiply (concept: GemmShape)
87
- typename Shape_>
88
- class MmaTensorOpDequantizer<
89
- MmaOperator_,
90
- Shape_,
91
- Operand::kB,
92
- bfloat16_t,
93
- layout::RowMajor,
94
- 32,
95
- typename platform::enable_if<
96
- MmaOperator_::ArchTag::kMinComputeCapability >= 80
97
- && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
98
-
99
- public:
100
- /// Mma Operator
101
- using MmaOperator = MmaOperator_;
102
-
103
- // The architecture specific mma ooperator being used
104
- using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
105
-
106
- // Mma Instruction Shape
107
- using InstructionShape = typename ArchMmaOperator::Shape;
108
-
109
- // This is the ratio of the load instruction vs the compute instruction.
110
- static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
111
-
112
- /// Type of the scales
113
- using ElementScale = bfloat16_t;
114
-
115
- /// Fragment to hold B data before Mma
116
- using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
117
-
118
- // Fragment to hold scale data to apply to B before mma
119
- // We need 1 fp16 per matrix iteration in the N dimension
120
- static constexpr int kColsPerMmaPerThread = 1;
121
- using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
122
-
123
- /// Warp mma shape
124
- using Shape = Shape_;
125
-
126
- /// Layout of the scales in shared memory
127
- using Layout = layout::RowMajor;
128
-
129
- /// TensorRef type for loading element from a tensor
130
- using TensorRef = TensorRef<ElementScale, Layout>;
131
-
132
- CUTLASS_DEVICE
133
- MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
134
- {
135
- const int warp_offset = warp_idx_n * Shape::kN;
136
- const int quad = lane_idx / 4;
137
- const int thread_offset = warp_offset + quad;
138
- pointer_ = smem_scales.data() + thread_offset;
139
- }
140
-
141
- CUTLASS_DEVICE
142
- void load(FragmentScale& scale_frag)
143
- {
144
-
145
- CUTLASS_PRAGMA_UNROLL
146
- for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
147
- scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
148
- }
149
- }
150
-
151
- CUTLASS_DEVICE
152
- void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
153
- {
154
- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
155
- using _MmaOperandB = typename ArchMmaOperator::FragmentB;
156
- using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
157
- static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
158
- == FragmentDequantizedOperand::kElements,
159
- "");
160
-
161
- const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag);
162
-
163
- ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
164
- CUTLASS_PRAGMA_UNROLL
165
- for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
166
- static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
167
-
168
- __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
169
- __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
170
- CUTLASS_PRAGMA_UNROLL
171
- for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) {
172
- operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
173
- }
174
- }
175
- #else
176
- // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
177
- // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
178
- // numerous conversion instructions in GEMM main loop.
179
- arch::device_breakpoint();
180
- #endif
181
- }
182
-
183
- private:
184
- ElementScale const* pointer_;
185
- };
186
-
187
- ////////////////////////////////////////////////////////////////////////////////
188
-
189
- // Specialization for Turing & Ampere
190
- template<
191
- /// Underlying matrix multiply operator (concept: MmaTensorOp)
192
- typename MmaOperator_,
193
- /// Shape of the warp level matrix multiply (concept: GemmShape)
194
- typename Shape_>
195
- class MmaTensorOpDequantizer<
196
- MmaOperator_,
197
- Shape_,
198
- Operand::kB,
199
- half_t,
200
- layout::RowMajor,
201
- 32,
202
- typename platform::enable_if<
203
- MmaOperator_::ArchTag::kMinComputeCapability >= 75
204
- && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
205
-
206
- public:
207
- /// Mma Operator
208
- using MmaOperator = MmaOperator_;
209
-
210
- // The architecture specific mma ooperator being used
211
- using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
212
-
213
- // Mma Instruction Shape
214
- using InstructionShape = typename ArchMmaOperator::Shape;
215
-
216
- // This is the ratio of the load instruction vs the compute instruction.
217
- static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
218
-
219
- /// Type of the scales
220
- using ElementScale = half_t;
221
-
222
- /// Fragment to hold B data before Mma
223
- using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
224
-
225
- // Fragment to hold scale data to apply to B before mma
226
- // We need 1 fp16 per matrix iteration in the N dimension
227
- static constexpr int kColsPerMmaPerThread = 1;
228
- using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
229
-
230
- /// Warp mma shape
231
- using Shape = Shape_;
232
-
233
- /// Layout of the scales in shared memory
234
- using Layout = layout::RowMajor;
235
-
236
- /// TensorRef type for loading element from a tensor
237
- using TensorRef = TensorRef<ElementScale, Layout>;
238
-
239
- CUTLASS_DEVICE
240
- MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
241
- {
242
- const int warp_offset = warp_idx_n * Shape::kN;
243
- const int quad = lane_idx / 4;
244
- const int thread_offset = warp_offset + quad;
245
- pointer_ = smem_scales.data() + thread_offset;
246
- }
247
-
248
- CUTLASS_DEVICE
249
- void load(FragmentScale& scale_frag)
250
- {
251
-
252
- CUTLASS_PRAGMA_UNROLL
253
- for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
254
- scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
255
- }
256
- }
257
-
258
- CUTLASS_DEVICE
259
- void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
260
- {
261
- using _MmaOperandB = typename ArchMmaOperator::FragmentB;
262
- using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
263
- static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
264
- == FragmentDequantizedOperand::kElements,
265
- "");
266
-
267
- multiplies<ExpandedMmaOperandB> mul_op;
268
-
269
- ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
270
- CUTLASS_PRAGMA_UNROLL
271
- for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
272
- operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
273
- }
274
- }
275
-
276
- private:
277
- ElementScale const* pointer_;
278
- };
279
-
280
- ////////////////////////////////////////////////////////////////////////////////
281
-
282
- // Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm
283
- template<
284
- /// Underlying matrix multiply operator (concept: MmaTensorOp)
285
- typename MmaOperator_,
286
- /// Shape of the warp level matrix multiply (concept: GemmShape)
287
- typename Shape_>
288
- class MmaTensorOpDequantizer<
289
- MmaOperator_,
290
- Shape_,
291
- Operand::kB,
292
- half_t,
293
- layout::RowMajor,
294
- 32,
295
- typename platform::enable_if<
296
- platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
297
- && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::RowMajor>::value>::type> {
298
-
299
- public:
300
- static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
301
-
302
- /// Mma Operator
303
- using MmaOperator = MmaOperator_;
304
-
305
- // The architecture specific mma ooperator being used
306
- using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
307
-
308
- // Mma Instruction Shape
309
- using InstructionShape = typename ArchMmaOperator::Shape;
310
-
311
- /// Type of the scales
312
- using ElementScale = half_t;
313
-
314
- /// Fragment to hold B data before Mma
315
- using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
316
-
317
- /// Warp mma shape
318
- using Shape = Shape_;
319
-
320
- // Fragment to hold scale data to apply to B before mma
321
- // Each 32x32x4 matmul uses 8 elements from B.
322
- static constexpr int ColsPerMmaTile = 32;
323
- static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
324
- using FragmentScale = Array<ElementScale, TileNIterations * 8>;
325
- using AccessType = Array<ElementScale, 8>;
326
-
327
- /// Layout of the scales in shared memory
328
- using Layout = layout::RowMajor;
329
-
330
- /// TensorRef type for loading element from a tensor
331
- using TensorRef = TensorRef<ElementScale, Layout>;
332
-
333
- CUTLASS_DEVICE
334
- MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
335
- {
336
- const int warp_offset = warp_idx_n * Shape::kN;
337
- const int base_col = lane_idx & 0xF8;
338
- const int thread_offset = warp_offset + base_col;
339
- pointer_ = smem_scales.data() + thread_offset;
340
- }
341
-
342
- CUTLASS_DEVICE
343
- void load(FragmentScale& scale_frag)
344
- {
345
- AccessType* scale_frag_ptr = reinterpret_cast<AccessType*>(&scale_frag);
346
-
347
- CUTLASS_PRAGMA_UNROLL
348
- for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
349
- // We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
350
- scale_frag_ptr[tile_iter] = *reinterpret_cast<AccessType const*>(pointer_ + ColsPerMmaTile * tile_iter);
351
- }
352
- }
353
-
354
- CUTLASS_DEVICE
355
- void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
356
- {
357
- static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, "");
358
-
359
- multiplies<FragmentDequantizedOperand> mul_op;
360
- operand_frag = mul_op(operand_frag, scale_frag);
361
- }
362
-
363
- private:
364
- ElementScale const* pointer_;
365
- };
366
-
367
- ////////////////////////////////////////////////////////////////////////////////
368
-
369
- // Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm
370
- template<
371
- /// Underlying matrix multiply operator (concept: MmaTensorOp)
372
- typename MmaOperator_,
373
- /// Shape of the warp level matrix multiply (concept: GemmShape)
374
- typename Shape_>
375
- class MmaTensorOpDequantizer<
376
- MmaOperator_,
377
- Shape_,
378
- Operand::kB,
379
- half_t,
380
- layout::RowMajor,
381
- 32,
382
- typename platform::enable_if<
383
- platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
384
- && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
385
-
386
- public:
387
- static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
388
-
389
- /// Mma Operator
390
- using MmaOperator = MmaOperator_;
391
-
392
- // The architecture specific mma ooperator being used
393
- using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
394
-
395
- // Mma Instruction Shape
396
- using InstructionShape = typename ArchMmaOperator::Shape;
397
-
398
- /// Type of the scales
399
- using ElementScale = half_t;
400
-
401
- /// Fragment to hold B data before Mma
402
- using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
403
-
404
- /// Warp mma shape
405
- using Shape = Shape_;
406
-
407
- // Fragment to hold scale data to apply to B before mma
408
- // Each 32x32x4 matmul uses 8 elements from B.
409
- static constexpr int ColsPerMmaTile = 32;
410
- static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
411
- using FragmentScale = Array<ElementScale, TileNIterations * 2>;
412
-
413
- /// Layout of the scales in shared memory
414
- using Layout = layout::RowMajor;
415
-
416
- /// TensorRef type for loading element from a tensor
417
- using TensorRef = TensorRef<ElementScale, Layout>;
418
-
419
- CUTLASS_DEVICE
420
- MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
421
- {
422
- const int warp_offset = warp_idx_n * Shape::kN;
423
- const int base_col = lane_idx & 0xF8 + lane_idx % 4;
424
- const int thread_offset = warp_offset + base_col;
425
- pointer_ = smem_scales.data() + thread_offset;
426
- }
427
-
428
- CUTLASS_DEVICE
429
- void load(FragmentScale& scale_frag)
430
- {
431
- CUTLASS_PRAGMA_UNROLL
432
- for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
433
- // We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
434
- // For col major B, each thread will jump 4 cols to get its next value inside
435
- // of the super mma.
436
- CUTLASS_PRAGMA_UNROLL
437
- for (int mma_iter = 0; mma_iter < 2; ++mma_iter) {
438
- scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter];
439
- }
440
- }
441
- }
442
-
443
- CUTLASS_DEVICE
444
- void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
445
- {
446
- using MmaOperandB = typename ArchMmaOperator::FragmentB;
447
- static constexpr int total_n_mmas = 2 * TileNIterations;
448
- static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, "");
449
-
450
- multiplies<MmaOperandB> mul_op;
451
-
452
- MmaOperandB* operand_frag_ptr = reinterpret_cast<MmaOperandB*>(&operand_frag);
453
- CUTLASS_PRAGMA_UNROLL
454
- for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) {
455
- operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
456
- }
457
- }
458
-
459
- private:
460
- ElementScale const* pointer_;
461
- };
462
-
463
- ////////////////////////////////////////////////////////////////////////////////
464
-
465
- } // namespace warp
466
- } // namespace gemm
467
- } // namespace cutlass
468
-
469
- ////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h DELETED
@@ -1,429 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*!
32
- \file
33
- \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
34
- */
35
-
36
- #pragma once
37
-
38
- #include "cutlass/arch/arch.h"
39
- #include "cutlass/array.h"
40
- #include "cutlass/half.h"
41
- #include "cutlass/numeric_types.h"
42
-
43
- namespace cutlass {
44
-
45
- // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
46
- // bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
47
- // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
48
- // This converter will uninterleave the data and subtract the bias while converting to the result type.
49
- template<typename T, typename S, int N>
50
- struct FastInterleavedAndBiasedNumericArrayConverter {
51
- };
52
-
53
- template<>
54
- struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4> {
55
- using result_type = Array<half_t, 4>;
56
- using source_type = Array<uint8_t, 4>;
57
-
58
- CUTLASS_DEVICE
59
- static result_type convert(source_type const& source)
60
- {
61
- result_type result;
62
-
63
- uint32_t* h = reinterpret_cast<uint32_t*>(&result);
64
- uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
65
-
66
- static constexpr uint32_t mask_for_elt_01 = 0x5250;
67
- static constexpr uint32_t mask_for_elt_23 = 0x5351;
68
- static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
69
- asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
70
- asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
71
-
72
- // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
73
- static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
74
- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
75
- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
76
-
77
- return result;
78
- }
79
-
80
- CUTLASS_DEVICE
81
- result_type operator()(source_type const& s)
82
- {
83
- return convert(s);
84
- }
85
- };
86
-
87
- template<int N>
88
- struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N> {
89
- static constexpr int VEC_WIDTH = 4;
90
- static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
91
-
92
- using result_type = Array<half_t, N>;
93
- using source_type = Array<uint8_t, N>;
94
-
95
- CUTLASS_DEVICE
96
- static result_type convert(source_type const& source)
97
- {
98
- using scalar_result_type = typename result_type::Element;
99
- using scalar_source_type = typename source_type::Element;
100
- FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
101
- convert_vector_;
102
-
103
- result_type result;
104
- using vec_result = Array<scalar_result_type, VEC_WIDTH>;
105
- using vec_source = Array<scalar_source_type, VEC_WIDTH>;
106
-
107
- vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
108
- vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
109
-
110
- CUTLASS_PRAGMA_UNROLL
111
- for (int i = 0; i < N / VEC_WIDTH; ++i) {
112
- result_ptr[i] = convert_vector_(source_ptr[i]);
113
- }
114
-
115
- return result;
116
- }
117
-
118
- CUTLASS_DEVICE
119
- result_type operator()(source_type const& s)
120
- {
121
- return convert(s);
122
- }
123
- };
124
-
125
- template<>
126
- struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4> {
127
- using result_type = Array<bfloat16_t, 4>;
128
- using source_type = Array<uint8_t, 4>;
129
-
130
- CUTLASS_DEVICE
131
- static result_type convert(source_type const& source)
132
- {
133
- result_type result;
134
- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
135
-
136
- uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
137
- uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
138
-
139
- static constexpr uint32_t fp32_base = 0x4B000000;
140
- float fp32_intermediates[4];
141
-
142
- // Construct FP32s, bfloat does not have enough mantissa for IADD trick
143
- uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
144
- fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
145
- fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652);
146
- fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651);
147
- fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
148
-
149
- // Subtract out fp32_base + 128 to make the unsigned integer signed.
150
- CUTLASS_PRAGMA_UNROLL
151
- for (int ii = 0; ii < 4; ++ii) {
152
- fp32_intermediates[ii] -= 8388736.f;
153
- }
154
-
155
- // Truncate the fp32 representation and pack up as bfloat16s.
156
- CUTLASS_PRAGMA_UNROLL
157
- for (int ii = 0; ii < 2; ++ii) {
158
- bf16_result_ptr[ii] =
159
- __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
160
- }
161
- #else
162
- // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
163
- // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
164
- result.clear(); // Suppress compiler warning
165
- arch::device_breakpoint();
166
- #endif
167
- return result;
168
- }
169
-
170
- CUTLASS_DEVICE
171
- result_type operator()(source_type const& s)
172
- {
173
- return convert(s);
174
- }
175
- };
176
-
177
- template<int N>
178
- struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N> {
179
- static constexpr int VEC_WIDTH = 4;
180
- static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
181
-
182
- using result_type = Array<bfloat16_t, N>;
183
- using source_type = Array<uint8_t, N>;
184
-
185
- CUTLASS_DEVICE
186
- static result_type convert(source_type const& source)
187
- {
188
- using scalar_result_type = typename result_type::Element;
189
- using scalar_source_type = typename source_type::Element;
190
- FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
191
- convert_vector_;
192
-
193
- result_type result;
194
- using vec_result = Array<scalar_result_type, VEC_WIDTH>;
195
- using vec_source = Array<scalar_source_type, VEC_WIDTH>;
196
-
197
- vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
198
- vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
199
-
200
- CUTLASS_PRAGMA_UNROLL
201
- for (int i = 0; i < N / VEC_WIDTH; ++i) {
202
- result_ptr[i] = convert_vector_(source_ptr[i]);
203
- }
204
-
205
- return result;
206
- }
207
-
208
- CUTLASS_DEVICE
209
- result_type operator()(source_type const& s)
210
- {
211
- return convert(s);
212
- }
213
- };
214
-
215
- template<>
216
- struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8> {
217
- using result_type = Array<half_t, 8>;
218
- using source_type = Array<uint4b_t, 8>;
219
-
220
- CUTLASS_DEVICE
221
- static result_type convert(source_type const& source)
222
- {
223
- result_type result;
224
-
225
- uint32_t* h = reinterpret_cast<uint32_t*>(&result);
226
- uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
227
-
228
- // First, we extract the i4s and construct an intermediate fp16 number.
229
- static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
230
- static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
231
- static constexpr uint32_t TOP_MASK = 0x00f000f0;
232
- static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
233
-
234
- // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
235
- // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
236
- // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
237
- // elt_67 to fp16 without having to shift them to the bottom bits before hand.
238
-
239
- // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
240
- // immediately before required.
241
- const uint32_t top_i4s = i4s >> 8;
242
- // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
243
- asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
244
- : "=r"(h[0])
245
- : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
246
- // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
247
- asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
248
- : "=r"(h[1])
249
- : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
250
- // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
251
- asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
252
- : "=r"(h[2])
253
- : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
254
- // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
255
- asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
256
- : "=r"(h[3])
257
- : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
258
-
259
- // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
260
- // half2 ctor. In this case, I chose performance reliability over code readability.
261
-
262
- // This is the half2 {1032, 1032} represented as an integer.
263
- static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
264
- // This is the half2 {1 / 16, 1 / 16} represented as an integer.
265
- static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
266
- // This is the half2 {-72, -72} represented as an integer.
267
- static constexpr uint32_t NEG_72 = 0xd480d480;
268
-
269
- // Finally, we construct the output numbers.
270
- // Convert elt_01
271
- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
272
- // Convert elt_23
273
- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
274
- // Convert elt_45
275
- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
276
- // Convert elt_67
277
- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
278
-
279
- return result;
280
- }
281
-
282
- CUTLASS_DEVICE
283
- result_type operator()(source_type const& s)
284
- {
285
- return convert(s);
286
- }
287
- };
288
-
289
- template<int N>
290
- struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N> {
291
- static constexpr int VEC_WIDTH = 8;
292
- static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
293
-
294
- using result_type = Array<half_t, N>;
295
- using source_type = Array<uint4b_t, N>;
296
-
297
- CUTLASS_DEVICE
298
- static result_type convert(source_type const& source)
299
- {
300
- using scalar_result_type = typename result_type::Element;
301
- using scalar_source_type = typename source_type::Element;
302
- FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
303
- convert_vector_;
304
-
305
- result_type result;
306
- using vec_result = Array<scalar_result_type, VEC_WIDTH>;
307
- using vec_source = Array<scalar_source_type, VEC_WIDTH>;
308
-
309
- vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
310
- vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
311
-
312
- CUTLASS_PRAGMA_UNROLL
313
- for (int i = 0; i < N / VEC_WIDTH; ++i) {
314
- result_ptr[i] = convert_vector_(source_ptr[i]);
315
- }
316
-
317
- return result;
318
- }
319
-
320
- CUTLASS_DEVICE
321
- result_type operator()(source_type const& s)
322
- {
323
- return convert(s);
324
- }
325
- };
326
-
327
- template<>
328
- struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8> {
329
- using result_type = Array<bfloat16_t, 8>;
330
- using source_type = Array<uint4b_t, 8>;
331
-
332
- CUTLASS_DEVICE
333
- static result_type convert(source_type const& source)
334
- {
335
- result_type result;
336
- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
337
-
338
- uint32_t* h = reinterpret_cast<uint32_t*>(&result);
339
- uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
340
-
341
- // First, we extract the i4s and construct an intermediate fp16 number.
342
- static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
343
- static constexpr uint32_t MASK = 0x000f000f;
344
- static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
345
-
346
- // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
347
- // No shift needed for first item.
348
- uint32_t i4s = source_i4s;
349
- asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
350
- : "=r"(h[0])
351
- : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
352
- CUTLASS_PRAGMA_UNROLL
353
- for (int ii = 1; ii < result_type::kElements / 2; ++ii) {
354
- i4s >>= sizeof_bits<typename source_type::Element>::value;
355
- // (i4s & 0x000f000f) | 0x43004300
356
- asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
357
- : "=r"(h[ii])
358
- : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
359
- }
360
-
361
- // This is the BF16 {-136, -136} represented as an integer.
362
- static constexpr uint32_t BF16_BIAS = 0xC308C308;
363
- static constexpr uint32_t BF16_ONE = 0x3F803F80;
364
-
365
- // Finally, we construct the output numbers.
366
- CUTLASS_PRAGMA_UNROLL
367
- for (int ii = 0; ii < result_type::kElements / 2; ++ii) {
368
- // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
369
- asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
370
- }
371
- #else
372
- // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
373
- // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
374
- arch::device_breakpoint();
375
- result.clear(); // Suppress compiler warning.
376
- #endif
377
- return result;
378
- }
379
-
380
- CUTLASS_DEVICE
381
- result_type operator()(source_type const& s)
382
- {
383
- return convert(s);
384
- }
385
- };
386
-
387
- template<int N>
388
- struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N> {
389
- static constexpr int VEC_WIDTH = 8;
390
- static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
391
-
392
- using result_type = Array<bfloat16_t, N>;
393
- using source_type = Array<uint4b_t, N>;
394
-
395
- CUTLASS_DEVICE
396
- static result_type convert(source_type const& source)
397
- {
398
- using scalar_result_type = typename result_type::Element;
399
- using scalar_source_type = typename source_type::Element;
400
- FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
401
- convert_vector_;
402
-
403
- result_type result;
404
- using vec_result = Array<scalar_result_type, VEC_WIDTH>;
405
- using vec_source = Array<scalar_source_type, VEC_WIDTH>;
406
-
407
- vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
408
- vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
409
-
410
- CUTLASS_PRAGMA_UNROLL
411
- for (int i = 0; i < N / VEC_WIDTH; ++i) {
412
- result_ptr[i] = convert_vector_(source_ptr[i]);
413
- }
414
-
415
- return result;
416
- }
417
-
418
- CUTLASS_DEVICE
419
- result_type operator()(source_type const& s)
420
- {
421
- return convert(s);
422
- }
423
- };
424
-
425
- /////////////////////////////////////////////////////////////////////////////////////////////////
426
-
427
- } // namespace cutlass
428
-
429
- /////////////////////////////////////////////////////////////////////////////////////////////////
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h DELETED
@@ -1,61 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- /*! \file
32
- \brief Defines new layouts needed for MoE
33
- */
34
- #pragma once
35
-
36
- #include "cutlass/cutlass.h"
37
- #include "cutlass/fast_math.h"
38
- #include "cutlass/matrix_coord.h"
39
- #include "cutlass/pitch_linear_coord.h"
40
-
41
- namespace cutlass {
42
- namespace layout {
43
-
44
- template<int RowsPerTile, int ColumnsInterleaved>
45
- class ColumnMajorTileInterleave {
46
- static constexpr int kRowsPerTile = RowsPerTile;
47
- static constexpr int kColumnsInterleaved = ColumnsInterleaved;
48
- };
49
-
50
- template<class T>
51
- struct IsColumnMajorTileInterleave {
52
- static constexpr bool value = false;
53
- };
54
-
55
- template<int U, int V>
56
- struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>> {
57
- static constexpr bool value = true;
58
- };
59
-
60
- } // namespace layout
61
- } // namespace cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_kernels/cutlass_heuristic.cu DELETED
@@ -1,208 +0,0 @@
1
- /*
2
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #include "cutlass_heuristic.h"
18
- #include "cutlass/gemm/gemm.h"
19
- #include <cuda_runtime_api.h>
20
-
21
- #include <vector>
22
- #include <stdexcept>
23
-
24
- namespace fastertransformer {
25
-
26
- struct TileShape {
27
- int m;
28
- int n;
29
- };
30
-
31
- TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
32
- {
33
- switch (tile_config) {
34
- case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
35
- return TileShape{32, 128};
36
- case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
37
- case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
38
- return TileShape{64, 128};
39
- case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
40
- case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
41
- case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
42
- return TileShape{128, 128};
43
- default:
44
- throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config");
45
- }
46
- }
47
-
48
- bool is_valid_split_k_factor(const int64_t m,
49
- const int64_t n,
50
- const int64_t k,
51
- const TileShape tile_shape,
52
- const int split_k_factor,
53
- const size_t workspace_bytes,
54
- const bool is_weight_only)
55
- {
56
-
57
- // All tile sizes have a k_tile of 64.
58
- static constexpr int k_tile = 64;
59
-
60
- // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k
61
- if (is_weight_only) {
62
- if ((k % k_tile) != 0) {
63
- return false;
64
- }
65
-
66
- if ((k % split_k_factor) != 0) {
67
- return false;
68
- }
69
-
70
- const int k_elements_per_split = k / split_k_factor;
71
- if ((k_elements_per_split % k_tile) != 0) {
72
- return false;
73
- }
74
- }
75
-
76
- // Check that the workspace has sufficient space for this split-k factor
77
- const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
78
- const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
79
- const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
80
-
81
- if (required_ws_bytes > workspace_bytes) {
82
- return false;
83
- }
84
-
85
- return true;
86
- }
87
-
88
- std::vector<CutlassTileConfig> get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only)
89
- {
90
-
91
- std::vector<CutlassTileConfig> simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
92
-
93
- std::vector<CutlassTileConfig> square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
94
- CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
95
- CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64};
96
-
97
- std::vector<CutlassTileConfig> quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
98
- CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
99
- CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
100
-
101
- const std::vector<CutlassTileConfig> allowed_configs = is_weight_only ? quant_B_configs : square_configs;
102
- return simt_configs_only ? simt_configs : allowed_configs;
103
- }
104
-
105
- std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only)
106
- {
107
- std::vector<CutlassTileConfig> tiles = get_candidate_tiles(is_weight_only, simt_configs_only);
108
-
109
- std::vector<CutlassGemmConfig> candidate_configs;
110
- const int min_stages = 2;
111
- const int max_stages = sm >= 80 ? 4 : 2;
112
-
113
- for (const auto& tile_config : tiles) {
114
- for (int stages = min_stages; stages <= max_stages; ++stages) {
115
- CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
116
- candidate_configs.push_back(config);
117
- }
118
- }
119
-
120
- return candidate_configs;
121
- }
122
-
123
- CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
124
- const std::vector<int>& occupancies,
125
- const int64_t m,
126
- const int64_t n,
127
- const int64_t k,
128
- const int64_t num_experts,
129
- const int split_k_limit,
130
- const size_t workspace_bytes,
131
- const int multi_processor_count,
132
- const int is_weight_only)
133
- {
134
-
135
- if (occupancies.size() != candidate_configs.size()) {
136
- throw std::runtime_error("[FT Error][estimate_best_config_from_occupancies] occpancies and "
137
- "candidate configs vectors must have equal length.");
138
- }
139
-
140
- CutlassGemmConfig best_config;
141
- // Score will be [0, 1]. The objective is to minimize this score.
142
- // It represents the fraction of SM resources unused in the last wave.
143
- float config_score = 1.0f;
144
- int config_waves = INT_MAX;
145
- int current_m_tile = 0;
146
-
147
- const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
148
- for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
149
- CutlassGemmConfig candidate_config = candidate_configs[ii];
150
- TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config);
151
- int occupancy = occupancies[ii];
152
-
153
- if (occupancy == 0) {
154
- continue;
155
- }
156
-
157
- // Keep small tile sizes when possible.
158
- if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile
159
- && current_m_tile < tile_shape.m) {
160
- continue;
161
- }
162
-
163
- const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
164
- const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
165
-
166
- for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) {
167
- if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) {
168
- const int ctas_per_wave = occupancy * multi_processor_count;
169
- const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
170
-
171
- const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
172
- const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
173
- const float current_score = float(num_waves_total) - num_waves_fractional;
174
-
175
- const float score_slack = 0.1f;
176
- if (current_score < config_score
177
- || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) {
178
- config_score = current_score;
179
- config_waves = num_waves_total;
180
- SplitKStyle split_style =
181
- split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
182
- best_config = CutlassGemmConfig{
183
- candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
184
- current_m_tile = tile_shape.m;
185
- }
186
- else if (current_score == config_score
187
- && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor
188
- || current_m_tile < tile_shape.m)) {
189
- // Prefer deeper pipeline or smaller split-k
190
- SplitKStyle split_style =
191
- split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
192
- best_config = CutlassGemmConfig{
193
- candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
194
- current_m_tile = tile_shape.m;
195
- config_waves = num_waves_total;
196
- }
197
- }
198
- }
199
- }
200
-
201
- if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) {
202
- throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config.");
203
- }
204
-
205
- return best_config;
206
- }
207
-
208
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_kernels/cutlass_heuristic.h DELETED
@@ -1,39 +0,0 @@
1
- /*
2
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #pragma once
18
-
19
- #include <vector>
20
- #include <cstddef>
21
- #include <cstdint>
22
- #include "cutlass_extensions/ft_gemm_configs.h"
23
-
24
- namespace fastertransformer {
25
-
26
- std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only);
27
-
28
- CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
29
- const std::vector<int>& occupancies,
30
- const int64_t m,
31
- const int64_t n,
32
- const int64_t k,
33
- const int64_t num_experts,
34
- const int split_k_limit,
35
- const size_t workspace_bytes,
36
- const int multi_processor_count,
37
- const int is_weight_only);
38
-
39
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_kernels/cutlass_preprocessors.cc DELETED
@@ -1,703 +0,0 @@
1
- /*
2
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
- #include "cutlass_preprocessors.h"
17
- #include "cuda_utils.h"
18
- #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
19
-
20
- #include <vector>
21
-
22
- namespace fastertransformer {
23
-
24
- int get_bits_in_quant_type(QuantType quant_type) {
25
- switch (quant_type) {
26
- case QuantType::INT8_WEIGHT_ONLY:
27
- return 8;
28
- case QuantType::PACKED_INT4_WEIGHT_ONLY:
29
- return 4;
30
- default:
31
- return -1;
32
- }
33
- }
34
-
35
- struct LayoutDetails {
36
- enum class Layout {
37
- UNKNOWN,
38
- ROW_MAJOR,
39
- COLUMN_MAJOR
40
- };
41
-
42
- Layout layoutB = Layout::UNKNOWN;
43
- int rows_per_column_tile = 1;
44
- int columns_interleaved = 1;
45
-
46
- bool uses_imma_ldsm = false;
47
- };
48
-
49
- template<typename Layout>
50
- struct getLayoutDetails {
51
- };
52
-
53
- template<>
54
- struct getLayoutDetails<cutlass::layout::RowMajor> {
55
- LayoutDetails operator()()
56
- {
57
- LayoutDetails layout_details;
58
- layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR;
59
- return layout_details;
60
- }
61
- };
62
-
63
- template<>
64
- struct getLayoutDetails<cutlass::layout::ColumnMajor> {
65
- LayoutDetails operator()()
66
- {
67
- LayoutDetails layout_details;
68
- layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR;
69
- return layout_details;
70
- }
71
- };
72
-
73
- template<int RowsPerTile, int ColumnsInterleaved>
74
- struct getLayoutDetails<cutlass::layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>> {
75
- LayoutDetails operator()()
76
- {
77
- LayoutDetails layout_details;
78
- layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR;
79
- layout_details.rows_per_column_tile = RowsPerTile;
80
- layout_details.columns_interleaved = ColumnsInterleaved;
81
- return layout_details;
82
- }
83
- };
84
-
85
- template<typename cutlassArch, typename TypeB>
86
- LayoutDetails getLayoutDetailsForArchAndQuantType()
87
- {
88
-
89
- using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB<TypeB, cutlassArch>;
90
- using LayoutB = typename CompileTraits::Layout;
91
- using MmaOperator = typename CompileTraits::Operator;
92
- LayoutDetails details = getLayoutDetails<LayoutB>()();
93
- details.uses_imma_ldsm = std::is_same<MmaOperator, cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA>::value;
94
- return details;
95
- }
96
-
97
- template<typename cutlassArch>
98
- LayoutDetails getLayoutDetailsForArch(QuantType quant_type)
99
- {
100
- LayoutDetails details;
101
- if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
102
- details = getLayoutDetailsForArchAndQuantType<cutlassArch, uint8_t>();
103
- }
104
- else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
105
- details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::uint4b_t>();
106
- }
107
- else {
108
- FT_CHECK_WITH_INFO(false, "Unsupported quantization type");
109
- }
110
- return details;
111
- }
112
-
113
- LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
114
- {
115
- if (arch >= 70 && arch < 75) {
116
- return getLayoutDetailsForArch<cutlass::arch::Sm70>(quant_type);
117
- }
118
- else if (arch >= 75 && arch < 80) {
119
- return getLayoutDetailsForArch<cutlass::arch::Sm75>(quant_type);
120
- }
121
- else if (arch >= 80 && arch < 90) {
122
- return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type);
123
- }
124
- else {
125
- FT_CHECK_WITH_INFO(false, "Unsupported Arch");
126
- return LayoutDetails();
127
- }
128
- }
129
-
130
- // Permutes the rows of B for Turing and Ampere. Throws an error for other
131
- // architectures. The data is permuted such that: For int8, each group of 16
132
- // rows is permuted using the map below:
133
- // 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
134
- // For int4, each group of 32 rows is permuted using the map below:
135
- // 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22
136
- // 23 30 31
137
- void permute_B_rows_for_mixed_gemm(int8_t *permuted_quantized_tensor,
138
- const int8_t *quantized_tensor,
139
- const std::vector<size_t> &shape,
140
- QuantType quant_type,
141
- const int64_t arch_version) {
142
- const size_t num_rows = shape[0];
143
- const size_t num_cols = shape[1];
144
-
145
- const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
146
- const int K = 16 / BITS_PER_ELT;
147
- const int ELTS_PER_REG = 32 / BITS_PER_ELT;
148
-
149
- const uint32_t *input_byte_ptr =
150
- reinterpret_cast<const uint32_t *>(quantized_tensor);
151
- uint32_t *output_byte_ptr =
152
- reinterpret_cast<uint32_t *>(permuted_quantized_tensor);
153
-
154
- int MMA_SHAPE_N = 8;
155
- int B_ROWS_PER_MMA = 8 * K;
156
- const int elts_in_int32 = 32 / BITS_PER_ELT;
157
-
158
- const int num_vec_cols = num_cols / elts_in_int32;
159
-
160
- FT_CHECK_WITH_INFO(arch_version >= 75,
161
- "Unsupported Arch. Pre-volta not supported. Column "
162
- "interleave not needed on Volta.");
163
-
164
- FT_CHECK_WITH_INFO(num_rows % B_ROWS_PER_MMA == 0,
165
- fmtstr("Invalid shape for quantized tensor. Number of "
166
- "rows of quantized matrix must be a multiple of %d",
167
- B_ROWS_PER_MMA));
168
-
169
- FT_CHECK_WITH_INFO(
170
- num_cols % MMA_SHAPE_N == 0,
171
- fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number "
172
- "of cols must be a multiple of %d.",
173
- MMA_SHAPE_N));
174
-
175
- // The code is written as below so it works for both int8
176
- // and packed int4.
177
- for (size_t base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) {
178
- for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) {
179
-
180
- for (int write_col = 0; write_col < num_vec_cols; ++write_col) {
181
- const int write_row = base_row + tile_row;
182
- const int tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) +
183
- tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
184
- const int read_row = base_row + tile_read_row;
185
- const int read_col = write_col;
186
-
187
- const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col;
188
- const int64_t write_offset =
189
- int64_t(write_row) * num_vec_cols + write_col;
190
-
191
- output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
192
- }
193
- }
194
- }
195
- }
196
-
197
- // We need to use this transpose to correctly handle packed int4 and int8 data
198
- // The reason this code is relatively complex is that the "trivial" loops took a
199
- // substantial amount of time to transpose leading to long preprocessing times.
200
- // This seemed to be a big issue for relatively large models.
201
- template <QuantType quant_type>
202
- void subbyte_transpose_impl(int8_t *transposed_quantized_tensor,
203
- const int8_t *quantized_tensor,
204
- const std::vector<size_t> &shape) {
205
- const int bits_per_elt = get_bits_in_quant_type(quant_type);
206
- const size_t num_rows = shape[0];
207
- const size_t num_cols = shape[1];
208
-
209
- const size_t col_bytes = num_cols * bits_per_elt / 8;
210
- const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
211
-
212
- const uint8_t *input_byte_ptr =
213
- reinterpret_cast<const uint8_t *>(quantized_tensor);
214
- uint8_t *output_byte_ptr =
215
- reinterpret_cast<uint8_t *>(transposed_quantized_tensor);
216
-
217
- static constexpr int ELTS_PER_BYTE =
218
- quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2;
219
-
220
- static constexpr int M_TILE_L1 = 64;
221
- static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE;
222
- uint8_t cache_buf[M_TILE_L1][N_TILE_L1];
223
-
224
- static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1);
225
-
226
- // We assume the dims are a multiple of vector width. Our kernels only handle
227
- // dims which are multiples of 64 for weight-only quantization. As a result,
228
- // this seemed like a reasonable tradeoff because it allows GCC to emit vector
229
- // instructions.
230
- FT_CHECK_WITH_INFO(
231
- !(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH),
232
- fmtstr("Number of bytes for rows and cols must be a multiple of %d. "
233
- "However, num_rows_bytes = %ld and num_col_bytes = %d.",
234
- VECTOR_WIDTH, col_bytes_trans, col_bytes));
235
-
236
- for (size_t row_tile_start = 0; row_tile_start < num_rows;
237
- row_tile_start += M_TILE_L1) {
238
- for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes;
239
- col_tile_start_byte += N_TILE_L1) {
240
-
241
- const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
242
- const int col_limit =
243
- std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
244
-
245
- for (int ii = 0; ii < M_TILE_L1; ++ii) {
246
- const int row = row_tile_start + ii;
247
-
248
- for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
249
- const int col = col_tile_start_byte + jj;
250
-
251
- const size_t logical_src_offset = row * col_bytes + col;
252
-
253
- if (row < row_limit && col < col_limit) {
254
- for (int v = 0; v < VECTOR_WIDTH; ++v) {
255
- cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v];
256
- }
257
- }
258
- }
259
- }
260
-
261
- if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
262
- for (int ii = 0; ii < M_TILE_L1; ++ii) {
263
- for (int jj = ii + 1; jj < N_TILE_L1; ++jj) {
264
- std::swap(cache_buf[ii][jj], cache_buf[jj][ii]);
265
- }
266
- }
267
- } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
268
-
269
- for (int ii = 0; ii < M_TILE_L1; ++ii) {
270
- // Using M_TILE_L1 here is deliberate since we assume that the cache
271
- // tile is square in the number of elements (not necessarily the
272
- // number of bytes).
273
- for (int jj = ii + 1; jj < M_TILE_L1; ++jj) {
274
- const int ii_byte = ii / ELTS_PER_BYTE;
275
- const int ii_bit_offset = ii % ELTS_PER_BYTE;
276
-
277
- const int jj_byte = jj / ELTS_PER_BYTE;
278
- const int jj_bit_offset = jj % ELTS_PER_BYTE;
279
-
280
- uint8_t src_elt =
281
- 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
282
- uint8_t tgt_elt =
283
- 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
284
-
285
- cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset));
286
- cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset));
287
-
288
- cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset));
289
- cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset));
290
- }
291
- }
292
- } else {
293
- FT_CHECK_WITH_INFO(false, "Unsupported quantization type.");
294
- }
295
-
296
- const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
297
- const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
298
-
299
- const int row_limit_trans =
300
- std::min(row_tile_start_trans + M_TILE_L1, num_cols);
301
- const int col_limit_trans =
302
- std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
303
-
304
- for (int ii = 0; ii < M_TILE_L1; ++ii) {
305
- const int row = row_tile_start_trans + ii;
306
- for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
307
- const int col = col_tile_start_byte_trans + jj;
308
-
309
- const size_t logical_tgt_offset = row * col_bytes_trans + col;
310
-
311
- if (row < row_limit_trans && col < col_limit_trans) {
312
- for (int v = 0; v < VECTOR_WIDTH; ++v) {
313
- output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v];
314
- }
315
- }
316
- }
317
- }
318
- }
319
- }
320
- }
321
-
322
- void subbyte_transpose(int8_t *transposed_quantized_tensor,
323
- const int8_t *quantized_tensor,
324
- const std::vector<size_t> &shape, QuantType quant_type) {
325
-
326
- if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
327
- subbyte_transpose_impl<QuantType::INT8_WEIGHT_ONLY>(
328
- transposed_quantized_tensor, quantized_tensor, shape);
329
- } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
330
- subbyte_transpose_impl<QuantType::PACKED_INT4_WEIGHT_ONLY>(
331
- transposed_quantized_tensor, quantized_tensor, shape);
332
- } else {
333
- FT_CHECK_WITH_INFO(false, "Invalid quant_tye");
334
- }
335
- }
336
-
337
- void add_bias_and_interleave_int8s_inplace(int8_t *int8_tensor,
338
- const size_t num_elts) {
339
- for (size_t ii = 0; ii < num_elts; ++ii) {
340
- int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128);
341
- }
342
-
343
- // Step 2 will transform the layout of a 32-bit register in CUDA in order to
344
- // match the int4 layout. This has no performance benefit and is purely so
345
- // that int4 and int8 have the same layout. Pictorially, this does the
346
- // following: bit 32 0
347
- // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits)
348
- //
349
- // And it will rearrange the output 32 bit register to be the following:
350
- // bit 32 0
351
- // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
352
-
353
- FT_CHECK_WITH_INFO(num_elts % 4 == 0, "Dimensions of int8 tensor must be a "
354
- "multiple of 4 for register relayout");
355
- for (size_t base = 0; base < num_elts; base += 4) {
356
- std::swap(int8_tensor[base + 1], int8_tensor[base + 2]);
357
- }
358
- }
359
-
360
- void add_bias_and_interleave_int4s_inplace(int8_t *packed_int4_tensor,
361
- const size_t num_elts) {
362
- const size_t num_bytes = num_elts / 2;
363
-
364
- // Step 1 will be to transform all the int4s to unsigned in order to make the
365
- // dequantize take as little instructions as possible in the CUDA code.
366
- for (size_t ii = 0; ii < num_bytes; ++ii) {
367
- int8_t transformed_packed_int4s = 0;
368
- int8_t transformed_first_elt =
369
- (int8_t(packed_int4_tensor[ii] << 4) >> 4) +
370
- 8; // The double shift here is to ensure sign extension
371
- int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8;
372
-
373
- FT_CHECK_WITH_INFO(transformed_first_elt >= 0 &&
374
- transformed_first_elt <= 15,
375
- "Illegal result for int4 transform (first elt)");
376
- FT_CHECK_WITH_INFO(transformed_second_elt >= 0 &&
377
- transformed_second_elt <= 15,
378
- "Illegal result for int4 transform (second elt)");
379
-
380
- // We don't need to mask in these ops since everything should be in the
381
- // range 0-15
382
- transformed_packed_int4s |= transformed_first_elt;
383
- transformed_packed_int4s |= (transformed_second_elt << 4);
384
- packed_int4_tensor[ii] = transformed_packed_int4s;
385
- }
386
-
387
- // Step 2 will transform the layout of a 32-bit register in CUDA in order to
388
- // minimize the number of shift & logical instructions That are needed to
389
- // extract the int4s in the GEMM main loop. Pictorially, the loop below will
390
- // do the following: Take as input a 32 bit register with layout: bit 32 0
391
- // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt
392
- // occupies 4 bits)
393
- //
394
- // And it will rearrange the output 32 bit register to be the following:
395
- // bit 32 0
396
- // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt
397
- // occupies 4 bits)
398
-
399
- FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a "
400
- "multiple of 8 for register relayout");
401
- const size_t num_registers = num_bytes / 4;
402
-
403
- uint32_t *register_ptr = reinterpret_cast<uint32_t *>(packed_int4_tensor);
404
- for (size_t ii = 0; ii < num_registers; ++ii) {
405
- const uint32_t current_register = register_ptr[ii];
406
- uint32_t transformed_register = 0;
407
-
408
- for (int dest_idx = 0; dest_idx < 8; ++dest_idx) {
409
- const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
410
- const int src_shift = 4 * src_idx;
411
- const int dest_shift = 4 * dest_idx;
412
-
413
- const uint32_t src_bits = (current_register >> src_shift) & 0xF;
414
- transformed_register |= (src_bits << dest_shift);
415
- }
416
- register_ptr[ii] = transformed_register;
417
- }
418
- }
419
-
420
- void add_bias_and_interleave_quantized_tensor_inplace(int8_t *tensor,
421
- const size_t num_elts,
422
- QuantType quant_type) {
423
- if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
424
- add_bias_and_interleave_int8s_inplace(tensor, num_elts);
425
- } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
426
- add_bias_and_interleave_int4s_inplace(tensor, num_elts);
427
- } else {
428
- FT_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving.");
429
- }
430
- }
431
-
432
- void interleave_column_major_tensor(int8_t *interleaved_quantized_tensor,
433
- const int8_t *quantized_tensor,
434
- const std::vector<size_t> &shape,
435
- QuantType quant_type,
436
- LayoutDetails details) {
437
- // We only want to run this step for weight only quant.
438
- FT_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY ||
439
- quant_type == QuantType::INT8_WEIGHT_ONLY);
440
- FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D");
441
-
442
- const size_t num_rows = shape[0];
443
- const size_t num_cols = shape[1];
444
-
445
- const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
446
- const int elts_in_int32 = 32 / BITS_PER_ELT;
447
-
448
- const int rows_per_tile = details.rows_per_column_tile;
449
-
450
- FT_CHECK_WITH_INFO(!(num_rows % elts_in_int32),
451
- fmtstr("The number of rows must be a multiple of %d but "
452
- "the number of rows is %d.",
453
- elts_in_int32, num_rows));
454
-
455
- FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile),
456
- fmtstr("The number of columns must be a multiple of %d "
457
- "but the number of columns is %ld",
458
- rows_per_tile, num_cols));
459
-
460
- const uint32_t *input_byte_ptr =
461
- reinterpret_cast<const uint32_t *>(quantized_tensor);
462
- uint32_t *output_byte_ptr =
463
- reinterpret_cast<uint32_t *>(interleaved_quantized_tensor);
464
-
465
- FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile),
466
- fmtstr("The number of columns must be a multiple of %d "
467
- "but the number of columns is %d.",
468
- rows_per_tile, num_cols));
469
-
470
- const int num_vec_rows = num_rows / elts_in_int32;
471
- const int vec_rows_per_tile = rows_per_tile / elts_in_int32;
472
- const int interleave = details.columns_interleaved;
473
-
474
- for (size_t read_col = 0; read_col < num_cols; ++read_col) {
475
- const auto write_col = read_col / interleave;
476
- for (int base_vec_row = 0; base_vec_row < num_vec_rows;
477
- base_vec_row += vec_rows_per_tile) {
478
- for (int vec_read_row = base_vec_row;
479
- vec_read_row <
480
- std::min(num_vec_rows, base_vec_row + vec_rows_per_tile);
481
- ++vec_read_row) {
482
- const int64_t vec_write_row =
483
- interleave * base_vec_row +
484
- vec_rows_per_tile * (read_col % interleave) +
485
- vec_read_row % vec_rows_per_tile;
486
-
487
- const int64_t read_offset =
488
- int64_t(read_col) * num_vec_rows + vec_read_row;
489
- const int64_t write_offset =
490
- int64_t(write_col) * num_vec_rows * interleave + vec_write_row;
491
- output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
492
- }
493
- }
494
- }
495
- }
496
-
497
- void preprocess_weights_for_mixed_gemm(int8_t *preprocessed_quantized_weight,
498
- const int8_t *row_major_quantized_weight,
499
- const std::vector<size_t> &shape,
500
- QuantType quant_type, int arch) {
501
- LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch);
502
-
503
- FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D");
504
-
505
- size_t num_elts = 1;
506
- for (const auto &dim : shape) {
507
- num_elts *= dim;
508
- }
509
-
510
- const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8;
511
-
512
- std::vector<int8_t> src_buf(num_bytes);
513
- std::vector<int8_t> dst_buf(num_bytes);
514
- std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin());
515
-
516
- // Works on row major data, so issue this permutation first.
517
- if (details.uses_imma_ldsm) {
518
- permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch);
519
- src_buf.swap(dst_buf);
520
- }
521
-
522
- if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) {
523
- subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type);
524
- src_buf.swap(dst_buf);
525
- }
526
-
527
- if (details.columns_interleaved > 1) {
528
- interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details);
529
- src_buf.swap(dst_buf);
530
- }
531
-
532
- add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
533
- std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight);
534
- }
535
-
536
- void preprocess_weights(int8_t *preprocessed_quantized_weight,
537
- const int8_t *row_major_quantized_weight, size_t rows,
538
- size_t cols, bool is_int4, int arch) {
539
- QuantType qtype = is_int4 ? QuantType::PACKED_INT4_WEIGHT_ONLY
540
- : QuantType::INT8_WEIGHT_ONLY;
541
- preprocess_weights_for_mixed_gemm(preprocessed_quantized_weight,
542
- row_major_quantized_weight, {rows, cols},
543
- qtype, arch);
544
- }
545
-
546
- /*
547
- Arguments:
548
- input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16.
549
-
550
- quant_type - the type of the output quantization weight.
551
-
552
- This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the
553
- zero-point is zero and will automatically construct the scales.
554
-
555
- It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is
556
- viewed as a stack of matrices and a scale is produced for each column of every matrix.
557
-
558
- Outputs
559
- processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM
560
- unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking.
561
- scale_ptr - scales for the quantized weight.
562
-
563
- Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data
564
- layout may not make sense if printed.
565
-
566
- Shapes:
567
- quant_type == int8:
568
- If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n]
569
- If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n]
570
- quant_type == int4:
571
- If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n]
572
- If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape
573
- [b,n]
574
-
575
- The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the
576
- reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind
577
- of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors
578
- must have a dimension of 1, which breaks the semantics we need for batched weights.
579
- */
580
-
581
- template<typename ComputeType, typename WeightType>
582
- void symmetric_quantize(int8_t* processed_quantized_weight,
583
- int8_t* unprocessed_quantized_weight,
584
- ComputeType* scale_ptr,
585
- const WeightType* input_weight_ptr,
586
- const std::vector<size_t>& shape,
587
- QuantType quant_type)
588
- {
589
-
590
- FT_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL");
591
- FT_CHECK_WITH_INFO(scale_ptr, "Scale output pointer is NULL");
592
- FT_CHECK_WITH_INFO(input_weight_ptr, "Input weight pointer is NULL");
593
-
594
- FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
595
- const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
596
- const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
597
- const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
598
-
599
- const int bits_in_type = get_bits_in_quant_type(quant_type);
600
- const int bytes_per_out_col = num_cols * bits_in_type / 8;
601
-
602
- std::vector<int8_t> weight_buf;
603
- if (unprocessed_quantized_weight == nullptr) {
604
- weight_buf.resize(num_experts * num_rows * num_cols);
605
- unprocessed_quantized_weight = weight_buf.data();
606
- }
607
-
608
- const int input_mat_size = num_rows * num_cols;
609
- const int quantized_mat_size = num_rows * bytes_per_out_col;
610
- const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1));
611
-
612
- std::vector<float> per_col_max(num_cols);
613
-
614
- for (int expert = 0; expert < num_experts; ++expert) {
615
- const WeightType* current_weight = input_weight_ptr + expert * input_mat_size;
616
- int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size;
617
-
618
- // First we find the per column max for this expert weight.
619
- for (int jj = 0; jj < num_cols; ++jj) {
620
- per_col_max[jj] = 0.f;
621
- }
622
-
623
- for (int ii = 0; ii < num_rows; ++ii) {
624
- const WeightType* current_weight_row = current_weight + ii * num_cols;
625
- for (int jj = 0; jj < num_cols; ++jj) {
626
- per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj])));
627
- }
628
- }
629
-
630
- // Then, we construct the scales
631
- ComputeType* current_scales = scale_ptr + expert * num_cols;
632
- for (int jj = 0; jj < num_cols; ++jj) {
633
- per_col_max[jj] *= quant_range_scale;
634
- current_scales[jj] = ComputeType(per_col_max[jj]);
635
- }
636
-
637
- // Finally, construct the weights.
638
- for (int ii = 0; ii < num_rows; ++ii) {
639
- int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col;
640
- const WeightType* current_weight_row = current_weight + ii * num_cols;
641
- for (int jj = 0; jj < bytes_per_out_col; ++jj) {
642
-
643
- if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
644
- const float col_scale = per_col_max[jj];
645
- const float weight_elt = float(current_weight_row[jj]);
646
- const float scaled_weight = round(weight_elt / col_scale);
647
- const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight)));
648
- current_quantized_weight_row[jj] = clipped_weight;
649
- }
650
- else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
651
-
652
- // We will pack two int4 elements per iteration of the inner loop.
653
- int8_t packed_int4s = 0;
654
- for (int packed_idx = 0; packed_idx < 2; ++packed_idx) {
655
- const int input_idx = 2 * jj + packed_idx;
656
- if (input_idx < num_cols) {
657
- const float col_scale = per_col_max[input_idx];
658
- const float weight_elt = float(current_weight_row[input_idx]);
659
- const float scaled_weight = round(weight_elt / col_scale);
660
- int int_weight = int(scaled_weight);
661
- const int8_t clipped_weight = std::max(-8, std::min(7, int_weight));
662
-
663
- // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits
664
- // if packing the second int4 and or the bits into the final result.
665
- packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx));
666
- }
667
- }
668
- current_quantized_weight_row[jj] = packed_int4s;
669
- }
670
- else {
671
- FT_CHECK_WITH_INFO(false, "Unsupported quantization type");
672
- }
673
- }
674
- }
675
- }
676
- const int arch = fastertransformer::getSMVersion();
677
- preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, arch);
678
- }
679
-
680
- template void
681
- symmetric_quantize<half, float>(int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
682
-
683
- template void
684
- symmetric_quantize<half, half>(int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
685
-
686
-
687
- template<typename ComputeType, typename WeightType>
688
- void symmetric_quantize(int8_t* processed_quantized_weight,
689
- ComputeType* scale_ptr,
690
- const WeightType* input_weight_ptr,
691
- const std::vector<size_t>& shape,
692
- QuantType quant_type)
693
- {
694
- symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type);
695
- }
696
-
697
- template void symmetric_quantize<float, float>(int8_t*, float*, const float*, const std::vector<size_t>&, QuantType);
698
-
699
- template void symmetric_quantize<half, float>(int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
700
-
701
- template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
702
-
703
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_kernels/cutlass_preprocessors.h DELETED
@@ -1,33 +0,0 @@
1
- #pragma once
2
- #pragma GCC diagnostic ignored "-Wstrict-aliasing"
3
-
4
- #include <cstddef>
5
- #include <cstdint>
6
- #include <vector>
7
-
8
- namespace fastertransformer {
9
-
10
- enum class QuantType { INT8_WEIGHT_ONLY, PACKED_INT4_WEIGHT_ONLY };
11
-
12
- int get_bits_in_quant_type(QuantType quant_type);
13
-
14
- void preprocess_weights(int8_t *preprocessed_quantized_weight,
15
- const int8_t *row_major_quantized_weight, size_t rows,
16
- size_t cols, bool is_int4, int arch);
17
-
18
- template<typename ComputeType, typename WeightType>
19
- void symmetric_quantize(int8_t* processed_quantized_weight,
20
- ComputeType* scale_ptr,
21
- const WeightType* input_weight_ptr,
22
- const std::vector<size_t>& shape,
23
- QuantType quant_type);
24
-
25
-
26
- template<typename ComputeType, typename WeightType>
27
- void symmetric_quantize(int8_t* processed_quantized_weight,
28
- int8_t* unprocessed_quantized_weight,
29
- ComputeType* scale_ptr,
30
- const WeightType* input_weight_ptr,
31
- const std::vector<size_t>& shape,
32
- QuantType quant_type);
33
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_kernels/fpA_intB_gemm.cu DELETED
@@ -1,99 +0,0 @@
1
- #include "fpA_intB_gemm.h"
2
- #include "fpA_intB_gemm/fpA_intB_gemm_template.h"
3
-
4
- namespace fastertransformer
5
- {
6
-
7
- ActivationType get_activation(const std::string &activation_name)
8
- {
9
- if (activation_name == "identity")
10
- return ActivationType::Identity;
11
- if (activation_name == "relu")
12
- return ActivationType::Relu;
13
- if (activation_name == "silu")
14
- return ActivationType::Silu;
15
- if (activation_name == "gelu")
16
- return ActivationType::Gelu;
17
- // todo: more
18
- return ActivationType::InvalidType;
19
- }
20
-
21
- void gemm_fp16_int(const half *A,
22
- const uint8_t *B,
23
- const half *weight_scales,
24
- half *C,
25
- int m, int n, int k,
26
- char *workspace_ptr,
27
- size_t workspace_bytes,
28
- cudaStream_t stream)
29
- {
30
- CutlassFpAIntBGemmRunner<half, uint8_t> runner;
31
- runner.gemm(A, B, weight_scales,
32
- C, m, n, k, workspace_ptr, workspace_bytes, stream);
33
- }
34
-
35
- template <typename WeightType>
36
- void gemm_fp16_int_bias_act(const half *A,
37
- const WeightType *B,
38
- const half *weight_scales,
39
- const half *bias,
40
- half *C,
41
- std::optional<std::string> activation,
42
- int m, int n, int k, int bias_stride, char *workspace_ptr,
43
- size_t workspace_bytes, cudaStream_t stream)
44
- {
45
- CutlassFpAIntBGemmRunner<half, WeightType> runner;
46
-
47
- if (!activation && bias == nullptr)
48
- {
49
- runner.gemm(A, B, weight_scales,
50
- C, m, n, k, workspace_ptr, workspace_bytes, stream);
51
- }
52
- else if (!activation)
53
- {
54
- runner.gemm_bias_act(A, B, weight_scales, bias,
55
- C, m, n, k, bias_stride, ActivationType::Identity, workspace_ptr, workspace_bytes, stream);
56
- }
57
- else
58
- {
59
- runner.gemm_bias_act(A, B, weight_scales, bias,
60
- C, m, n, k, bias_stride, get_activation(*activation), workspace_ptr, workspace_bytes, stream);
61
- }
62
- }
63
-
64
- template <typename WeightType>
65
- void gemm_fp16_int_bias_act_residual(
66
- const half *A, const WeightType *B, const half *weight_scales,
67
- const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
68
- const std::string &unary_op, int m, int n,
69
- int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream)
70
- {
71
- CutlassFpAIntBGemmRunner<half, WeightType> runner;
72
-
73
- runner.gemm_bias_act_residual(A, B, weight_scales, bias, residual,
74
- C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream);
75
- }
76
-
77
- template void gemm_fp16_int_bias_act<uint4b_t>(const half *A, const uint4b_t *B,
78
- const half *weight_scales, const half *bias,
79
- half *C, std::optional<std::string> activation, int m,
80
- int n, int k, int bias_stride, char *workspace_ptr,
81
- size_t workspace_bytes, cudaStream_t stream);
82
-
83
- template void gemm_fp16_int_bias_act_residual<uint4b_t>(
84
- const half *A, const uint4b_t *B, const half *weight_scales,
85
- const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
86
- const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
87
-
88
- template void gemm_fp16_int_bias_act<uint8_t>(const half *A, const uint8_t *B,
89
- const half *weight_scales, const half *bias,
90
- half *C, std::optional<std::string> activation, int m,
91
- int n, int k, int bias_stride, char *workspace_ptr,
92
- size_t workspace_bytes, cudaStream_t stream);
93
-
94
- template void gemm_fp16_int_bias_act_residual<uint8_t>(
95
- const half *A, const uint8_t *B, const half *weight_scales,
96
- const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
97
- const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
98
-
99
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_kernels/fpA_intB_gemm.h DELETED
@@ -1,36 +0,0 @@
1
- #pragma once
2
-
3
- #include <string>
4
- #include <optional>
5
-
6
- #include <cuda_runtime.h>
7
- #include "cutlass/numeric_types.h"
8
- #include "cutlass/half.h"
9
- #include "cutlass/integer_subbyte.h"
10
-
11
- namespace fastertransformer {
12
-
13
- using half = cutlass::half_t;
14
- using uint4b_t = cutlass::uint4b_t;
15
-
16
- // TODO: Support more general bias shape
17
-
18
- // base gemm
19
- void gemm_fp16_int(const half *A, const uint8_t * B, const half *weight_scales,
20
- half *C, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
21
-
22
- template <typename WeightType>
23
- void gemm_fp16_int_bias_act(const half *A, const WeightType *B,
24
- const half *weight_scales, const half *bias,
25
- half *C, std::optional<std::string> activation, int m,
26
- int n, int k, int bias_stride, char *workspace_ptr,
27
- size_t workspace_bytes, cudaStream_t stream);
28
-
29
- template <typename WeightType>
30
- void gemm_fp16_int_bias_act_residual(
31
- const half *A, const WeightType *B, const half *weight_scales,
32
- const half *bias, const half *residual, half *C, const std::string& activation, const std::string& binary_op,
33
- const std::string& unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
34
-
35
-
36
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h DELETED
@@ -1,118 +0,0 @@
1
- /*
2
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #pragma once
18
-
19
- #include "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h"
20
- #include "utils/activation_types.h"
21
- #include <cuda_runtime_api.h>
22
-
23
- namespace fastertransformer {
24
-
25
- /*
26
- This runner only supports:
27
- T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t}
28
-
29
- Activations, biases, scales and outputs are all assumed to be row-major.
30
-
31
- However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.
32
- In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor
33
- will instantiate the layout and preprocess based on the instantiation, so layout changes should only require
34
- modifications to mix_gemm_B_layout.h.
35
- */
36
-
37
- template<typename T, typename WeightType>
38
- class CutlassFpAIntBGemmRunner {
39
- public:
40
- CutlassFpAIntBGemmRunner();
41
- ~CutlassFpAIntBGemmRunner();
42
-
43
- void gemm(const T* A,
44
- const WeightType* B,
45
- const T* weight_scales,
46
- T* C,
47
- int m,
48
- int n,
49
- int k,
50
- char* workspace_ptr,
51
- const size_t workspace_bytes,
52
- cudaStream_t stream);
53
-
54
- void gemm_bias_act(const T* A,
55
- const WeightType* B,
56
- const T* weight_scales,
57
- const T* biases,
58
- T* C,
59
- int m,
60
- int n,
61
- int k,
62
- int bias_stride,
63
- ActivationType activation_type,
64
- char* workspace_ptr,
65
- const size_t workspace_bytes,
66
- cudaStream_t stream);
67
-
68
- void gemm_bias_act_residual(const T *A, const WeightType *B,
69
- const T *weight_scales, const T *biases,
70
- const T *residual, T *C, int m, int n, int k,
71
- const std::string& activation, const std::string& binary_op,
72
- const std::string& unary_op,
73
- char *workspace_ptr,
74
- const size_t workspace_bytes,
75
- cudaStream_t stream);
76
-
77
- // Returns desired workspace size in bytes.
78
- int getWorkspaceSize(const int m, const int n, const int k);
79
-
80
- private:
81
- template<typename EpilogueTag>
82
- void dispatch_to_arch(const T* A,
83
- const WeightType* B,
84
- const T* weight_scales,
85
- const T* biases,
86
- T* C,
87
- int m,
88
- int n,
89
- int k,
90
- int bias_stride,
91
- CutlassGemmConfig gemm_config,
92
- char* workspace_ptr,
93
- const size_t workspace_bytes,
94
- cudaStream_t stream,
95
- int* occupancy = nullptr);
96
-
97
- template<typename EpilogueTag>
98
- void run_gemm(const T* A,
99
- const WeightType* B,
100
- const T* weight_scales,
101
- const T* biases,
102
- T* C,
103
- int m,
104
- int n,
105
- int k,
106
- int bias_stride,
107
- char* workspace_ptr,
108
- const size_t workspace_bytes,
109
- cudaStream_t stream);
110
-
111
- private:
112
- static constexpr int split_k_limit = 7;
113
-
114
- int sm_;
115
- int multi_processor_count_;
116
- };
117
-
118
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h DELETED
@@ -1,858 +0,0 @@
1
- /*
2
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #pragma GCC diagnostic push
18
- #pragma GCC diagnostic ignored "-Wstrict-aliasing"
19
-
20
- #include "cutlass/gemm/device/gemm_universal_base.h"
21
- #include "cutlass/gemm/kernel/default_gemm.h"
22
- #include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
23
- #include "cutlass/epilogue/thread/linear_combination_residual_block.h"
24
- #include "cutlass_extensions/compute_occupancy.h"
25
-
26
- #include "cutlass_extensions/epilogue_helpers.h"
27
- #include "cutlass_extensions/ft_gemm_configs.h"
28
- #include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
29
- #include "cutlass_extensions/gemm/kernel/fpA_intB_gemm.h"
30
- #include "cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h"
31
- #include "cutlass_extensions/gemm/threadblock/default_mma.h"
32
-
33
- #pragma GCC diagnostic pop
34
-
35
- #include "../cutlass_heuristic.h"
36
- #include "fpA_intB_gemm.h"
37
- #include "cuda_utils.h"
38
-
39
- namespace fastertransformer {
40
-
41
- template <typename T,
42
- typename WeightType,
43
- typename arch,
44
- typename EpilogueTag,
45
- typename ThreadblockShape,
46
- typename WarpShape,
47
- int Stages>
48
- void generic_mixed_gemm_kernelLauncher(const T *A,
49
- const WeightType *B,
50
- const T *weight_scales,
51
- const T *biases,
52
- T *C,
53
- int m,
54
- int n,
55
- int k,
56
- int bias_stride,
57
- CutlassGemmConfig gemm_config,
58
- char *workspace,
59
- size_t workspace_bytes,
60
- cudaStream_t stream,
61
- int *occupancy = nullptr)
62
- {
63
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
64
- static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
65
- "Specialized for half, float");
66
-
67
- static_assert(cutlass::platform::is_same<T, WeightType>::value || cutlass::platform::is_same<WeightType, uint8_t>::value || cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
68
- "");
69
-
70
- // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
71
- using ElementType_ =
72
- typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
73
- using ElementType = ElementType_;
74
-
75
- using CutlassWeightType_ = typename cutlass::platform::
76
- conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t, WeightType>::type;
77
- using CutlassWeightType = CutlassWeightType_;
78
-
79
- // We need separate config for each architecture since we will target different tensorcore instructions. For float,
80
- // we do not target TCs.
81
- using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
82
- using ElementAccumulator = typename MixedGemmArchTraits::AccType;
83
-
84
- using EpilogueOp =
85
- typename Epilogue<ElementType, MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op;
86
-
87
- using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
88
- ElementType,
89
- cutlass::layout::RowMajor,
90
- MixedGemmArchTraits::ElementsPerAccessA,
91
- CutlassWeightType,
92
- typename MixedGemmArchTraits::LayoutB,
93
- MixedGemmArchTraits::ElementsPerAccessB,
94
- ElementType,
95
- cutlass::layout::RowMajor,
96
- ElementAccumulator,
97
- cutlass::arch::OpClassTensorOp,
98
- arch,
99
- ThreadblockShape,
100
- WarpShape,
101
- typename MixedGemmArchTraits::InstructionShape,
102
- EpilogueOp,
103
- typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
104
- Stages,
105
- true,
106
- typename MixedGemmArchTraits::Operator>::GemmKernel;
107
-
108
- using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB<typename GemmKernel_::Mma,
109
- typename GemmKernel_::Epilogue,
110
- typename GemmKernel_::ThreadblockSwizzle,
111
- arch, // Ensure top level arch is used for dispatch
112
- GemmKernel_::kSplitKSerial>;
113
-
114
- if (occupancy != nullptr)
115
- {
116
- *occupancy = compute_occupancy_for_kernel<GemmKernel>();
117
- return;
118
- }
119
-
120
- using Gemm = cutlass::gemm::device::GemmUniversalBase<GemmKernel>;
121
-
122
- const int ldb =
123
- cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value ? n : k * GemmKernel::kInterleave;
124
-
125
- typename Gemm::Arguments args({m, n, k},
126
- {reinterpret_cast<ElementType *>(const_cast<T *>(A)), k},
127
- {reinterpret_cast<CutlassWeightType *>(const_cast<WeightType *>(B)), ldb},
128
- {reinterpret_cast<ElementType *>(const_cast<T *>(weight_scales)), 0},
129
- // TODO: Support more general bias shape
130
- {reinterpret_cast<ElementType *>(const_cast<T *>(biases)), bias_stride},
131
- {reinterpret_cast<ElementType *>(C), n},
132
- gemm_config.split_k_factor,
133
- {ElementAccumulator(1.f), ElementAccumulator(0.f)});
134
-
135
- // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of
136
- // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the
137
- // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write
138
- // our own predicated iterator in order to relax this limitation.
139
- if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK)))
140
- {
141
- throw std::runtime_error("Temp assertion: k must be multiple of threadblockK");
142
- }
143
-
144
- Gemm gemm;
145
- if (gemm.get_workspace_size(args) > workspace_bytes)
146
- {
147
- FT_LOG_WARNING(
148
- "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation.");
149
- // If requested split-k factor will require more workspace bytes, revert to standard gemm.
150
- args.batch_count = 1;
151
- }
152
-
153
- auto can_implement = gemm.can_implement(args);
154
- if (can_implement != cutlass::Status::kSuccess)
155
- {
156
- std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement));
157
- throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
158
- }
159
-
160
- auto init_status = gemm.initialize(args, workspace, stream);
161
- if (init_status != cutlass::Status::kSuccess)
162
- {
163
- std::string err_msg =
164
- "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status));
165
- throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
166
- }
167
-
168
- auto run_status = gemm.run(stream);
169
- if (run_status != cutlass::Status::kSuccess)
170
- {
171
- std::string err_msg =
172
- "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status));
173
- throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
174
- }
175
- }
176
-
177
- template<typename T,
178
- typename WeightType,
179
- typename arch,
180
- typename EpilogueTag,
181
- typename ThreadblockShape,
182
- typename WarpShape,
183
- int Stages,
184
- typename Enable = void>
185
- struct dispatch_stages {
186
- static void dispatch(const T *A,
187
- const WeightType *B,
188
- const T *weight_scales,
189
- const T *biases,
190
- T *C,
191
- int m,
192
- int n,
193
- int k,
194
- int bias_stride,
195
- CutlassGemmConfig gemm_config,
196
- char *workspace,
197
- size_t workspace_bytes,
198
- cudaStream_t stream,
199
- int *occupancy = nullptr)
200
- {
201
-
202
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
203
- std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages);
204
- throw std::runtime_error("[FT Error][dispatch_stages::dispatch] " + err_msg);
205
- }
206
- };
207
-
208
- template<typename T,
209
- typename WeightType,
210
- typename arch,
211
- typename EpilogueTag,
212
- typename ThreadblockShape,
213
- typename WarpShape>
214
- struct dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2> {
215
- static void dispatch(const T *A,
216
- const WeightType *B,
217
- const T *weight_scales,
218
- const T *biases,
219
- T *C,
220
- int m,
221
- int n,
222
- int k,
223
- int bias_stride,
224
- CutlassGemmConfig gemm_config,
225
- char *workspace,
226
- size_t workspace_bytes,
227
- cudaStream_t stream,
228
- int *occupancy = nullptr)
229
- {
230
-
231
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
232
- generic_mixed_gemm_kernelLauncher<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(
233
- A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
234
- }
235
- };
236
-
237
- template<typename T,
238
- typename WeightType,
239
- typename EpilogueTag,
240
- typename ThreadblockShape,
241
- typename WarpShape,
242
- int Stages>
243
- struct dispatch_stages<T,
244
- WeightType,
245
- cutlass::arch::Sm80,
246
- EpilogueTag,
247
- ThreadblockShape,
248
- WarpShape,
249
- Stages,
250
- typename std::enable_if<(Stages > 2)>::type> {
251
- static void dispatch(const T *A,
252
- const WeightType *B,
253
- const T *weight_scales,
254
- const T *biases,
255
- T *C,
256
- int m,
257
- int n,
258
- int k,
259
- int bias_stride,
260
- CutlassGemmConfig gemm_config,
261
- char *workspace,
262
- size_t workspace_bytes,
263
- cudaStream_t stream,
264
- int *occupancy = nullptr)
265
- {
266
-
267
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
268
- generic_mixed_gemm_kernelLauncher<T,
269
- WeightType,
270
- cutlass::arch::Sm80,
271
- EpilogueTag,
272
- ThreadblockShape,
273
- WarpShape,
274
- Stages>(
275
- A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
276
- }
277
- };
278
-
279
- template <typename T,
280
- typename WeightType,
281
- typename arch,
282
- typename EpilogueTag,
283
- typename ThreadblockShape,
284
- typename WarpShape>
285
- void dispatch_gemm_config(const T *A,
286
- const WeightType *B,
287
- const T *weight_scales,
288
- const T *biases,
289
- T *C,
290
- int m,
291
- int n,
292
- int k,
293
- int bias_stride,
294
- CutlassGemmConfig gemm_config,
295
- char *workspace,
296
- size_t workspace_bytes,
297
- cudaStream_t stream,
298
- int *occupancy = nullptr)
299
- {
300
-
301
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
302
- switch (gemm_config.stages) {
303
- case 2:
304
- using DispatcherStages2 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>;
305
- DispatcherStages2::dispatch(
306
- A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
307
- break;
308
- case 3:
309
- using DispatcherStages3 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>;
310
- DispatcherStages3::dispatch(
311
- A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
312
- break;
313
- case 4:
314
- using DispatcherStages4 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>;
315
- DispatcherStages4::dispatch(
316
- A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
317
- break;
318
- default:
319
- std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages);
320
- throw std::runtime_error("[FT Error][dispatch_gemm_config] " + err_msg);
321
- break;
322
- }
323
- }
324
-
325
- template <typename T, typename WeightType, typename arch, typename EpilogueTag>
326
- void dispatch_gemm_to_cutlass(const T *A,
327
- const WeightType *B,
328
- const T *weight_scales,
329
- const T *biases,
330
- T *C,
331
- int m,
332
- int n,
333
- int k,
334
- int bias_stride,
335
- char *workspace,
336
- size_t workspace_bytes,
337
- CutlassGemmConfig gemm_config,
338
- cudaStream_t stream,
339
- int *occupancy = nullptr)
340
- {
341
-
342
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
343
-
344
- // Note that SIMT configs are omitted here since they are not supported for fpA_intB.
345
- // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best
346
- // for mixed type gemms.
347
- switch (gemm_config.tile_config) {
348
- case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
349
- dispatch_gemm_config<T,
350
- WeightType,
351
- arch,
352
- EpilogueTag,
353
- cutlass::gemm::GemmShape<32, 128, 64>,
354
- cutlass::gemm::GemmShape<32, 32, 64>>(
355
- A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
356
- break;
357
- case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
358
- dispatch_gemm_config<T,
359
- WeightType,
360
- arch,
361
- EpilogueTag,
362
- cutlass::gemm::GemmShape<64, 128, 64>,
363
- cutlass::gemm::GemmShape<64, 32, 64>>(
364
- A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
365
- break;
366
- case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
367
- dispatch_gemm_config<T,
368
- WeightType,
369
- arch,
370
- EpilogueTag,
371
- cutlass::gemm::GemmShape<128, 128, 64>,
372
- cutlass::gemm::GemmShape<128, 32, 64>>(
373
- A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
374
- break;
375
- case CutlassTileConfig::Undefined:
376
- throw std::runtime_error("[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined.");
377
- break;
378
- case CutlassTileConfig::ChooseWithHeuristic:
379
- throw std::runtime_error(
380
- "[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by heuristic.");
381
- break;
382
- default:
383
- throw std::runtime_error(
384
- "[FT Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM.");
385
- break;
386
- }
387
- }
388
-
389
- template<typename T, typename WeightType>
390
- CutlassFpAIntBGemmRunner<T, WeightType>::CutlassFpAIntBGemmRunner()
391
- {
392
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
393
- int device{-1};
394
- check_cuda_error(cudaGetDevice(&device));
395
- sm_ = getSMVersion();
396
- check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
397
- }
398
-
399
- template<typename T, typename WeightType>
400
- CutlassFpAIntBGemmRunner<T, WeightType>::~CutlassFpAIntBGemmRunner()
401
- {
402
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
403
- }
404
-
405
- template<typename T, typename WeightType>
406
- template<typename EpilogueTag>
407
- void CutlassFpAIntBGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(const T* A,
408
- const WeightType* B,
409
- const T* weight_scales,
410
- const T* biases,
411
- T* C,
412
- int m,
413
- int n,
414
- int k,
415
- int bias_stride,
416
- CutlassGemmConfig gemm_config,
417
- char* workspace_ptr,
418
- const size_t workspace_bytes,
419
- cudaStream_t stream,
420
- int* occupancy)
421
- {
422
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
423
- if (sm_ >= 70 && sm_ < 75) {
424
- dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm70, EpilogueTag>(
425
- A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
426
- } else if (sm_ >= 75 && sm_ < 80) {
427
- dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm75, EpilogueTag>(
428
- A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
429
- } else if (sm_ >= 80 && sm_ < 90) {
430
- dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(
431
- A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
432
- }
433
- else {
434
- throw std::runtime_error(
435
- "[FT Error][CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed type GEMM");
436
- }
437
- }
438
-
439
- template<typename T, typename WeightType>
440
- template<typename EpilogueTag>
441
- void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(const T* A,
442
- const WeightType* B,
443
- const T* weight_scales,
444
- const T* biases,
445
- T* C,
446
- int m,
447
- int n,
448
- int k,
449
- int bias_stride,
450
- char* workspace_ptr,
451
- const size_t workspace_bytes,
452
- cudaStream_t stream)
453
- {
454
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
455
- static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
456
- std::vector<CutlassGemmConfig> candidate_configs = get_candidate_configs(sm_, is_weight_only, false);
457
- std::vector<int> occupancies(candidate_configs.size());
458
-
459
- for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
460
- dispatch_to_arch<EpilogueTag>(A,
461
- B,
462
- weight_scales,
463
- biases,
464
- C,
465
- m,
466
- n,
467
- k,
468
- bias_stride,
469
- candidate_configs[ii],
470
- workspace_ptr,
471
- workspace_bytes,
472
- stream,
473
- &occupancies[ii]);
474
- }
475
- // Standard GEMM, so 1 "expert". We use the same function for MoE and regular FFN.
476
- static constexpr int num_experts = 1;
477
- CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(candidate_configs,
478
- occupancies,
479
- m,
480
- n,
481
- k,
482
- num_experts,
483
- split_k_limit,
484
- workspace_bytes,
485
- multi_processor_count_,
486
- is_weight_only);
487
-
488
- dispatch_to_arch<EpilogueTag>(
489
- A, B, weight_scales, biases, C, m, n, k, bias_stride, chosen_config, workspace_ptr, workspace_bytes, stream);
490
- }
491
-
492
- template <typename T, typename WeightType>
493
- void CutlassFpAIntBGemmRunner<T, WeightType>::gemm_bias_act(const T *A,
494
- const WeightType *B,
495
- const T *weight_scales,
496
- const T *biases,
497
- T *C,
498
- int m,
499
- int n,
500
- int k,
501
- int bias_stride,
502
- ActivationType activation_type,
503
- char *workspace_ptr,
504
- const size_t workspace_bytes,
505
- cudaStream_t stream)
506
- {
507
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
508
-
509
- switch (activation_type) {
510
- case ActivationType::Relu:
511
- run_gemm<EpilogueOpBiasReLU>(
512
- A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
513
- break;
514
- case ActivationType::Gelu:
515
- run_gemm<EpilogueOpBiasFtGelu>(
516
- A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
517
- break;
518
- case ActivationType::Silu:
519
- run_gemm<EpilogueOpBiasSilu>(
520
- A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
521
- break;
522
- case ActivationType::Identity:
523
- run_gemm<EpilogueOpBias>(A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
524
- break;
525
- case ActivationType::InvalidType:
526
- FT_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be valid.");
527
- break;
528
- default: {
529
- if (isGatedActivation(activation_type)) {
530
- FT_CHECK_WITH_INFO(false, "Fused gated activations not supported");
531
- }
532
- else {
533
- FT_CHECK_WITH_INFO(false, "Invalid activation type.");
534
- }
535
- }
536
- }
537
- }
538
-
539
- template<typename T, typename WeightType>
540
- void CutlassFpAIntBGemmRunner<T, WeightType>::gemm(const T* A,
541
- const WeightType* B,
542
- const T* weight_scales,
543
- T* C,
544
- int m,
545
- int n,
546
- int k,
547
- char* workspace_ptr,
548
- const size_t workspace_bytes,
549
- cudaStream_t stream)
550
- {
551
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
552
- run_gemm<EpilogueOpNoBias>(A, B, weight_scales, nullptr, C, m, n, k, 0, workspace_ptr, workspace_bytes, stream);
553
- }
554
-
555
- template <typename T, typename WeightType, typename Arch,
556
- typename ThreadblockShape, typename WarpShape, typename EpilogueOp,
557
- int stages>
558
- void dispatch_gemm_residual(const T *A, const WeightType *B,
559
- const T *weight_scales, const T *biases,
560
- const T *residual, T *C, int m, int n, int k,
561
- char *workspace_ptr, const size_t workspace_bytes,
562
- cudaStream_t stream) {
563
- using ElementType = typename cutlass::platform::conditional<
564
- cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
565
- using ElementOutput = ElementType;
566
-
567
- using MixedGemmArchTraits =
568
- cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, WeightType, Arch>;
569
- using ElementAccumulator = typename EpilogueOp::ElementAccumulator;
570
-
571
- using Swizzle =
572
- typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
573
- using InstructionShape = typename MixedGemmArchTraits::InstructionShape;
574
-
575
- using Epilogue = typename cutlass::gemm::kernel::DefaultGemmWithBroadcast<
576
- ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone,
577
- MixedGemmArchTraits::ElementsPerAccessA, WeightType,
578
- typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone,
579
- MixedGemmArchTraits::ElementsPerAccessB, ElementType,
580
- cutlass::layout::RowMajor, ElementAccumulator,
581
- cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape,
582
- InstructionShape, EpilogueOp, Swizzle, stages,
583
- typename MixedGemmArchTraits::Operator>::Epilogue;
584
-
585
- using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
586
- ElementType, cutlass::layout::RowMajor,
587
- MixedGemmArchTraits::ElementsPerAccessA, WeightType,
588
- typename MixedGemmArchTraits::LayoutB,
589
- MixedGemmArchTraits::ElementsPerAccessB, ElementType,
590
- cutlass::layout::RowMajor, ElementAccumulator,
591
- cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape,
592
- InstructionShape, EpilogueOp, Swizzle, stages, true,
593
- typename MixedGemmArchTraits::Operator>::GemmKernel;
594
-
595
- using GemmKernel = cutlass::gemm::kernel::GemmFpAIntBWithBroadcast<
596
- typename GemmKernel_::Mma, Epilogue,
597
- typename GemmKernel_::ThreadblockSwizzle, Arch>;
598
-
599
- using Gemm = cutlass::gemm::device::GemmUniversalBase<GemmKernel>;
600
-
601
- // TODO: Support batch
602
- const int batch_count = 1;
603
- const auto lda = k;
604
- const int ldb =
605
- cutlass::platform::is_same<cutlass::layout::RowMajor,
606
- typename MixedGemmArchTraits::LayoutB>::value
607
- ? n
608
- : k * GemmKernel::kInterleave;
609
- const int ldc = n;
610
-
611
- typename Gemm::Arguments args(
612
- {m, n, k}, batch_count,
613
- {ElementAccumulator(1.f), ElementAccumulator(1.f)}, A, B, weight_scales,
614
- residual, C, biases, nullptr, 0, 0, 0, 0, 0, 0, lda, ldb, ldc, ldc, 0, 0);
615
-
616
- if (GemmKernel::kInterleave > 1 &&
617
- ((k % MixedGemmArchTraits::ThreadblockK) ||
618
- (k % MixedGemmArchTraits::ThreadblockK))) {
619
- throw std::runtime_error(
620
- "Temp assertion: k must be multiple of threadblockK");
621
- }
622
-
623
- Gemm gemm;
624
- auto can_implement = gemm.can_implement(args);
625
- if (can_implement != cutlass::Status::kSuccess) {
626
- std::string err_msg =
627
- "fpA_intB cutlass kernel will fail for params. Error: " +
628
- std::string(cutlassGetStatusString(can_implement));
629
- throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
630
- }
631
-
632
- auto init_status = gemm.initialize(args, workspace_ptr, stream);
633
- if (init_status != cutlass::Status::kSuccess) {
634
- std::string err_msg =
635
- "Failed to initialize cutlass fpA_intB gemm. Error: " +
636
- std::string(cutlassGetStatusString(init_status));
637
- throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
638
- }
639
-
640
- auto run_status = gemm.run(stream);
641
- if (run_status != cutlass::Status::kSuccess) {
642
- std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " +
643
- std::string(cutlassGetStatusString(run_status));
644
- throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
645
- }
646
- }
647
-
648
- template <typename T, typename WeightType, typename Arch, typename EpilogueOp,
649
- int stages>
650
- void dispatch_gemm_residual(CutlassTileConfig tile_config, const T *A,
651
- const WeightType *B, const T *weight_scales,
652
- const T *biases, const T *residual, T *C, int m,
653
- int n, int k, char *workspace_ptr,
654
- const size_t workspace_bytes, cudaStream_t stream) {
655
- if (tile_config == CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64) {
656
- dispatch_gemm_residual<
657
- T, WeightType, Arch, cutlass::gemm::GemmShape<32, 128, 64>,
658
- cutlass::gemm::GemmShape<32, 32, 64>, EpilogueOp, stages>(
659
- A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
660
- workspace_bytes, stream);
661
- } else if (tile_config ==
662
- CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64) {
663
- dispatch_gemm_residual<
664
- T, WeightType, Arch, cutlass::gemm::GemmShape<64, 128, 64>,
665
- cutlass::gemm::GemmShape<64, 32, 64>, EpilogueOp, stages>(
666
- A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
667
- workspace_bytes, stream);
668
- } else { // CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
669
- dispatch_gemm_residual<
670
- T, WeightType, Arch, cutlass::gemm::GemmShape<128, 128, 64>,
671
- cutlass::gemm::GemmShape<128, 32, 64>, EpilogueOp, stages>(
672
- A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
673
- workspace_bytes, stream);
674
- }
675
- }
676
-
677
- template <typename T, typename WeightType, typename Arch, typename EpilogueOp>
678
- void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
679
- const WeightType *B, const T *weight_scales,
680
- const T *biases, const T *residual, T *C, int m,
681
- int n, int k, char *workspace_ptr,
682
- const size_t workspace_bytes, cudaStream_t stream) {
683
- if constexpr (std::is_same<Arch, cutlass::arch::Sm75>::value) {
684
- dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm75, EpilogueOp, 2>(
685
- config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
686
- workspace_ptr, workspace_bytes, stream);
687
- } else if constexpr (std::is_same<Arch, cutlass::arch::Sm70>::value) {
688
- dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm70, EpilogueOp, 2>(
689
- config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
690
- workspace_ptr, workspace_bytes, stream);
691
- } else {
692
- if (config.stages == 3) {
693
- dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 3>(
694
- config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
695
- workspace_ptr, workspace_bytes, stream);
696
- } else if (config.stages == 4) {
697
- dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 4>(
698
- config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
699
- workspace_ptr, workspace_bytes, stream);
700
- } else { // 2
701
- dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 2>(
702
- config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
703
- workspace_ptr, workspace_bytes, stream);
704
- }
705
- }
706
- }
707
-
708
- template <typename T, typename WeightType, typename Arch,
709
- template <typename T_> class ActivationOp,
710
- template <typename T_> class BinaryOp>
711
- inline void
712
- dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
713
- const WeightType *B, const T *weight_scales,
714
- const T *biases, const T *residual, T *C, int m, int n,
715
- int k, const std::string &unary_op, char *workspace_ptr,
716
- const size_t workspace_bytes, cudaStream_t stream) {
717
- using ElementOutput = T;
718
- using MixedGemmArchTraits =
719
- cutlass::gemm::kernel::MixedGemmArchTraits<T, WeightType, Arch>;
720
- using ElementAccumulator = typename MixedGemmArchTraits::AccType;
721
-
722
- if (unary_op == "identity") {
723
- using EpilogueOp =
724
- cutlass::epilogue::thread::LinearCombinationResidualBlock<
725
- ElementOutput, ElementAccumulator, ElementAccumulator,
726
- ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
727
- ActivationOp, BinaryOp, cutlass::epilogue::thread::Identity>;
728
- dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp>(
729
- config, A, B, weight_scales, biases, residual, C, m, n, k,
730
- workspace_ptr, workspace_bytes, stream);
731
- } else if (unary_op == "relu") {
732
- using EpilogueOp =
733
- cutlass::epilogue::thread::LinearCombinationResidualBlock<
734
- ElementOutput, ElementAccumulator, ElementAccumulator,
735
- ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
736
- ActivationOp, BinaryOp, cutlass::epilogue::thread::ReLu>;
737
- dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp>(
738
- config, A, B, weight_scales, biases, residual, C, m, n, k,
739
- workspace_ptr, workspace_bytes, stream);
740
- } else {
741
- throw std::runtime_error(
742
- "[FT Error][Unsupported unary op after residual block] " + unary_op);
743
- }
744
- }
745
-
746
- template <typename T, typename WeightType, typename Arch,
747
- template <typename T_> class ActivationOp>
748
- void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
749
- const WeightType *B, const T *weight_scales,
750
- const T *biases, const T *residual, T *C, int m,
751
- int n, int k, const std::string &binary_op,
752
- const std::string &unary_op, char *workspace_ptr,
753
- const size_t workspace_bytes, cudaStream_t stream) {
754
- if (binary_op == "plus") {
755
- dispatch_gemm_residual<T, WeightType, Arch, ActivationOp, cutlass::plus>(
756
- config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op,
757
- workspace_ptr, workspace_bytes, stream);
758
- } else if (binary_op == "multiply") {
759
- dispatch_gemm_residual<T, WeightType, Arch, ActivationOp,
760
- cutlass::multiplies>(
761
- config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op,
762
- workspace_ptr, workspace_bytes, stream);
763
- } else {
764
- throw std::runtime_error(
765
- "[FT Error][Unsupported binary op for residual block] " + binary_op);
766
- }
767
- }
768
-
769
- template <typename T, typename WeightType, typename Arch>
770
- void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
771
- const WeightType *B, const T *weight_scales,
772
- const T *biases, const T *residual, T *C, int m,
773
- int n, int k, const std::string &activation,
774
- const std::string &binary_op,
775
- const std::string &unary_op, char *workspace_ptr,
776
- const size_t workspace_bytes, cudaStream_t stream) {
777
- if (activation == "identity") {
778
- dispatch_gemm_residual<T, WeightType, Arch,
779
- cutlass::epilogue::thread::Identity>(
780
- config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
781
- unary_op, workspace_ptr, workspace_bytes, stream);
782
- } else if ("silu") {
783
- dispatch_gemm_residual<T, WeightType, Arch,
784
- cutlass::epilogue::thread::SiLu>(
785
- config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
786
- unary_op, workspace_ptr, workspace_bytes, stream);
787
- } else if ("relu") {
788
- dispatch_gemm_residual<T, WeightType, Arch,
789
- cutlass::epilogue::thread::ReLu>(
790
- config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
791
- unary_op, workspace_ptr, workspace_bytes, stream);
792
- } else if ("gelu") {
793
- dispatch_gemm_residual<T, WeightType, Arch,
794
- cutlass::epilogue::thread::GELU>(
795
- config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
796
- unary_op, workspace_ptr, workspace_bytes, stream);
797
- } else {
798
- throw std::runtime_error(
799
- "[FT Error][Unsupported activation before residual binary op] " +
800
- activation);
801
- }
802
- }
803
-
804
- template <typename T, typename WeightType>
805
- void CutlassFpAIntBGemmRunner<T, WeightType>::gemm_bias_act_residual(
806
- const T *A, const WeightType *B, const T *weight_scales, const T *biases,
807
- const T *residual, T *C, int m, int n, int k, const std::string &activation,
808
- const std::string &binary_op, const std::string &unary_op,
809
- char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) {
810
-
811
- std::vector<CutlassGemmConfig> candidate_configs =
812
- get_candidate_configs(sm_, true, false);
813
- std::vector<int> occupancies(candidate_configs.size());
814
-
815
- for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
816
- dispatch_to_arch<EpilogueOpNoBias>(
817
- A, B, weight_scales, biases, C, m, n, k, 0, candidate_configs[ii],
818
- workspace_ptr, workspace_bytes, stream, &occupancies[ii]);
819
- }
820
-
821
- CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(
822
- candidate_configs, occupancies, m, n, k, 1, split_k_limit,
823
- workspace_bytes, multi_processor_count_, true);
824
-
825
- if (sm_ >= 80 && sm_ < 90) {
826
- dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm80>(
827
- chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
828
- activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
829
- stream);
830
- } else if (sm_ >= 75 && sm_ < 80) {
831
- dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm75>(
832
- chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
833
- activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
834
- stream);
835
- } else if (sm_ == 70) {
836
- dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm70>(
837
- chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
838
- activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
839
- stream);
840
- } else {
841
- throw std::runtime_error("[FT Error][Unsupported SM] " + sm_);
842
- }
843
- }
844
-
845
- template<typename T, typename WeightType>
846
- int CutlassFpAIntBGemmRunner<T, WeightType>::getWorkspaceSize(const int m, const int n, const int k)
847
- {
848
- FT_LOG_DEBUG(__PRETTY_FUNCTION__);
849
- // TODO(masahi): Shouldn't it be 0?
850
-
851
- // These are the min tile sizes for each config, which would launch the maximum number of blocks
852
- const int max_grid_m = (m + 31) / 32;
853
- const int max_grid_n = (n + 127) / 128;
854
- // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
855
- return max_grid_m * max_grid_n * split_k_limit * 4;
856
- }
857
-
858
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_kernels/fpA_intB_gemm_wrapper.cu DELETED
@@ -1,201 +0,0 @@
1
- #include <torch/all.h>
2
- #include "cub/cub.cuh"
3
- #include <cuda_runtime.h>
4
- #include <cuda_fp16.h>
5
- #include <c10/cuda/CUDAGuard.h>
6
- #include "fpA_intB_gemm_wrapper.h"
7
- #include "fpA_intB_gemm.h"
8
- #include "cutlass_preprocessors.h"
9
- #include "cuda_utils.h"
10
- #include "weightOnlyBatchedGemv/enabled.h"
11
- #include "weightOnlyBatchedGemv/kernelLauncher.h"
12
- #include "torch_utils.h"
13
-
14
- #include <vector>
15
-
16
- namespace ft = fastertransformer;
17
-
18
- int getWorkspaceSize(const int m, const int n, const int k)
19
- {
20
- // These are the min tile sizes for each config, which would launch the maximum number of blocks
21
- const int max_grid_m = (m + 31) / 32;
22
- const int max_grid_n = (n + 127) / 128;
23
- const int split_k_limit = 7;
24
- // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
25
- return max_grid_m * max_grid_n * split_k_limit * 4;
26
- }
27
-
28
- std::vector<torch::Tensor>
29
- symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
30
- at::ScalarType quant_type,
31
- bool return_unprocessed_quantized_tensor)
32
- {
33
- CHECK_CPU(weight);
34
- CHECK_CONTIGUOUS(weight);
35
- TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
36
- TORCH_CHECK(weight.dim() == 2 || weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3");
37
-
38
- auto _st = weight.scalar_type();
39
- TORCH_CHECK(_st == torch::kFloat32 || _st == torch::kFloat16, "Invalid datatype. Weight must be FP16 or FP32");
40
- TORCH_CHECK(quant_type == torch::kInt8 || quant_type == at::ScalarType::QUInt4x2, "Must be int4 or int8 quantization");
41
- ft::QuantType ft_quant_type = ft::get_ft_quant_type(quant_type);
42
-
43
- const size_t num_experts = weight.dim() == 2 ? 1 : weight.size(0);
44
- const size_t num_rows = weight.size(-2);
45
- const size_t num_cols = weight.size(-1);
46
-
47
- const size_t bits_in_type = ft::get_bits_in_quant_type(ft_quant_type);
48
- const size_t bytes_per_out_col = num_cols * bits_in_type / 8;
49
-
50
- const size_t input_mat_size = num_rows * num_cols;
51
- const size_t quantized_mat_size = num_rows * bytes_per_out_col;
52
-
53
- std::vector<long int> quantized_weight_shape;
54
- std::vector<long int> scale_shape;
55
- if (weight.dim() == 2) {
56
- quantized_weight_shape = {long(num_rows), long(bytes_per_out_col)};
57
- scale_shape = {long(num_cols)};
58
- }
59
- else if (weight.dim() == 3) {
60
- quantized_weight_shape = {long(num_experts), long(num_rows), long(bytes_per_out_col)};
61
- scale_shape = {long(num_experts), long(num_cols)};
62
- }
63
- else {
64
- TORCH_CHECK(false, "Invalid weight dimension. Weight must have dim 2 or 3");
65
- }
66
-
67
- torch::Tensor unprocessed_quantized_weight =
68
- torch::empty(quantized_weight_shape, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false));
69
-
70
- torch::Tensor processed_quantized_weight = torch::empty_like(unprocessed_quantized_weight);
71
-
72
- torch::Tensor scales = torch::empty(scale_shape, torch::dtype(weight.dtype()).device(torch::kCPU).requires_grad(false));
73
-
74
- int8_t *unprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(unprocessed_quantized_weight.data_ptr());
75
- int8_t *processed_quantized_weight_ptr = reinterpret_cast<int8_t *>(processed_quantized_weight.data_ptr());
76
-
77
- if (weight.scalar_type() == at::ScalarType::Float)
78
- {
79
- ft::symmetric_quantize<float, float>(processed_quantized_weight_ptr,
80
- unprocessed_quantized_weight_ptr,
81
- reinterpret_cast<float *>(scales.data_ptr()),
82
- reinterpret_cast<const float *>(weight.data_ptr()),
83
- {num_rows, num_cols},
84
- ft_quant_type);
85
- }
86
- else if (weight.scalar_type() == at::ScalarType::Half)
87
- {
88
- ft::symmetric_quantize<half, half>(processed_quantized_weight_ptr,
89
- unprocessed_quantized_weight_ptr,
90
- reinterpret_cast<half *>(scales.data_ptr()),
91
- reinterpret_cast<const half *>(weight.data_ptr()),
92
- {num_rows, num_cols},
93
- ft_quant_type);
94
- }
95
- else
96
- {
97
- TORCH_CHECK(false, "Invalid data type. Weight must be FP32/FP16");
98
- }
99
-
100
- if (return_unprocessed_quantized_tensor)
101
- {
102
- return std::vector<torch::Tensor>{unprocessed_quantized_weight, processed_quantized_weight, scales};
103
- }
104
-
105
- return std::vector<torch::Tensor>{processed_quantized_weight, scales};
106
- }
107
-
108
- torch::Tensor preprocess_weights_cuda(torch::Tensor const &origin_weight,
109
- bool is_int4)
110
- {
111
- // guarantee the weight is cpu tensor
112
- CHECK_CPU(origin_weight);
113
-
114
- torch::Tensor preprocessed_quantized_weight = torch::empty_like(origin_weight);
115
- int8_t *preprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(preprocessed_quantized_weight.data_ptr());
116
- const int8_t *row_major_quantized_weight_ptr = reinterpret_cast<const int8_t *>(origin_weight.data_ptr());
117
- size_t rows = origin_weight.size(-2);
118
- size_t cols = origin_weight.size(-1);
119
- int arch = ft::getSMVersion();
120
- ft::preprocess_weights(preprocessed_quantized_weight_ptr,
121
- row_major_quantized_weight_ptr,
122
- rows,
123
- cols,
124
- is_int4,
125
- arch);
126
- return preprocessed_quantized_weight;
127
- }
128
-
129
- torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
130
- torch::Tensor const &weight,
131
- torch::Tensor const &scale)
132
- {
133
- c10::cuda::CUDAGuard device_guard(input.device());
134
- // TORCH_CHECK(input.dim() == 3 || input.dim() == 2, "Invalid input dim: ", input.dim());
135
- const int m = input.dim() == 2 ? input.size(0) : input.size(0) * input.size(1);
136
- const int k = input.size(-1);
137
- const int n = weight.size(-1);
138
- auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device());
139
- torch::Tensor output = input.dim() == 2 ? torch::empty({m, n}, options) : torch::empty({input.size(0), input.size(1), n}, options);
140
- const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
141
- const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
142
- const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
143
- ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
144
- // const int max_size = std::max(n, k);
145
- // size_t workspace_size = getWorkspaceSize(m, max_size, max_size);
146
- // void *ptr = nullptr;
147
- // char *workspace_ptr = workspace_size > 0 ? (char *)cudaMalloc((void **)&ptr, workspace_size) : nullptr;
148
- const bool use_cuda_kernel = m <= SMALL_M_FAST_PATH;
149
- // const bool use_cuda_kernel = false;
150
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
151
-
152
- if(use_cuda_kernel){
153
- tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16;
154
- tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b;
155
- tensorrt_llm::kernels::WeightOnlyParams params{weight_ptr, reinterpret_cast<const uint8_t *>(scale.data_ptr()), nullptr,
156
- reinterpret_cast<half *>(input.data_ptr()), nullptr, nullptr, reinterpret_cast<half *>(output.data_ptr()), m, n, k, 0, weight_only_quant_type,
157
- tensorrt_llm::kernels::WeightOnlyType::PerChannel,
158
- tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
159
- tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
160
- }
161
- else
162
- ft::gemm_fp16_int(
163
- input_ptr,
164
- weight_ptr,
165
- scale_ptr,
166
- output_ptr,
167
- m, n, k,
168
- nullptr,
169
- 0,
170
- stream);
171
- return output;
172
- }
173
-
174
-
175
- torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
176
- torch::Tensor const &weight,
177
- torch::Tensor const &scale,
178
- torch::Tensor &output,
179
- const int64_t m,
180
- const int64_t n,
181
- const int64_t k)
182
- {
183
- c10::cuda::CUDAGuard device_guard(input.device());
184
-
185
- const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
186
- const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
187
- const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
188
- ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
189
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
190
-
191
- ft::gemm_fp16_int(
192
- input_ptr,
193
- weight_ptr,
194
- scale_ptr,
195
- output_ptr,
196
- m, n, k,
197
- nullptr,
198
- 0,
199
- stream);
200
- return output;
201
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cutlass_kernels/fpA_intB_gemm_wrapper.h DELETED
@@ -1,23 +0,0 @@
1
- #include <torch/all.h>
2
- #include <vector>
3
-
4
- #define SMALL_M_FAST_PATH 4
5
- std::vector<torch::Tensor>
6
- symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
7
- at::ScalarType quant_type,
8
- bool return_unprocessed_quantized_tensor);
9
-
10
- torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight,
11
- bool is_int4);
12
-
13
- torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
14
- torch::Tensor const &weight,
15
- torch::Tensor const &scale);
16
-
17
- torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
18
- torch::Tensor const &weight,
19
- torch::Tensor const &scale,
20
- torch::Tensor &output,
21
- const int64_t m,
22
- const int64_t n,
23
- const int64_t k);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.lock DELETED
@@ -1,169 +0,0 @@
1
- {
2
- "nodes": {
3
- "flake-compat": {
4
- "locked": {
5
- "lastModified": 1747046372,
6
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
- "owner": "edolstra",
8
- "repo": "flake-compat",
9
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
- "type": "github"
11
- },
12
- "original": {
13
- "owner": "edolstra",
14
- "repo": "flake-compat",
15
- "type": "github"
16
- }
17
- },
18
- "flake-compat_2": {
19
- "locked": {
20
- "lastModified": 1733328505,
21
- "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
- "owner": "edolstra",
23
- "repo": "flake-compat",
24
- "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
- "type": "github"
26
- },
27
- "original": {
28
- "owner": "edolstra",
29
- "repo": "flake-compat",
30
- "type": "github"
31
- }
32
- },
33
- "flake-utils": {
34
- "inputs": {
35
- "systems": "systems"
36
- },
37
- "locked": {
38
- "lastModified": 1731533236,
39
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
- "owner": "numtide",
41
- "repo": "flake-utils",
42
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
- "type": "github"
44
- },
45
- "original": {
46
- "owner": "numtide",
47
- "repo": "flake-utils",
48
- "type": "github"
49
- }
50
- },
51
- "flake-utils_2": {
52
- "inputs": {
53
- "systems": "systems_2"
54
- },
55
- "locked": {
56
- "lastModified": 1731533236,
57
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
- "owner": "numtide",
59
- "repo": "flake-utils",
60
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
- "type": "github"
62
- },
63
- "original": {
64
- "owner": "numtide",
65
- "repo": "flake-utils",
66
- "type": "github"
67
- }
68
- },
69
- "hf-nix": {
70
- "inputs": {
71
- "flake-compat": "flake-compat_2",
72
- "flake-utils": "flake-utils_2",
73
- "nixpkgs": "nixpkgs"
74
- },
75
- "locked": {
76
- "lastModified": 1753354560,
77
- "narHash": "sha256-vmOfRmr0Qm/IbZTWB2sBn+UFrABSTTA/cTg+m27Yt/E=",
78
- "owner": "huggingface",
79
- "repo": "hf-nix",
80
- "rev": "7f2aceda2a2e72cd573bdb25e5c0667fd75f89d3",
81
- "type": "github"
82
- },
83
- "original": {
84
- "owner": "huggingface",
85
- "repo": "hf-nix",
86
- "type": "github"
87
- }
88
- },
89
- "kernel-builder": {
90
- "inputs": {
91
- "flake-compat": "flake-compat",
92
- "flake-utils": "flake-utils",
93
- "hf-nix": "hf-nix",
94
- "nixpkgs": [
95
- "kernel-builder",
96
- "hf-nix",
97
- "nixpkgs"
98
- ]
99
- },
100
- "locked": {
101
- "lastModified": 1753602110,
102
- "narHash": "sha256-AEt6rSqYqSTgsKZ+2BuGezurpVC2gm+Jpjqg2D54n7E=",
103
- "owner": "huggingface",
104
- "repo": "kernel-builder",
105
- "rev": "2021ea0f8d9e63ada986f189d077fc301cc1c3c9",
106
- "type": "github"
107
- },
108
- "original": {
109
- "owner": "huggingface",
110
- "ref": "torch-2.8",
111
- "repo": "kernel-builder",
112
- "type": "github"
113
- }
114
- },
115
- "nixpkgs": {
116
- "locked": {
117
- "lastModified": 1752785354,
118
- "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
119
- "owner": "nixos",
120
- "repo": "nixpkgs",
121
- "rev": "d38025438a6ee456758dc03188ca6873a415463b",
122
- "type": "github"
123
- },
124
- "original": {
125
- "owner": "nixos",
126
- "repo": "nixpkgs",
127
- "rev": "d38025438a6ee456758dc03188ca6873a415463b",
128
- "type": "github"
129
- }
130
- },
131
- "root": {
132
- "inputs": {
133
- "kernel-builder": "kernel-builder"
134
- }
135
- },
136
- "systems": {
137
- "locked": {
138
- "lastModified": 1681028828,
139
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
140
- "owner": "nix-systems",
141
- "repo": "default",
142
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
143
- "type": "github"
144
- },
145
- "original": {
146
- "owner": "nix-systems",
147
- "repo": "default",
148
- "type": "github"
149
- }
150
- },
151
- "systems_2": {
152
- "locked": {
153
- "lastModified": 1681028828,
154
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
155
- "owner": "nix-systems",
156
- "repo": "default",
157
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
158
- "type": "github"
159
- },
160
- "original": {
161
- "owner": "nix-systems",
162
- "repo": "default",
163
- "type": "github"
164
- }
165
- }
166
- },
167
- "root": "root",
168
- "version": 7
169
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.nix DELETED
@@ -1,17 +0,0 @@
1
- {
2
- description = "Flake for EETQ kernels";
3
-
4
- inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder/torch-2.8";
6
- };
7
-
8
- outputs =
9
- {
10
- self,
11
- kernel-builder,
12
- }:
13
- kernel-builder.lib.genFlakeOutputs {
14
- path = ./.;
15
- rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
- };
17
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/quantization_eetq/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .custom_ops import w8_a16_gemm, w8_a16_gemm_, preprocess_weights, quant_weights
2
-
3
- __all__ = ["w8_a16_gemm", "w8_a16_gemm_", "preprocess_weights", "quant_weights"]
 
 
 
 
torch-ext/quantization_eetq/custom_ops.py DELETED
@@ -1,36 +0,0 @@
1
- from typing import List
2
- import torch
3
-
4
- from ._ops import ops
5
-
6
-
7
- def w8_a16_gemm(
8
- input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
9
- ) -> torch.Tensor:
10
- return ops.w8_a16_gemm(input, weight, scale)
11
-
12
-
13
- def w8_a16_gemm_(
14
- input: torch.Tensor,
15
- weight: torch.Tensor,
16
- scale: torch.Tensor,
17
- output: torch.Tensor,
18
- m: int,
19
- n: int,
20
- k: int,
21
- ) -> torch.Tensor:
22
- return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k)
23
-
24
-
25
- def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor:
26
- return ops.preprocess_weights(origin_weight, is_int4)
27
-
28
-
29
- def quant_weights(
30
- origin_weight: torch.Tensor,
31
- quant_type: torch.dtype,
32
- return_unprocessed_quantized_tensor: bool,
33
- ) -> List[torch.Tensor]:
34
- return ops.quant_weights(
35
- origin_weight, quant_type, return_unprocessed_quantized_tensor
36
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/torch_binding.cpp DELETED
@@ -1,19 +0,0 @@
1
- #include <torch/library.h>
2
-
3
- #include "registration.h"
4
- #include "torch_binding.h"
5
-
6
- TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
- ops.def("w8_a16_gemm(Tensor input, Tensor weight, Tensor scale) -> Tensor");
8
- ops.impl("w8_a16_gemm", torch::kCUDA, &w8_a16_gemm_forward_cuda);
9
- ops.def("w8_a16_gemm_(Tensor input, Tensor weight, Tensor scale, Tensor! output,"
10
- "int m, int n, int k) -> Tensor");
11
- ops.impl("w8_a16_gemm_", torch::kCUDA, &w8_a16_gemm_forward_cuda_);
12
- ops.def("preprocess_weights(Tensor origin_weight, bool is_int4) -> Tensor");
13
- ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda);
14
- ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type,"
15
- "bool return_unprocessed_quantized_tensor) -> Tensor[]");
16
- ops.impl("quant_weights", torch::kCPU, &symmetric_quantize_last_axis_of_tensor);
17
- }
18
-
19
- REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/torch_binding.h DELETED
@@ -1,25 +0,0 @@
1
- #pragma once
2
-
3
- #include <vector>
4
-
5
- #include <torch/torch.h>
6
-
7
- std::vector<torch::Tensor>
8
- symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
9
- at::ScalarType quant_type,
10
- bool return_unprocessed_quantized_tensor);
11
-
12
- torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight,
13
- bool is_int4);
14
-
15
- torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
16
- torch::Tensor const&weight,
17
- torch::Tensor const &scale);
18
-
19
- torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
20
- torch::Tensor const &weight,
21
- torch::Tensor const &scale,
22
- torch::Tensor &output,
23
- const int64_t m,
24
- const int64_t n,
25
- const int64_t k);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/activation_types.h DELETED
@@ -1,40 +0,0 @@
1
- /*
2
- * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #pragma once
18
-
19
- #include "cuda_utils.h"
20
-
21
- namespace fastertransformer {
22
-
23
- enum class ActivationType {
24
- Gelu,
25
- Relu,
26
- Silu,
27
- GeGLU,
28
- ReGLU,
29
- SiGLU,
30
- Identity,
31
- InvalidType
32
- };
33
-
34
- inline bool isGatedActivation(ActivationType activaiton_type)
35
- {
36
- return activaiton_type == ActivationType::GeGLU || activaiton_type == ActivationType::ReGLU
37
- || activaiton_type == ActivationType::SiGLU;
38
- }
39
-
40
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/cuda_utils.cc DELETED
@@ -1,55 +0,0 @@
1
- /*
2
- * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #include "cuda_utils.h"
18
-
19
- namespace fastertransformer {
20
-
21
- /* ***************************** common utils ****************************** */
22
-
23
- cudaError_t getSetDevice(int i_device, int* o_device)
24
- {
25
- int current_dev_id = 0;
26
- cudaError_t err = cudaSuccess;
27
-
28
- if (o_device != NULL) {
29
- err = cudaGetDevice(&current_dev_id);
30
- if (err != cudaSuccess) {
31
- return err;
32
- }
33
- if (current_dev_id == i_device) {
34
- *o_device = i_device;
35
- }
36
- else {
37
- err = cudaSetDevice(i_device);
38
- if (err != cudaSuccess) {
39
- return err;
40
- }
41
- *o_device = current_dev_id;
42
- }
43
- }
44
- else {
45
- err = cudaSetDevice(i_device);
46
- if (err != cudaSuccess) {
47
- return err;
48
- }
49
- }
50
-
51
- return cudaSuccess;
52
- }
53
-
54
- /* ************************** end of common utils ************************** */
55
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/cuda_utils.h DELETED
@@ -1,76 +0,0 @@
1
- /*
2
- * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #pragma once
18
-
19
- #include "logger.h"
20
-
21
- #include <cuda_runtime.h>
22
- #include <fstream>
23
- #include <iostream>
24
- #include <string>
25
- #include <vector>
26
-
27
- namespace fastertransformer {
28
- /* **************************** debug tools ********************************* */
29
- template<typename T>
30
- void check(T result, char const* const func, const char* const file, int const line)
31
- {
32
- if (result) {
33
- throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + ("<unknown>") + " "
34
- + file + ":" + std::to_string(line) + " \n");
35
- }
36
- }
37
-
38
- #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
39
-
40
- [[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "")
41
- {
42
- throw std::runtime_error(std::string("[FT][ERROR] ") + info + " Assertion fail: " + file + ":"
43
- + std::to_string(line) + " \n");
44
- }
45
-
46
- inline void myAssert(bool result, const char* const file, int const line, std::string const& info = "")
47
- {
48
- if (!result) {
49
- throwRuntimeError(file, line, info);
50
- }
51
- }
52
-
53
- #define FT_CHECK(val) myAssert(val, __FILE__, __LINE__)
54
- #define FT_CHECK_WITH_INFO(val, info) \
55
- do { \
56
- bool is_valid_val = (val); \
57
- if (!is_valid_val) { \
58
- fastertransformer::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \
59
- } \
60
- } while (0)
61
-
62
- /* ***************************** common utils ****************************** */
63
- inline int getSMVersion()
64
- {
65
- int device{-1};
66
- check_cuda_error(cudaGetDevice(&device));
67
- int sm_major = 0;
68
- int sm_minor = 0;
69
- check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
70
- check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
71
- return sm_major * 10 + sm_minor;
72
- }
73
-
74
- cudaError_t getSetDevice(int i_device, int* o_device = NULL);
75
- /* ************************** end of common utils ************************** */
76
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/logger.cc DELETED
@@ -1,59 +0,0 @@
1
- /*
2
- * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #include "logger.h"
18
- #include <cuda_runtime.h>
19
-
20
- namespace fastertransformer {
21
-
22
- Logger::Logger()
23
- {
24
- char* is_first_rank_only_char = std::getenv("FT_LOG_FIRST_RANK_ONLY");
25
- bool is_first_rank_only =
26
- (is_first_rank_only_char != nullptr && std::string(is_first_rank_only_char) == "ON") ? true : false;
27
-
28
- int device_id;
29
- cudaGetDevice(&device_id);
30
-
31
- char* level_name = std::getenv("FT_LOG_LEVEL");
32
- if (level_name != nullptr) {
33
- std::map<std::string, Level> name_to_level = {
34
- {"TRACE", TRACE},
35
- {"DEBUG", DEBUG},
36
- {"INFO", INFO},
37
- {"WARNING", WARNING},
38
- {"ERROR", ERROR},
39
- };
40
- auto level = name_to_level.find(level_name);
41
- // If FT_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR
42
- if (is_first_rank_only && device_id != 0) {
43
- level = name_to_level.find("ERROR");
44
- }
45
- if (level != name_to_level.end()) {
46
- setLevel(level->second);
47
- }
48
- else {
49
- fprintf(stderr,
50
- "[FT][WARNING] Invalid logger level FT_LOG_LEVEL=%s. "
51
- "Ignore the environment variable and use a default "
52
- "logging level.\n",
53
- level_name);
54
- level_name = nullptr;
55
- }
56
- }
57
- }
58
-
59
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/logger.h DELETED
@@ -1,121 +0,0 @@
1
- /*
2
- * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #pragma once
18
-
19
- #include <cstdlib>
20
- #include <map>
21
- #include <string>
22
-
23
- #include "string_utils.h"
24
-
25
- namespace fastertransformer {
26
-
27
- class Logger {
28
-
29
- public:
30
- enum Level {
31
- TRACE = 0,
32
- DEBUG = 10,
33
- INFO = 20,
34
- WARNING = 30,
35
- ERROR = 40
36
- };
37
-
38
- static Logger& getLogger()
39
- {
40
- thread_local Logger instance;
41
- return instance;
42
- }
43
- Logger(Logger const&) = delete;
44
- void operator=(Logger const&) = delete;
45
-
46
- template<typename... Args>
47
- void log(const Level level, const std::string format, const Args&... args)
48
- {
49
- if (level_ <= level) {
50
- std::string fmt = getPrefix(level) + format + "\n";
51
- FILE* out = level_ < WARNING ? stdout : stderr;
52
- std::string logstr = fmtstr(fmt, args...);
53
- fprintf(out, "%s", logstr.c_str());
54
- }
55
- }
56
-
57
- template<typename... Args>
58
- void log(const Level level, const int rank, const std::string format, const Args&... args)
59
- {
60
- if (level_ <= level) {
61
- std::string fmt = getPrefix(level, rank) + format + "\n";
62
- FILE* out = level_ < WARNING ? stdout : stderr;
63
- std::string logstr = fmtstr(fmt, args...);
64
- fprintf(out, "%s", logstr.c_str());
65
- }
66
- }
67
-
68
- void setLevel(const Level level)
69
- {
70
- level_ = level;
71
- log(INFO, "Set logger level by %s", getLevelName(level).c_str());
72
- }
73
-
74
- int getLevel() const
75
- {
76
- return level_;
77
- }
78
-
79
- private:
80
- const std::string PREFIX = "[FT]";
81
- const std::map<const Level, const std::string> level_name_ = {
82
- {TRACE, "TRACE"}, {DEBUG, "DEBUG"}, {INFO, "INFO"}, {WARNING, "WARNING"}, {ERROR, "ERROR"}};
83
-
84
- #ifndef NDEBUG
85
- const Level DEFAULT_LOG_LEVEL = DEBUG;
86
- #else
87
- const Level DEFAULT_LOG_LEVEL = INFO;
88
- #endif
89
- Level level_ = DEFAULT_LOG_LEVEL;
90
-
91
- Logger();
92
-
93
- inline const std::string getLevelName(const Level level)
94
- {
95
- return level_name_.at(level);
96
- }
97
-
98
- inline const std::string getPrefix(const Level level)
99
- {
100
- return PREFIX + "[" + getLevelName(level) + "] ";
101
- }
102
-
103
- inline const std::string getPrefix(const Level level, const int rank)
104
- {
105
- return PREFIX + "[" + getLevelName(level) + "][" + std::to_string(rank) + "] ";
106
- }
107
- };
108
-
109
- #define FT_LOG(level, ...) \
110
- do { \
111
- if (fastertransformer::Logger::getLogger().getLevel() <= level) { \
112
- fastertransformer::Logger::getLogger().log(level, __VA_ARGS__); \
113
- } \
114
- } while (0)
115
-
116
- #define FT_LOG_TRACE(...) FT_LOG(fastertransformer::Logger::TRACE, __VA_ARGS__)
117
- #define FT_LOG_DEBUG(...) FT_LOG(fastertransformer::Logger::DEBUG, __VA_ARGS__)
118
- #define FT_LOG_INFO(...) FT_LOG(fastertransformer::Logger::INFO, __VA_ARGS__)
119
- #define FT_LOG_WARNING(...) FT_LOG(fastertransformer::Logger::WARNING, __VA_ARGS__)
120
- #define FT_LOG_ERROR(...) FT_LOG(fastertransformer::Logger::ERROR, __VA_ARGS__)
121
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/string_utils.h DELETED
@@ -1,54 +0,0 @@
1
- /*
2
- * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #pragma once
18
-
19
- #include <memory> // std::make_unique
20
- #include <sstream> // std::stringstream
21
- #include <string>
22
- #include <vector>
23
-
24
- namespace fastertransformer {
25
-
26
- template<typename... Args>
27
- inline std::string fmtstr(const std::string& format, Args... args)
28
- {
29
- // This function came from a code snippet in stackoverflow under cc-by-1.0
30
- // https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf
31
-
32
- // Disable format-security warning in this function.
33
- #if defined(_MSC_VER) // for visual studio
34
- #pragma warning(push)
35
- #pragma warning(warning(disable : 4996))
36
- #elif defined(__GNUC__) || defined(__clang__) // for gcc or clang
37
- #pragma GCC diagnostic push
38
- #pragma GCC diagnostic ignored "-Wformat-security"
39
- #endif
40
- int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0'
41
- if (size_s <= 0) {
42
- throw std::runtime_error("Error during formatting.");
43
- }
44
- auto size = static_cast<size_t>(size_s);
45
- auto buf = std::make_unique<char[]>(size);
46
- std::snprintf(buf.get(), size, format.c_str(), args...);
47
- #if defined(_MSC_VER)
48
- #pragma warning(pop)
49
- #elif defined(__GNUC__) || defined(__clang__)
50
- #pragma GCC diagnostic pop
51
- #endif
52
- return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside
53
- }
54
- } // namespace fastertransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/torch_utils.h DELETED
@@ -1,68 +0,0 @@
1
- #pragma once
2
- #include "torch/csrc/cuda/Stream.h"
3
- #include "torch/all.h"
4
- #include <ATen/cuda/CUDAContext.h>
5
- #include <cstdio>
6
- #include <cuda_fp16.h>
7
- #include <cuda_runtime.h>
8
- #include <iostream>
9
- // Generates a conflict with CUDA 12.6 between nvtx 2 and 3. Does not
10
- // seem to be used anyway?
11
- //
12
- // #include <nvToolsExt.h>
13
- #include <torch/custom_class.h>
14
- #include <torch/script.h>
15
- #include <vector>
16
-
17
- #define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
18
- #define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
19
- #define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
20
- #define CHECK_TYPE(x, st) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x)
21
- #define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
22
- #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
23
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
24
- #define CHECK_INPUT(x, st) \
25
- CHECK_TH_CUDA(x); \
26
- CHECK_CONTIGUOUS(x); \
27
- CHECK_TYPE(x, st)
28
- #define CHECK_CPU_INPUT(x, st) \
29
- CHECK_CPU(x); \
30
- CHECK_CONTIGUOUS(x); \
31
- CHECK_TYPE(x, st)
32
- #define CHECK_OPTIONAL_INPUT(x, st) \
33
- if (x.has_value()) { \
34
- CHECK_INPUT(x.value(), st); \
35
- }
36
- #define CHECK_OPTIONAL_CPU_INPUT(x, st) \
37
- if (x.has_value()) { \
38
- CHECK_CPU_INPUT(x.value(), st); \
39
- }
40
- #define PRINT_TENSOR(x) std::cout << #x << ":\n" << x << std::endl
41
- #define PRINT_TENSOR_SIZE(x) std::cout << "size of " << #x << ": " << x.sizes() << std::endl
42
-
43
- namespace fastertransformer {
44
-
45
- template<typename T>
46
- inline T* get_ptr(torch::Tensor& t)
47
- {
48
- return reinterpret_cast<T*>(t.data_ptr());
49
- }
50
-
51
- std::vector<size_t> convert_shape(torch::Tensor tensor);
52
-
53
- size_t sizeBytes(torch::Tensor tensor);
54
-
55
- QuantType get_ft_quant_type(torch::ScalarType quant_type)
56
- {
57
- if (quant_type == torch::kInt8) {
58
- return QuantType::INT8_WEIGHT_ONLY;
59
- }
60
- else if (quant_type == at::ScalarType::QUInt4x2) {
61
- return QuantType::PACKED_INT4_WEIGHT_ONLY;
62
- }
63
- else {
64
- TORCH_CHECK(false, "Invalid quantization type");
65
- }
66
- }
67
-
68
- } // namespace fastertransformer