diff --git a/README.md b/README.md index 5f30ed277963890fc23e47a55f372fd4ecee2750..256e495c21799c774d5e3821504919107f5acdc1 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,14 @@ --- license: apache-2.0 tags: -- kernel + - kernel --- ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/quantization-eetq) ## eetq -EETQ kernels from [NetEase-FuXi/EETQ](https://github.com/NetEase-FuXi/EETQ). \ No newline at end of file +EETQ kernels from [NetEase-FuXi/EETQ](https://github.com/NetEase-FuXi/EETQ). + +Kernel source: https://github.com/huggingface/kernels-community/tree/main/quantization-eetq + diff --git a/build.toml b/build.toml deleted file mode 100644 index 86507fe154794d9cb40f023bf1b1652a27686351..0000000000000000000000000000000000000000 --- a/build.toml +++ /dev/null @@ -1,92 +0,0 @@ -[general] -name = "quantization_eetq" -universal = false - -[torch] -src = [ - "torch-ext/torch_binding.cpp", - "torch-ext/torch_binding.h", -] - -[kernel.weight_only_batched_gemv] -backend = "cuda" -depends = [ - "cutlass_2_10", - "torch", -] -include = ["cutlass_extensions/include"] -src = [ - "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h", - "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h", - "weightOnlyBatchedGemv/common.h", - "weightOnlyBatchedGemv/enabled.h", - "weightOnlyBatchedGemv/kernel.h", - "weightOnlyBatchedGemv/kernelLauncher.cu", - "weightOnlyBatchedGemv/kernelLauncher.h", - "weightOnlyBatchedGemv/utility.h", - "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu", - "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu", - "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu", - "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu", - "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu", - "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu", - "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu", - "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu", -] - -[kernel.cutlass_kernels] -backend = "cuda" -depends = [ - "cutlass_2_10", - "torch", -] -include = [ - ".", - "utils", - "cutlass_extensions/include", -] -src = [ - "cutlass_extensions/include/cutlass_extensions/arch/mma.h", - "cutlass_extensions/include/cutlass_extensions/compute_occupancy.h", - "cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h", - "cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h", - "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h", - "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h", - "cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h", - "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h", - "cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h", - "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h", - "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h", - "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h", - "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h", - "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h", - "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h", - "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h", - "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h", - "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h", - "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h", - "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h", - "cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h", - "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h", - "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h", - "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h", - "cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h", - "cutlass_kernels/cutlass_heuristic.cu", - "cutlass_kernels/cutlass_heuristic.h", - "cutlass_kernels/cutlass_preprocessors.cc", - "cutlass_kernels/cutlass_preprocessors.h", - "cutlass_kernels/fpA_intB_gemm.cu", - "cutlass_kernels/fpA_intB_gemm.h", - "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h", - "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h", - "cutlass_kernels/fpA_intB_gemm_wrapper.cu", - "cutlass_kernels/fpA_intB_gemm_wrapper.h", - "weightOnlyBatchedGemv/common.h", - "weightOnlyBatchedGemv/enabled.h", - "utils/activation_types.h", - "utils/cuda_utils.h", - "utils/logger.cc", - "utils/logger.h", - "utils/string_utils.h", - "utils/torch_utils.h", -] diff --git a/cutlass_extensions/include/cutlass_extensions/arch/mma.h b/cutlass_extensions/include/cutlass_extensions/arch/mma.h deleted file mode 100644 index f4331bb68a0bfacdc8372ae55d8355e5e160b209..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/arch/mma.h +++ /dev/null @@ -1,46 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing architecture support for multiply-add operations -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace arch { - -// Tag which triggers MMA which will trigger -struct OpMultiplyAddDequantizeInterleavedBToA; - -} // namespace arch -} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h b/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h deleted file mode 100644 index bad9b324601aeb564a4e244eff434de29f0dd176..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include "cutlass/device_kernel.h" -#include "utils/cuda_utils.h" - -namespace fastertransformer { - -template -inline int compute_occupancy_for_kernel() -{ - - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size > (48 << 10)) { - cudaError_t status = - cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - if (status == cudaError::cudaErrorInvalidValue) { - // Clear the error bit since we can ignore this. - // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an - // occupancy of 0. This will cause the heuristic to ignore this configuration. - status = cudaGetLastError(); - return 0; - } - check_cuda_error(status); - } - - int max_active_blocks = -1; - check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); - - return max_active_blocks; -} - -} // namespace fastertransformer diff --git a/cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h b/cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h deleted file mode 100644 index 3697b6748eea34ca50924c524079d38061e8d8dd..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h +++ /dev/null @@ -1,48 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { - -// define scaling mode -enum class QuantMode { - PerTensorQuant, - PerTokenQuant, - PerChannelQuant, - PerTokenChannelQuant -}; - -} // namespace epilogue -} // namespace cutlass diff --git a/cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h b/cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h deleted file mode 100644 index 6a1f7ee80b02d34af65e3e5574bacff491a6655a..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h +++ /dev/null @@ -1,148 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing linear combination with a maximum operation used by epilogues. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/epilogue/thread/scale_type.h" -#include "cutlass/functional.h" -#include "cutlass/half.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -__forceinline__ __device__ float copysignf_pos(float a, float b) -{ - float r; - r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); - return r; -} - -__forceinline__ __device__ float tanh_opt(float x) -{ -#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) - const float exp_val = -1.f * fabs(2 * x); - return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); -#else - return fast_tanh(x); -#endif -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// DdK: GELU_taylor ir incomplete in 2.10. Vendored fixes here. - -// GELU operator implemented using the Taylor series approximation -template -struct GELU_taylor_fixed { - static const bool kIsHeavy=true; - CUTLASS_HOST_DEVICE - T operator()(T const &z) const { - - T k0 = T(0.7978845608028654); - T k1 = T(0.044715); - - return T(cutlass::constants::half() * z * - (cutlass::constants::one() + fast_tanh(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } - - using Params = LinearCombinationGenericParams; - - CUTLASS_HOST_DEVICE - T operator()(T const &scalar, Params const ¶ms_) const { - return this->operator()(scalar); - } -}; - -template<> -struct GELU_taylor_fixed { - static const bool kIsHeavy = true; - CUTLASS_DEVICE - float operator()(float const& z) const - { - - float k0 = float(0.7978845608028654); - float k1 = float(0.044715); - - return float( - cutlass::constants::half() * z - * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } - - using Params = LinearCombinationGenericParams; - - CUTLASS_DEVICE - float operator()(float const& scalar, Params const& params_) const - { - return this->operator()(scalar); - } -}; - -template -struct GELU_taylor_fixed > { - static const bool kIsHeavy=true; - CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs) const { - Array y; - GELU_taylor gelu_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - y[i] = gelu_op(rhs[i]); - } - - return y; - } - - using Params = LinearCombinationGenericParams; - CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, Params const ¶ms_) const { - return this->operator()(rhs); - } -}; - -} // namespace thread -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h deleted file mode 100644 index 53b70e8019addf1d1a186249c329528aefc66118..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h +++ /dev/null @@ -1,390 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. - - original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h - -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "../epilogue_quant_helper.h" -#include "cutlass/arch/memory.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/numeric_conversion.h" - -namespace cutlass { -namespace epilogue { -namespace threadblock { - -template -class EpilogueVisitorPerRowPerCol { -public: - using ThreadblockShape = ThreadblockShape_; - static int const kThreadCount = ThreadCount; - - using ScaleTileIterator = ScaleTileIterator_; - using OutputTileIterator = OutputTileIterator_; - using ElementwiseFunctor = ElementwiseFunctor_; - - static int const kIterations = OutputTileIterator::kIterations; - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - - using ElementOutput = typename OutputTileIterator::Element; - using LayoutOutput = cutlass::layout::RowMajor; - using ElementAccumulator = ElementAccumulator_; - - using AlphaScaleElementType = typename ScaleTileIterator::Element; - - using ElementCompute = ElementCompute_; - using AccumulatorFragment = Array; - using ComputeFragment = Array; - using OutputVector = Array; - - static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; - static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); - - /// Argument structure - struct Arguments { - - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; - - // - // Methods - // - Arguments(): batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} - - Arguments(typename ElementwiseFunctor::Params elementwise_): - elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) - { - } - - Arguments(typename ElementwiseFunctor::Params elementwise_, - int64_t batch_stride_alpha_, - int64_t batch_stride_C_, - int64_t batch_stride_D_): - elementwise(elementwise_), - batch_stride_alpha(batch_stride_alpha_), - batch_stride_C(batch_stride_C_), - batch_stride_D(batch_stride_D_) - { - } - }; - - struct Params { - - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Arguments const& args): - elementwise(args.elementwise), - batch_stride_alpha(args.batch_stride_alpha), - batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D) - { - } - }; - - /// Shared storage - struct SharedStorage {}; - -private: - Params const& params_; - SharedStorage& shared_storage_; - MatrixCoord extent_; - MatrixCoord extent_real_; - ElementwiseFunctor elementwise_; - - const bool per_token_quant_; - const bool per_channel_quant_; - - AlphaScaleElementType* ptr_alpha_row_; - AlphaScaleElementType* ptr_alpha_col_; - ScaleTileIterator iterator_alpha_col_; - OutputTileIterator iterator_C_; - OutputTileIterator iterator_D_; - - AlphaScaleElementType element_alpha_row_ = 1.0f; - AlphaScaleElementType element_alpha_col_ = 1.0f; - typename ScaleTileIterator::Fragment fragment_alpha_col_; - typename OutputTileIterator::Fragment fragment_C_; - typename OutputTileIterator::Fragment fragment_D_; - - ElementAccumulator beta_; - - int column_offset_; - - MatrixCoord thread_offset_; - -public: - CUTLASS_DEVICE - EpilogueVisitorPerRowPerCol(Params const& params, - SharedStorage& shared_storage, - cutlass::MatrixCoord const& problem_size, - int thread_idx, - int warp_idx, - int lane_idx, - typename ScaleTileIterator::Params params_alpha_col, - typename OutputTileIterator::Params params_C, - typename OutputTileIterator::Params params_D, - QuantMode quant_mode, - AlphaScaleElementType* ptr_alpha_row, - AlphaScaleElementType* ptr_alpha_col, - typename OutputTileIterator::Element* ptr_C, - typename OutputTileIterator::Element* ptr_D, - cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), - int column_offset = 0, - cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)): - params_(params), - shared_storage_(shared_storage), - extent_(problem_size), - elementwise_(params.elementwise), - per_token_quant_(quant_mode == QuantMode::PerTokenQuant || quant_mode == QuantMode::PerTokenChannelQuant), - per_channel_quant_(quant_mode == QuantMode::PerChannelQuant || quant_mode == QuantMode::PerTokenChannelQuant), - ptr_alpha_row_(ptr_alpha_row), - ptr_alpha_col_(ptr_alpha_col), - iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset), - iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), - iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), - extent_real_(problem_size_real) - { - beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); - - if (beta_ == ElementAccumulator()) { - iterator_C_.clear_mask(); - } - } - - /// Helper to indicate split-K behavior - CUTLASS_DEVICE - void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme - int split_k_slices) - { ///< Total number of split-K slices - } - - /// Called to set the batch index - CUTLASS_DEVICE - void set_batch_index(int batch_idx) - { - iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); - iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); - iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); - } - - /// Called at the start of the epilogue just before iterating over accumulator slices - CUTLASS_DEVICE - void begin_epilogue() - { - if (per_channel_quant_) { - iterator_alpha_col_.load(fragment_alpha_col_); - } - else if (ptr_alpha_col_ != nullptr) { - arch::global_load( - element_alpha_col_, ptr_alpha_col_, true); - } - - if (!per_token_quant_ && ptr_alpha_row_ != nullptr) { - arch::global_load( - element_alpha_row_, ptr_alpha_row_, true); - } - } - - /// Called at the start of one step before starting accumulator exchange - CUTLASS_DEVICE - void begin_step(int step_idx) - { - fragment_D_.clear(); - fragment_C_.clear(); - - if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { - iterator_C_.load(fragment_C_); - ++iterator_C_; - } - - // load alpha_row in begin_step only when per token(row) scaling is used - if (per_token_quant_) { - int thread_offset_row = - iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(0).row(); - - // element_alpha_row_ = ptr_alpha_row_[thread_offset_row]; - arch::global_load( - element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); - } - } - - /// Called at the start of a row - CUTLASS_DEVICE - void begin_row(int row_idx) - { - // Clear accumulators for max and sum when starting a whole row - } - - /// Called after accumulators have been exchanged for each accumulator vector - CUTLASS_DEVICE - void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) - { - - NumericArrayConverter source_converter; - - ComputeFragment result = source_converter(accum); - if (per_channel_quant_) { - ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[frag_idx]; - result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); - } - else { - result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); - } - - /* printf("%d %e\n", accum[0], result[0]); */ - /* scale_accumulator_(result, alpha_row_vector[0]); //TODO(mseznec) */ - - /* if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { */ - /* result = source_converter(elementwise_(result)); */ - /* } else { */ - /* result = source_converter(elementwise_(result, source_vector)); */ - /* } */ - - /* // Convert to the output */ - NumericArrayConverter output_converter; - OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; - output = output_converter(result); - } - - /// Called at the end of a row - CUTLASS_DEVICE - void end_row(int row_idx) - { - - /* using ConvertSumOutput = cutlass::NumericConverter; */ - /* using ConvertNormOutput = cutlass::NumericConverter; */ - - /* ConvertSumOutput convert_sum_output; */ - /* ConvertNormOutput convert_norm_output; */ - - /* // Compute accumulate sum only in the last step */ - /* accum_sum_ = warp_reduce_sum_(accum_sum_); */ - - /* bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0); */ - /* bool row_guard = thread_offset_.row() < extent_.row(); */ - /* bool is_write_thread = row_guard && is_first_thread_in_tile; */ - - /* int block_batch = blockIdx.z; */ - - /* ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch * - * params_.batch_stride_Max; */ - /* ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch * - * params_.batch_stride_Sum; */ - - /* arch::global_store( */ - /* convert_norm_output(accum_max_), */ - /* (void *)curr_ptr_max, */ - /* is_write_thread); */ - - /* arch::global_store( */ - /* convert_sum_output(accum_sum_), */ - /* (void *)curr_ptr_sum, */ - /* is_write_thread); */ - - /* // Clear accumulators for max and sum when finishing a whole row */ - /* clear_accum_(); */ - } - - /// Called after all accumulator elements have been visited - CUTLASS_DEVICE - void end_step(int step_idx) - { - - iterator_D_.store(fragment_D_); - ++iterator_D_; - } - - /// Called after all steps have been completed - CUTLASS_DEVICE - void end_epilogue() {} - -private: - CUTLASS_DEVICE - ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, - ComputeFragment const& scale_col, - AlphaScaleElementType const& scale_row) - { - - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) { - result[i] = accum[i] * (scale_col[i] * scale_row); - } - - return result; - } - - CUTLASS_DEVICE - ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, - AlphaScaleElementType const& scale_col, - AlphaScaleElementType const& scale_row) - { - - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) { - result[i] = accum[i] * (scale_col * scale_row); - } - - return result; - } -}; - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass diff --git a/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h deleted file mode 100644 index 0c16c0a59622b45a526639aa48e718deea9c2c32..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +++ /dev/null @@ -1,285 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. - - original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h - -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/platform/platform.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/epilogue/thread/linear_combination_clamp.h" -#include "cutlass/epilogue/thread/linear_combination_gelu.h" -#include "cutlass/epilogue/thread/linear_combination_hardswish.h" -#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -#include "cutlass/epilogue/thread/linear_combination_relu.h" -#include "cutlass/epilogue/thread/linear_combination_relu0.h" -#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" - -#include "cutlass/epilogue/thread/conversion_op.h" -#include "cutlass/epilogue/thread/reduction_op.h" - -#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" - -#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" -#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" -#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" - -#include "cutlass/epilogue/threadblock/epilogue.h" -#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" - -#include "cutlass/layout/permute.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -namespace detail { - -/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts. -template -struct DefaultIteratorsTensorOp { - - using WarpTileIterator = - cutlass::epilogue::warp::TileIteratorTensorOp; - - using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator; - - static int const kFragmentsPerIteration = 1; -}; - -/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. -template -struct DefaultIteratorsTensorOp { - - using WarpTileIterator = - cutlass::epilogue::warp::TileIteratorTensorOp; - - using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator; - - static int const kFragmentsPerIteration = 1; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load output tile from shared memory in epilogue. -/// -/// Satisfies: ReadableTileIterator -/// -template -class SharedLoadIteratorMixed { -public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = int32_t; - - using Layout = layout::RowMajor; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - - static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; - - static int const kThreads = ThreadMap::kThreads; - - /// Fragment object - using Fragment = Array; - - /// Memory access size - using AccessType = AlignedArray; - - /// Vector type used for SMEM loads - using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), - const_min(16, kAlignment)>; - - static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; - -private: - // - // Data members - // - - /// Byte-level pointer - LoadType const* pointers_[kLoadsPerAccess]; - - /// Stride along adjacent rows in units of LoadType - int stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - SharedLoadIteratorMixed(TensorRef ref, int thread_idx): stride_((ref.stride(0) / LoadType::kElements)) - { - - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); - - // Initialize pointers - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) { - pointers_[i] = reinterpret_cast(ref.data()); - - int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; - int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; - - col_idx += (bank_offset + i) % kLoadsPerAccess; - - pointers_[i] += thread_offset.row() * stride_ + col_idx; - } - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) { - pointers_[i] += pointer_offset / LoadType::kElements; - } - } - - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const& offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) { - pointers_[i] += - offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const - { - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - - int row_ptr_offset = - row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_ - + cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements; - - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - LoadType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - - int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kLoadsPerAccess; ++v) { - - int vector_idx = - (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); - - LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; - - frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; - } - } - } - } - } - } - - /// Loads a fragment - CUTLASS_DEVICE - void load(Fragment& frag) const - { - - load_with_pointer_offset(frag, 0); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h b/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h deleted file mode 100644 index 8e8190c2a72acfd5fb8a227fdddf8e85e8540e75..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * @file epilogue_helpers.h - * - * This file includes types for the epilogues. The empty structs exist so we can signal to template - * code the type of epilogue we want to run, and let the underlying code specify the details such as - * element types, accumulator type and elements per vector access. - * - */ - -#pragma once - -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/epilogue/thread/linear_combination_generic.h" -#include "cutlass/epilogue/thread/linear_combination_relu.h" -#include "cutlass/epilogue/thread/linear_combination_silu.h" -#include "cutlass_extensions/epilogue/thread/ft_fused_activations.h" - -namespace fastertransformer { - -struct EpilogueOpBiasSilu {}; - -struct EpilogueOpBiasReLU {}; - -struct EpilogueOpBiasFtGelu {}; - -struct EpilogueOpBias {}; - -struct EpilogueOpNoBias {}; - -template -struct Epilogue { -}; - -template -struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationSilu; -}; - -template -struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationRelu; -}; - -template -struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationGeneric; -}; - -template -struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombination; -}; - -template -struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombination; -}; - -} // namespace fastertransformer diff --git a/cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h b/cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h deleted file mode 100644 index fbc01b0787c20e848658982960a468b32ccc82c1..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -namespace fastertransformer { -// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape -// in the kernel layout details when doing weight only quantization. -enum class CutlassTileConfig { - // Signals that we should run heuristics do choose a config - Undefined, - - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - // SiMT config - CtaShape128x128x8_WarpShape64x64x8, - - // TensorCore configs CTA_N = 128, CTA_K = 64 - // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, - - // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, - CtaShape64x128x64_WarpShape64x32x64, - - // Warp configs for M=128 - CtaShape128x128x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape128x32x64 -}; - -enum class SplitKStyle { - NO_SPLIT_K, - SPLIT_K_SERIAL, - // SPLIT_K_PARALLEL // Not supported yet -}; - -struct CutlassGemmConfig { - CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; - SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; - int split_k_factor = -1; - int stages = -1; -}; - -} // namespace fastertransformer \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h deleted file mode 100644 index a903254ccac4dcf5554e65d9c14c2e159b9caf95..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ /dev/null @@ -1,123 +0,0 @@ -#pragma once - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/bfloat16.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/layout/matrix.h" - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" - -namespace cutlass { -namespace gemm { -namespace kernel { - -template -struct MixedGemmArchTraits { -}; - -template -struct MixedGemmArchTraits { - static constexpr int Stages = 2; - using OperatorClass = cutlass::arch::OpClassSimt; - using AccType = float; - using LayoutB = cutlass::layout::RowMajor; - - static constexpr int ElementsPerAccessA = 1; - static constexpr int ElementsPerAccessB = 1; - static constexpr int ElementsPerAccessC = 1; - static constexpr int ThreadblockK = 8; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -// ========================= Volta Traits =========================== -// Volta will always dequantize after the global memory load. -// This will instantiate any HMMA tensorcore kernels for Volta. -// Note that volta does not have native bfloat support so weights and activations will be casted to fp16 -// and compute will happen in fp16 then will be converted for bf16 output. -template -struct MixedGemmArchTraits< - TypeA, - TypeB, - cutlass::arch::Sm70, - typename cutlass::platform::enable_if::value - || cutlass::platform::is_same::value>::type> { -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; - - using Operator = typename LayoutDetails::Operator; -}; - -// ======================= Turing Traits ============================== -// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 -// and compute will happen in fp16 then will be converted for bf16 output. -template -struct MixedGemmArchTraits< - TypeA, - TypeB, - cutlass::arch::Sm75, - typename cutlass::platform::enable_if::value - || cutlass::platform::is_same::value>::type> { -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; - - using Operator = typename LayoutDetails::Operator; -}; - -// ======================= Ampere Traits ============================== -template -struct MixedGemmArchTraits< - TypeA, - TypeB, - cutlass::arch::Sm80, - typename cutlass::platform::enable_if::value - || cutlass::platform::is_same::value>::type> { -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; - - using Operator = typename LayoutDetails::Operator; -}; - -} // namespace kernel -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h deleted file mode 100644 index 8eb6c10ea8bb948a725a8f02089f1ac68081d3c5..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ /dev/null @@ -1,492 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct GemmFpAIntB { - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static bool const kSplitKSerial = SplitKSerial; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Element; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Mma::LayoutC; - using ElementScale = ElementC; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformA; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - - /// Parameters structure - struct Arguments { - GemmUniversalMode mode = GemmUniversalMode::kGemm; - - cutlass::gemm::GemmCoord problem_size; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - - // Control serial split-k - int batch_count; - - typename EpilogueOutputOp::Params output_op; - - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // Included so we can use Gemm Universal - int batch_stride_D = 0; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Arguments() {} - - CUTLASS_HOST_DEVICE - Arguments(cutlass::gemm::GemmCoord const& problem_size, - typename Mma::IteratorA::TensorRef ref_A, - typename Mma::IteratorB::TensorRef ref_B, - typename Mma::IteratorScale::TensorRef ref_scale, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D, - int serial_split_k_factor, - typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), - int const* gather_A_indices = nullptr, - int const* gather_B_indices = nullptr, - int const* scatter_D_indices = nullptr): - problem_size(problem_size), - ref_A(ref_A), - ref_B(ref_B), - ref_scale(ref_scale), - ref_C(ref_C), - ref_D(ref_D), - batch_count(serial_split_k_factor), - output_op(output_op), - gather_A_indices(gather_A_indices), - gather_B_indices(gather_B_indices), - scatter_D_indices(scatter_D_indices) - { - } - }; - - /// Parameters structure - struct Params { - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::Params params_B; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::Params params_scale; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::Params params_D; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - typename EpilogueOutputOp::Params output_op; - int* semaphore; - int gemm_k_size; - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {} - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, - cutlass::gemm::GemmCoord const& grid_tiled_shape, - const int gemm_k_size, - void* workspace = nullptr): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), - params_A(args.ref_A.layout()), - ref_A(args.ref_A), - params_B(args.ref_B.layout()), - ref_B(args.ref_B), - params_scale(args.ref_scale.layout()), - ref_scale(args.ref_scale), - params_C(args.ref_C.layout()), - ref_C(args.ref_C), - params_D(args.ref_D.layout()), - ref_D(args.ref_D), - output_op(args.output_op), - semaphore(static_cast(workspace)), - gemm_k_size(gemm_k_size), - gather_A_indices(args.gather_A_indices), - gather_B_indices(args.gather_B_indices), - scatter_D_indices(args.scatter_D_indices) - { - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - GemmFpAIntB() {} - - /// Determines whether kernel satisfies alignment - CUTLASS_HOST_DEVICE - static Status can_implement(Arguments const& args) - { - - static int const kAlignmentA = - (platform::is_same>::value) ? - 32 : - (platform::is_same>::value) ? - 64 : - Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = - (platform::is_same>::value) ? - 32 : - (platform::is_same>::value) ? - 64 : - Mma::IteratorB::AccessType::kElements; - - static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; - - static int const kAlignmentC = (platform::is_same>::value) ? - 32 : - (platform::is_same>::value) ? - 64 : - Epilogue::OutputTileIterator::kElementsPerAccess; - - if (!TensorRef_aligned(args.ref_A, kAlignmentA)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_B, kAlignmentB)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_C, kAlignmentC)) { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_D, kAlignmentC)) { - return Status::kErrorMisalignedOperand; - } - - return Status::kSuccess; - } - - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; - } - - // The dummy template parameter is not used and exists so that we can compile this code using - // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in - // a namespace - template - struct KernelRunner { - CUTLASS_DEVICE - static void run_kernel(Params const& params, SharedStorage& shared_storage) - { - CUTLASS_NOT_IMPLEMENTED(); - } - }; - - template - struct KernelRunner { - CUTLASS_DEVICE - static void run_kernel(Params const& params, SharedStorage& shared_storage) - { - using LayoutB = typename Mma::IteratorB::Layout; - static_assert(platform::is_same::value && kInterleave == 1 - || platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() - || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - - return; - } - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size, - }; - - cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, - threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; - - cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() * Mma::Shape::kN}; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params.params_A, - params.ref_A.data(), - {params.problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_A, - params.gather_A_indices); - - typename Mma::IteratorB iterator_B(params.params_B, - params.ref_B.data(), - {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, - thread_idx, - tb_offset_B, - params.gather_B_indices); - - typename Mma::IteratorScale iterator_scale(params.params_scale, - params.ref_scale.data(), - {1, params.problem_size.n()}, - thread_idx, - tb_offset_scale); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - if (!kSplitKSerial || gemm_k_iterations > 0) { - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); - } - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // assume identity swizzle - MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C(params.params_C, - params.ref_C.data(), - params.problem_size.mn(), - thread_idx, - threadblock_offset, - params.scatter_D_indices); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D(params.params_D, - params.ref_D.data(), - params.problem_size.mn(), - thread_idx, - threadblock_offset, - params.scatter_D_indices); - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_offset.k()); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // - // Release the semaphore - // - - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - semaphore.release(lock); - } - } - }; - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h deleted file mode 100644 index bbbe3b053821961dcfc82b29678ab635dba606b2..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h +++ /dev/null @@ -1,447 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights - *reserved. SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or - support split-K. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct GemmFpAIntBWithBroadcast { - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Element; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Mma::LayoutC; - using ElementScale = ElementC; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformA; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = - Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - static constexpr int kInterleave = - Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - - /// Parameters structure - struct Arguments { - GemmUniversalMode mode = GemmUniversalMode::kGemm; - - cutlass::gemm::GemmCoord problem_size; - int batch_count; - typename EpilogueOutputOp::Params epilogue; - - void const *ptr_A; - void const *ptr_B; - void const *ptr_scales; - void const *ptr_C; - void *ptr_D; - - void const *ptr_Vector; - void const *ptr_Tensor; - - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_C; - int64_t batch_stride_D; - int64_t batch_stride_Vector; - int64_t batch_stride_Tensor; - - int lda, ldb, ldc, ldd, ldr, ldt; - - typename EpilogueOutputOp::Params output_op; - - // For gather+scatter operations - int const *gather_A_indices; - int const *gather_B_indices; - int const *scatter_D_indices; - - CUTLASS_HOST_DEVICE - Arguments() {} - - CUTLASS_HOST_DEVICE - Arguments(cutlass::gemm::GemmCoord const &problem_size, int batch_count, - typename EpilogueOutputOp::Params epilogue, void const *ptr_A, - void const *ptr_B, void const *ptr_scales, void const *ptr_C, - void *ptr_D, const void *ptr_Vector, const void *ptr_Tensor, - int64_t batch_stride_A, int64_t batch_stride_B, - int64_t batch_stride_C, int64_t batch_stride_D, - int64_t batch_stride_Vector, int64_t batch_stride_Tensor, - int lda, int ldb, int ldc, int ldd, int ldr, int ldt, - typename EpilogueOutputOp::Params output_op = - typename EpilogueOutputOp::Params()) - : problem_size(problem_size), batch_count(batch_count), - epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B), - ptr_scales(ptr_scales), ptr_C(ptr_C), ptr_D(ptr_D), - ptr_Vector(ptr_Vector), ptr_Tensor(ptr_Tensor), - batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), - batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), - batch_stride_Vector(batch_stride_Vector), - batch_stride_Tensor(batch_stride_Tensor), lda(lda), ldb(ldb), - ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt), output_op(output_op), - gather_A_indices(nullptr), gather_B_indices(nullptr), - scatter_D_indices(nullptr) {} - }; - - /// Parameters structure - struct Params { - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorB::Params params_B; - typename Mma::IteratorScale::Params params_scale; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::OutputTileIterator::Params params_D; - typename Epilogue::TensorTileIterator::Params params_Tensor; - - typename EpilogueOutputOp::Params output_op; - - // GemmUniversalMode mode; todo - int batch_count; - int gemm_k_size; - void *ptr_A; - void *ptr_B; - void *ptr_C; - void *ptr_scales; - void *ptr_D; - - void *ptr_Vector; - typename LayoutC::Stride::Index ldr; - - void *ptr_Tensor; - - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_C; - int64_t batch_stride_D; - int64_t batch_stride_Vector; - int64_t batch_stride_Tensor; - - // For gather+scatter operations - int const *gather_A_indices; - int const *gather_B_indices; - int const *scatter_D_indices; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() : swizzle_log_tile(0), gemm_k_size(0) {} - - CUTLASS_HOST_DEVICE - Params(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape, - const int gemm_k_size, void *workspace = nullptr) - : problem_size(args.problem_size), grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), - params_A(args.lda), params_B(args.ldb), params_C(args.ldc), - params_D(args.ldd), params_Tensor(args.ldt), output_op(args.epilogue), - batch_count(args.batch_count), gemm_k_size(gemm_k_size), - ptr_A(const_cast(args.ptr_A)), - ptr_B(const_cast(args.ptr_B)), - ptr_scales(const_cast(args.ptr_scales)), - ptr_C(const_cast(args.ptr_C)), ptr_D(args.ptr_D), - ptr_Vector(const_cast(args.ptr_Vector)), ldr(args.ldr), - ptr_Tensor(const_cast(args.ptr_Tensor)), batch_stride_A(args.batch_stride_A), - batch_stride_B(args.batch_stride_B), - batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D), - batch_stride_Vector(args.batch_stride_Vector), - batch_stride_Tensor(args.batch_stride_Tensor), - gather_A_indices(args.gather_A_indices), - gather_B_indices(args.gather_B_indices), - scatter_D_indices(args.scatter_D_indices) {} - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - GemmFpAIntBWithBroadcast() {} - - CUTLASS_HOST_DEVICE - static Status can_implement(Arguments const &args) { - // todo - return Status::kSuccess; - } - - static size_t - get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - - return 0; - } - - // The dummy template parameter is not used and exists so that we can compile - // this code using a standard earlier than C++17. Prior to C++17, fully - // specialized templates HAD to exists in a namespace - template struct KernelRunner { - CUTLASS_DEVICE - static void run_kernel(Params const ¶ms, - SharedStorage &shared_storage) { - CUTLASS_NOT_IMPLEMENTED(); - } - }; - - template struct KernelRunner { - CUTLASS_DEVICE - static void run_kernel(Params const ¶ms, - SharedStorage &shared_storage) { - using LayoutB = typename Mma::IteratorB::Layout; - static_assert( - platform::is_same::value && - kInterleave == 1 || - platform::is_same::value && - kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - - return; - } - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size, - }; - - cutlass::MatrixCoord tb_offset_B{ - threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, - threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; - - cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() * - Mma::Shape::kN}; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = - min(params.problem_size.k(), - (threadblock_tile_offset.k() + 1) * params.gemm_k_size); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = - (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / - Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, static_cast(params.ptr_A), - {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, - params.gather_A_indices); - - typename Mma::IteratorB iterator_B( - params.params_B, static_cast(params.ptr_B), - {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, - thread_idx, tb_offset_B, params.gather_B_indices); - - typename Mma::IteratorScale iterator_scale( - params.params_scale, static_cast(params.ptr_scales), - {1, params.problem_size.n()}, thread_idx, tb_offset_scale); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code is - // compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - if (gemm_k_iterations > 0) { - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, - iterator_scale, accumulators); - } - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN); - - int block_idx = threadblock_tile_offset.m() + - threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - ElementC *ptr_C = static_cast(params.ptr_C); - ElementC *ptr_D = static_cast(params.ptr_D); - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params.params_C, ptr_C, params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params.params_D, ptr_D, params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); - - typename Epilogue::ElementTensor *ptr_Tensor = - static_cast(params.ptr_Tensor); - - // Define the reduction output pointer and move to the appropriate place - typename Epilogue::ElementVector *ptr_Vector = - static_cast(params.ptr_Vector); - - typename Epilogue::TensorTileIterator tensor_iterator( - params.params_Tensor, - // Only the final block outputs Tensor - ptr_Tensor, params.problem_size.mn(), thread_idx, threadblock_offset); - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, - lane_idx); - - if (ptr_Vector) { - ptr_Vector += threadblock_offset.column() + - threadblock_tile_offset.m() * params.ldr; - } - - epilogue(output_op, ptr_Vector, iterator_D, accumulators, iterator_C, - tensor_iterator, params.problem_size.mn(), threadblock_offset); - } - }; - - /* - To improve compilation speed, we do not compile the device operator if the - CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel - operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) - static constexpr bool compile_needed = - platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - static constexpr bool compile_needed = - platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) - static constexpr bool compile_needed = - platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h deleted file mode 100644 index 14d45f0dbce17607dc4230bbb1ae06f711dd22ff..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is - quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices - to be consumed by CUTLASS. - - Note that for int4, ThreadBlockK MUST be 64. - - */ - -#pragma once - -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/platform/platform.h" - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/tile_interleaved_layout.h" - -namespace cutlass { -namespace gemm { -namespace kernel { - -template -struct LayoutDetailsB { -}; - -// Volta specialiations. Volta will dequantize before STS, so we need a different operator -template -struct LayoutDetailsB { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 8; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. -// TODO - Switch this to column major for weights since gemms should be more performant. -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, -// which signals that we want to dequantize after loading from smem. -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; -}; - -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; -}; - -} // namespace kernel -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h deleted file mode 100644 index b4b98db95278de0ea8c604050b4c6a20474a5654..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h +++ /dev/null @@ -1,106 +0,0 @@ -#pragma once - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -namespace cutlass { -namespace gemm { -namespace threadblock { -//////////////////////////////////////////////////////////////////////////////// - -// We need to distinguish here, since we want volta support. It is too much effort -// to write shared memory iterators that are probably needed for volta to function -// properly. As a result, we allow converters both after the LDG (for volta) and after -// the LDS for Turing+. -template< - /// Iterator for B matrix in global memory - typename IteratorB, - /// Warp level Mma - typename MmaOperator, - /// Math operation perform by warp level operator - typename MathOperator> -struct SetConverters { -}; - -// Dequantize after LDG, so set transforms accordingly -template< - /// Iterator for B matrix in global memory - typename IteratorB, - /// Mma Policy - typename MmaOperator> -struct SetConverters { - using TransformAfterLDG = - FastInterleavedAndBiasedNumericArrayConverter; - - using TransformAfterLDS = NumericArrayConverter; -}; - -// Dequantize after LDS, so set transforms accordingly - -template< - /// Iterator for B matrix in global memory - typename IteratorB, - /// Mma Policy - typename MmaOperator> -struct SetConverters { - using TransformAfterLDG = - NumericArrayConverter; - - using TransformAfterLDS = - FastInterleavedAndBiasedNumericArrayConverter; -}; - -//////////////////////////////////////////////////////////////////////////////// - -template< - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale_, - /// Layout for the scale operand - typename LayoutScale_, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// - typename Enable = void> -struct DqMma; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h deleted file mode 100644 index ef59e1b406c8d01cd138f81e1c7f737fe5c3e3c5..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +++ /dev/null @@ -1,346 +0,0 @@ -#pragma once - -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass_extensions/arch/mma.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" -#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" -#include "cutlass_extensions/tile_interleaved_layout.h" - -#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template< - /// Type for elementA - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Stages in GEMM - int kStages, - /// - typename Operator, - /// - SharedMemoryClearOption SharedMemoryClear> -struct DqMma= 80)>::type> { - - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); - - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, - LayoutA, - 1, - ThreadMapA, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, - LayoutB, - 0, - ThreadMapB, - AccessTypeB>; - - // ThreadMap for scale iterator - static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); - using IteratorScaleThreadMap = - transform::PitchLinearStripminedThreadMap, - MmaCore::Shape::kN / kAlignmentScale, - kAlignmentScale>; - - // Define iterators over tiles from the scale operand - using IteratorScale = - cutlass::transform::threadblock::PredicatedTileIterator, - ElementScale, - LayoutScale, - 0, - IteratorScaleThreadMap, - kAlignmentScale>; - - using SmemIteratorScale = IteratorScale; - - using Converter = FastInterleavedAndBiasedNumericArrayConverter; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; -}; - -template< - /// Type for element A - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Stages in GEMM - int kStages, - /// - typename Operator, - /// - SharedMemoryClearOption SharedMemoryClear, - /// - int RowsPerTile, - /// - int ColumnsInterleaved> -struct DqMma, - kAlignmentB, - ElementScale, - LayoutScale, - kAlignmentScale, - ElementAccumulator, - layout::RowMajor, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - kStages, - Operator, - SharedMemoryClear, - typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> { - - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); - - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, - LayoutA, - 1, - ThreadMapA, - AccessTypeA>; - -private: - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); - - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); - - using GmemIteratorShape = - MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; - -public: - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock:: - PredicatedTileAccessIterator; - - // ThreadMap for scale iterator - static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); - using IteratorScaleThreadMap = - transform::PitchLinearStripminedThreadMap, - MmaCore::Shape::kN / kAlignmentScale, - kAlignmentScale>; - - // Define iterators over tiles from the scale operand - using IteratorScale = - cutlass::transform::threadblock::PredicatedTileIterator, - ElementScale, - LayoutScale, - 0, - IteratorScaleThreadMap, - kAlignmentScale>; - - using SmemIteratorScale = IteratorScale; - - using Converter = FastInterleavedAndBiasedNumericArrayConverter; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h deleted file mode 100644 index b25405de013f2cacee9351a60de9605f17f5cace..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +++ /dev/null @@ -1,315 +0,0 @@ -#pragma once - -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass_extensions/arch/mma.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" -#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" -#include "cutlass_extensions/tile_interleaved_layout.h" - -#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -template< - /// Type for element A - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DqMma::type> { - - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - static constexpr bool DqAfterLDG = platform::is_same::value; - static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; - using MmaCoreElementA = typename platform::conditional::type; - using MmaCoreElementB = typename platform::conditional::type; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementA, - LayoutA, - 1, - typename MmaCore::IteratorThreadMapA, - kAlignmentA>; - - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementB, - LayoutB, - 0, - typename MmaCore::IteratorThreadMapB, - kAlignmentB>; - - // ThreadMap for scale iterator - static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); - using IteratorScaleThreadMap = - transform::PitchLinearStripminedThreadMap, - MmaCore::Shape::kN / kAlignmentScale, - kAlignmentScale>; - - // Define iterators over tiles from the scale operand - using IteratorScale = - cutlass::transform::threadblock::PredicatedTileIterator, - ElementScale, - LayoutScale, - 0, - IteratorScaleThreadMap, - kAlignmentScale>; - - using SmemScaleType = typename platform::conditional::type; - using SmemIteratorScale = - cutlass::transform::threadblock::PredicatedTileIterator, - SmemScaleType, - LayoutScale, - 0, - IteratorScaleThreadMap, - kAlignmentScale>; - - using Converters = SetConverters; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; -}; - -// Specialization to handle column major interleave B -template< - /// Type for element A - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int RowsPerTile, - /// - int ColumnsInterleaved> -struct DqMma, - kAlignmentB, - ElementScale, - LayoutScale, - kAlignmentScale, - ElementAccumulator, - layout::RowMajor, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - 2, - Operator, - SharedMemoryClearOption::kNone, - typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> { - - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - static constexpr bool DqAfterLDG = platform::is_same::value; - static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; - using MmaCoreElementA = typename platform::conditional::type; - using MmaCoreElementB = typename platform::conditional::type; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - ElementA, - LayoutA, - 1, - typename MmaCore::IteratorThreadMapA, - kAlignmentA>; - -private: - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); - - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); - - using GmemIteratorShape = - MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; - -public: - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock:: - PredicatedTileIterator; - - // ThreadMap for scale iterator - static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); - using IteratorScaleThreadMap = - transform::PitchLinearStripminedThreadMap, - MmaCore::Shape::kN / kAlignmentScale, - kAlignmentScale>; - - // Define iterators over tiles from the scale operand - using IteratorScale = - cutlass::transform::threadblock::PredicatedTileIterator, - ElementScale, - LayoutScale, - 0, - IteratorScaleThreadMap, - kAlignmentScale>; - - using SmemScaleType = typename platform::conditional::type; - using SmemIteratorScale = - cutlass::transform::threadblock::PredicatedTileIterator, - SmemScaleType, - LayoutScale, - 0, - IteratorScaleThreadMap, - kAlignmentScale>; - - using Converters = SetConverters; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h deleted file mode 100644 index da51c94f8659f5f5c8d0abbf6039bc726fe95dd0..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h +++ /dev/null @@ -1,426 +0,0 @@ -#pragma once - -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" -#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight -template< - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma { - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight -template< - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma { - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -template< - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma { - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight -template< - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma { - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on -// large tile when not enough shared mem is present to do 3+ stage -template< - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB> -struct DefaultMma { - - // Define the MmaCore components - // 3 is used on purpose here to trigger components for mma multistage - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - half_t, - LayoutA, - 1, - ThreadMapA, - AccessTypeA, - GatherA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - half_t, - LayoutB, - 0, - ThreadMapB, - AccessTypeB, - GatherB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h deleted file mode 100644 index 25acf9772e7dfef6047d9b0cd19351191ca6f179..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ /dev/null @@ -1,527 +0,0 @@ -#pragma once - -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight -template< - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB> -struct DefaultMma { - -private: - // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. - static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; - using MmaElementA = typename platform::conditional::type; - using MmaElementB = typename platform::conditional::type; - -public: - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - bfloat16_t, - LayoutA, - 1, - typename MmaCore::IteratorThreadMapA, - kAlignmentA, - GatherA>; - - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, - bfloat16_t, - LayoutB, - 0, - typename MmaCore::IteratorThreadMapB, - kAlignmentB, - GatherB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; -}; - -// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on -// large tile when not enough shared mem is present to do 3+ stage -template< - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB> -struct DefaultMma { - - // Define the MmaCore components - // 3 is used on purpose here to trigger components for mma multistage - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - bfloat16_t, - LayoutA, - 1, - ThreadMapA, - AccessTypeA, - GatherA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - bfloat16_t, - LayoutB, - 0, - ThreadMapB, - AccessTypeB, - GatherB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight -template< - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma { - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight -template< - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma { - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -template< - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma { - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight -template< - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma { - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h deleted file mode 100644 index ad863af970becf39de61560226126ad5a4540b75..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h +++ /dev/null @@ -1,236 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/mma_base.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// -// SFINAE trick so I can keep the same loop code for Volta and dispatch to the -// correct warp level mma. On volta, all data is stored to shared memory as FP16. -template -CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, - typename WarpMma::FragmentC& D, - typename WarpMma::FragmentA const& A, - typename WarpMma::FragmentB const& B, - typename WarpMma::FragmentC const& C, - const int warp_tileB_k_offset) -{ - warp_mma(D, A, B, C); -} - -template -CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, - typename WarpMma::FragmentC& D, - typename WarpMma::TransformedFragmentA const& A, - typename WarpMma::TransformedFragmentB const& B, - typename WarpMma::FragmentC const& C, - const int warp_tileB_k_offset) -{ - warp_mma(D, A, B, C, warp_tileB_k_offset); -} -//////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template< - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// The type of the scales - typename ElementScale_, - /// Number of stages, - int Stages, - /// Used for partial specialization - typename Enable = bool> -class DqMmaBase { -public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Policy describing tuning details - using Policy = Policy_; - - ///< Type of the scale to be loaded - using ElementScale = ElementScale_; - - // - // Dependent types - // - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; - - /// Number of warp-level GEMM oeprations - static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - static constexpr int kNumKIterationsPerWarpBLoad = - Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; - - static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); - static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = - MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - /// Buffer to hold scales for threadblock - AlignedBuffer operand_scale; - - public: - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() - { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() - { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() - { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() - { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - -protected: - // - // Data members - // - - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; - - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx): - warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) - { - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h deleted file mode 100644 index c232264826233680ef5c2c5ae2cf330e9dffab80..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h +++ /dev/null @@ -1,599 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template< - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Data type for the scales - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = bool> -class DqMmaMultistage: public DqMmaBase { -public: - ///< Base class - using Base = DqMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using TransformBAfterLDS = TransformBAfterLDS_; - - // - // Dependent types - // - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - using Dequantizer = - warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave = - layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -private: - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, - lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void - copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale operand in global memory - IteratorScale iterator_scale, - ///< initial value of accumulator - FragmentC const& src_accum) - { - - // - // Prologue - // - - TransformBAfterLDS lds_converter; - - // NOTE - switch to ldg.sts - // Issue this first, so cp.async.commit_group will commit this load as well. - // Note: we do not commit here and this load will commit in the same group as - // the first load of A. - FragmentScale tb_frag_scales; - tb_frag_scales.clear(); - iterator_scale.load(tb_frag_scales); - this->smem_iterator_scale_.store(tb_frag_scales); - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector - / 8; - - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector - / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // - - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - - typename IteratorA::AccessType zero_A; - zero_A.clear(); - - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename Dequantizer::FragmentScale warp_frag_scales; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - warp_dequantizer_.load(warp_frag_scales); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { - this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) - % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - typename TransformBAfterLDS::result_type converted_frag_B = - lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - - run_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - smem_read_stage_idx = 0; - } - else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - } - } - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h deleted file mode 100644 index 4441e795c02cfb3bd7ab851d00b0732ce38f4614..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +++ /dev/null @@ -1,385 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -#include "cutlass_extensions/ft_gemm_configs.h" -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template< - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Data type for the scales - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Converter for B matrix applied immediately after the LDG (before STS) - typename TransformBAfterLDG_, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// Used for partial specialization - typename Enable = bool> -class DqMmaPipelined: public DqMmaBase { -public: - ///< Base class - using Base = DqMmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - using TransformBAfterLDG = TransformBAfterLDG_; - using TransformBAfterLDS = TransformBAfterLDS_; - - // - // Dependent types - // - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave = - layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -protected: - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaPipelined(typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, - lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) - { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC& accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - IteratorScale iterator_scale, ///< iterator over scale operand in global memory - FragmentC const& src_accum) - { ///< source accumulator tile - - // - // Prologue - // - TransformBAfterLDG ldg_converter; - TransformBAfterLDS lds_converter; - - using TransformA = - NumericArrayConverter; - - using TransformScale = NumericArrayConverter; - - // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want - // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. - TransformA transformA; - TransformScale transformScale; - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A; - FragmentB tb_frag_B; - FragmentScale tb_frag_scales; - - using WarpFragmentScale = typename Dequantizer::FragmentScale; - WarpFragmentScale warp_frag_scales; - - tb_frag_A.clear(); - tb_frag_B.clear(); - tb_frag_scales.clear(); - - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - iterator_scale.load(tb_frag_scales); - - ++iterator_A; - ++iterator_B; - - this->smem_iterator_A_.store(transformA(tb_frag_A)); - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - __syncthreads(); - - warp_dequantizer_.load(warp_frag_scales); - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transformA(tb_frag_A)); - - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { - this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) - % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k == 0) { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - } - - typename TransformBAfterLDS::result_type converted_frag_B = - lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - run_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - } - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h deleted file mode 100644 index 2a42f5785ab17afd968894f07c96ff97bf8aea5a..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +++ /dev/null @@ -1,127 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/warp/default_mma_tensor_op.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for m-by-n-by-kgroup -template< - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A elements, - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp { - -private: - // Shape for computing the FP16s - using ComputeInstructionShape = InstructionShape_; - - // Chosen so we get K=16 for int8 and K=32 for int4. - static constexpr int LoadInstructionK = 8 * sizeof_bits::value / sizeof_bits::value; - - // Shape for loading the narrow data type from shared memory - using LoadInstructionShape = GemmShape; - -public: - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy, - cutlass::MatrixShape<1, 1>>; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h deleted file mode 100644 index 7cc255a6017b1f0eff9b479a0fc3a0904d8b69ce..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ /dev/null @@ -1,313 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template< - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Instruction shape to override shared memory iterators with - typename SharedMemoryInstructionShape_, - /// Number of partitions along K dimension - int PartitionsK_ = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Used for partial specialization - typename Enable = bool> -class MmaTensorOpComputeBWithF16 { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 80), - "MmaTensorOpCvtBToA only supports underlying HMMA"); - - static_assert(platform::is_same::value - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80), - "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; - - static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, - "M dimension of compute instruction must match load"); - static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, - "N dimension of compute instruction must match load"); - - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; - - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - -public: - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = Array; - - /// Iterates over the B operand in memory - using IteratorB = - MmaTensorOpMultiplicandTileIterator, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = Array; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; - -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, - TransformedFragmentA const& A, - TransformedFragmentB const& B, - FragmentC const& C, - const int warp_tileB_k_offset) const - { - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B"); - - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } - } -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } - } -#else - assert(0); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h deleted file mode 100644 index 0b48d28219d819ef97d6b0b49fe4a440bfe9e5b3..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +++ /dev/null @@ -1,469 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" - -#include "cutlass/functional.h" -#include "cutlass/platform/platform.h" - - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace warp { - -//////////////////////////////////////////////////////////////////////////////// - -template< - /// Matrix multiply operator - typename MmaOperator_, - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity - Operand Operand, - /// Data type of Scale elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Number of threads participating in one matrix operation - int Threads, - /// - typename Enable = void> -class MmaTensorOpDequantizer; - -//////////////////////////////////////////////////////////////////////////////// -// Bfloat specialization for Ampere -template< - /// Underlying matrix multiply operator (concept: MmaTensorOp) - typename MmaOperator_, - /// Shape of the warp level matrix multiply (concept: GemmShape) - typename Shape_> -class MmaTensorOpDequantizer< - MmaOperator_, - Shape_, - Operand::kB, - bfloat16_t, - layout::RowMajor, - 32, - typename platform::enable_if< - MmaOperator_::ArchTag::kMinComputeCapability >= 80 - && platform::is_same::value>::type> { - -public: - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - - /// Type of the scales - using ElementScale = bfloat16_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentScale = Array; - - /// Warp mma shape - using Shape = Shape_; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) - { - const int warp_offset = warp_idx_n * Shape::kN; - const int quad = lane_idx / 4; - const int thread_offset = warp_offset + quad; - pointer_ = smem_scales.data() + thread_offset; - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { - scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN]; - } - } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) - { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - const __nv_bfloat16* scale_ptr = reinterpret_cast(&scale_frag); - - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { - static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); - __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { - operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); - } - } -#else - // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should - // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid - // numerous conversion instructions in GEMM main loop. - arch::device_breakpoint(); -#endif - } - -private: - ElementScale const* pointer_; -}; - -//////////////////////////////////////////////////////////////////////////////// - -// Specialization for Turing & Ampere -template< - /// Underlying matrix multiply operator (concept: MmaTensorOp) - typename MmaOperator_, - /// Shape of the warp level matrix multiply (concept: GemmShape) - typename Shape_> -class MmaTensorOpDequantizer< - MmaOperator_, - Shape_, - Operand::kB, - half_t, - layout::RowMajor, - 32, - typename platform::enable_if< - MmaOperator_::ArchTag::kMinComputeCapability >= 75 - && platform::is_same::value>::type> { - -public: - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - - /// Type of the scales - using ElementScale = half_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentScale = Array; - - /// Warp mma shape - using Shape = Shape_; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) - { - const int warp_offset = warp_idx_n * Shape::kN; - const int quad = lane_idx / 4; - const int thread_offset = warp_offset + quad; - pointer_ = smem_scales.data() + thread_offset; - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { - scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN]; - } - } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) - { - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - multiplies mul_op; - - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } - } - -private: - ElementScale const* pointer_; -}; - -//////////////////////////////////////////////////////////////////////////////// - -// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm -template< - /// Underlying matrix multiply operator (concept: MmaTensorOp) - typename MmaOperator_, - /// Shape of the warp level matrix multiply (concept: GemmShape) - typename Shape_> -class MmaTensorOpDequantizer< - MmaOperator_, - Shape_, - Operand::kB, - half_t, - layout::RowMajor, - 32, - typename platform::enable_if< - platform::is_same::value - && platform::is_same::value>::type> { - -public: - static_assert(platform::is_same>::value, ""); - - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Type of the scales - using ElementScale = half_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - /// Warp mma shape - using Shape = Shape_; - - // Fragment to hold scale data to apply to B before mma - // Each 32x32x4 matmul uses 8 elements from B. - static constexpr int ColsPerMmaTile = 32; - static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; - using FragmentScale = Array; - using AccessType = Array; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) - { - const int warp_offset = warp_idx_n * Shape::kN; - const int base_col = lane_idx & 0xF8; - const int thread_offset = warp_offset + base_col; - pointer_ = smem_scales.data() + thread_offset; - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - AccessType* scale_frag_ptr = reinterpret_cast(&scale_frag); - - CUTLASS_PRAGMA_UNROLL - for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { - // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. - scale_frag_ptr[tile_iter] = *reinterpret_cast(pointer_ + ColsPerMmaTile * tile_iter); - } - } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) - { - static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); - - multiplies mul_op; - operand_frag = mul_op(operand_frag, scale_frag); - } - -private: - ElementScale const* pointer_; -}; - -//////////////////////////////////////////////////////////////////////////////// - -// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm -template< - /// Underlying matrix multiply operator (concept: MmaTensorOp) - typename MmaOperator_, - /// Shape of the warp level matrix multiply (concept: GemmShape) - typename Shape_> -class MmaTensorOpDequantizer< - MmaOperator_, - Shape_, - Operand::kB, - half_t, - layout::RowMajor, - 32, - typename platform::enable_if< - platform::is_same::value - && platform::is_same::value>::type> { - -public: - static_assert(platform::is_same>::value, ""); - - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Type of the scales - using ElementScale = half_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - /// Warp mma shape - using Shape = Shape_; - - // Fragment to hold scale data to apply to B before mma - // Each 32x32x4 matmul uses 8 elements from B. - static constexpr int ColsPerMmaTile = 32; - static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; - using FragmentScale = Array; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) - { - const int warp_offset = warp_idx_n * Shape::kN; - const int base_col = lane_idx & 0xF8 + lane_idx % 4; - const int thread_offset = warp_offset + base_col; - pointer_ = smem_scales.data() + thread_offset; - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - CUTLASS_PRAGMA_UNROLL - for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { - // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. - // For col major B, each thread will jump 4 cols to get its next value inside - // of the super mma. - CUTLASS_PRAGMA_UNROLL - for (int mma_iter = 0; mma_iter < 2; ++mma_iter) { - scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter]; - } - } - } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) - { - using MmaOperandB = typename ArchMmaOperator::FragmentB; - static constexpr int total_n_mmas = 2 * TileNIterations; - static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, ""); - - multiplies mul_op; - - MmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } - } - -private: - ElementScale const* pointer_; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h b/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h deleted file mode 100644 index fd200e0d4bc93131930f6203df060396b814070d..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +++ /dev/null @@ -1,429 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register -*/ - -#pragma once - -#include "cutlass/arch/arch.h" -#include "cutlass/array.h" -#include "cutlass/half.h" -#include "cutlass/numeric_types.h" - -namespace cutlass { - -// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low -// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally -// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. -// This converter will uninterleave the data and subtract the bias while converting to the result type. -template -struct FastInterleavedAndBiasedNumericArrayConverter { -}; - -template<> -struct FastInterleavedAndBiasedNumericArrayConverter { - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); - - // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter { - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template<> -struct FastInterleavedAndBiasedNumericArrayConverter { - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - - uint32_t* bf16_result_ptr = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - static constexpr uint32_t fp32_base = 0x4B000000; - float fp32_intermediates[4]; - - // Construct FP32s, bfloat does not have enough mantissa for IADD trick - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); - - // Subtract out fp32_base + 128 to make the unsigned integer signed. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 4; ++ii) { - fp32_intermediates[ii] -= 8388736.f; - } - - // Truncate the fp32 representation and pack up as bfloat16s. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 2; ++ii) { - bf16_result_ptr[ii] = - __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); - } -#else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - result.clear(); // Suppress compiler warning - arch::device_breakpoint(); -#endif - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter { - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template<> -struct FastInterleavedAndBiasedNumericArrayConverter { - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. - - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. - - // This is the half2 {1032, 1032} represented as an integer. - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - static constexpr uint32_t NEG_72 = 0xd480d480; - - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter { - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template<> -struct FastInterleavedAndBiasedNumericArrayConverter { - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - - uint32_t* h = reinterpret_cast(&result); - uint32_t const source_i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; - - // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. - // No shift needed for first item. - uint32_t i4s = source_i4s; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - CUTLASS_PRAGMA_UNROLL - for (int ii = 1; ii < result_type::kElements / 2; ++ii) { - i4s >>= sizeof_bits::value; - // (i4s & 0x000f000f) | 0x43004300 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[ii]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - } - - // This is the BF16 {-136, -136} represented as an integer. - static constexpr uint32_t BF16_BIAS = 0xC308C308; - static constexpr uint32_t BF16_ONE = 0x3F803F80; - - // Finally, we construct the output numbers. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < result_type::kElements / 2; ++ii) { - // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction - asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); - } -#else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - arch::device_breakpoint(); - result.clear(); // Suppress compiler warning. -#endif - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter { - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h b/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h deleted file mode 100644 index bb0808522b19aba3dd488caba34cccacc6d7f269..0000000000000000000000000000000000000000 --- a/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h +++ /dev/null @@ -1,61 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines new layouts needed for MoE -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/pitch_linear_coord.h" - -namespace cutlass { -namespace layout { - -template -class ColumnMajorTileInterleave { - static constexpr int kRowsPerTile = RowsPerTile; - static constexpr int kColumnsInterleaved = ColumnsInterleaved; -}; - -template -struct IsColumnMajorTileInterleave { - static constexpr bool value = false; -}; - -template -struct IsColumnMajorTileInterleave> { - static constexpr bool value = true; -}; - -} // namespace layout -} // namespace cutlass diff --git a/cutlass_kernels/cutlass_heuristic.cu b/cutlass_kernels/cutlass_heuristic.cu deleted file mode 100644 index 62735ce30deb92d57284fffb31c007c8d8f11a37..0000000000000000000000000000000000000000 --- a/cutlass_kernels/cutlass_heuristic.cu +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cutlass_heuristic.h" -#include "cutlass/gemm/gemm.h" -#include - -#include -#include - -namespace fastertransformer { - -struct TileShape { - int m; - int n; -}; - -TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) -{ - switch (tile_config) { - case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - return TileShape{32, 128}; - case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: - case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - return TileShape{64, 128}; - case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: - case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: - case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - return TileShape{128, 128}; - default: - throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config"); - } -} - -bool is_valid_split_k_factor(const int64_t m, - const int64_t n, - const int64_t k, - const TileShape tile_shape, - const int split_k_factor, - const size_t workspace_bytes, - const bool is_weight_only) -{ - - // All tile sizes have a k_tile of 64. - static constexpr int k_tile = 64; - - // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k - if (is_weight_only) { - if ((k % k_tile) != 0) { - return false; - } - - if ((k % split_k_factor) != 0) { - return false; - } - - const int k_elements_per_split = k / split_k_factor; - if ((k_elements_per_split % k_tile) != 0) { - return false; - } - } - - // Check that the workspace has sufficient space for this split-k factor - const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; - const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; - const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; - - if (required_ws_bytes > workspace_bytes) { - return false; - } - - return true; -} - -std::vector get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) -{ - - std::vector simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; - - std::vector square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, - CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}; - - std::vector quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; - - const std::vector allowed_configs = is_weight_only ? quant_B_configs : square_configs; - return simt_configs_only ? simt_configs : allowed_configs; -} - -std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) -{ - std::vector tiles = get_candidate_tiles(is_weight_only, simt_configs_only); - - std::vector candidate_configs; - const int min_stages = 2; - const int max_stages = sm >= 80 ? 4 : 2; - - for (const auto& tile_config : tiles) { - for (int stages = min_stages; stages <= max_stages; ++stages) { - CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; - candidate_configs.push_back(config); - } - } - - return candidate_configs; -} - -CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, - const std::vector& occupancies, - const int64_t m, - const int64_t n, - const int64_t k, - const int64_t num_experts, - const int split_k_limit, - const size_t workspace_bytes, - const int multi_processor_count, - const int is_weight_only) -{ - - if (occupancies.size() != candidate_configs.size()) { - throw std::runtime_error("[FT Error][estimate_best_config_from_occupancies] occpancies and " - "candidate configs vectors must have equal length."); - } - - CutlassGemmConfig best_config; - // Score will be [0, 1]. The objective is to minimize this score. - // It represents the fraction of SM resources unused in the last wave. - float config_score = 1.0f; - int config_waves = INT_MAX; - int current_m_tile = 0; - - const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; - for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { - CutlassGemmConfig candidate_config = candidate_configs[ii]; - TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); - int occupancy = occupancies[ii]; - - if (occupancy == 0) { - continue; - } - - // Keep small tile sizes when possible. - if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile - && current_m_tile < tile_shape.m) { - continue; - } - - const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; - const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; - - for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { - if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { - const int ctas_per_wave = occupancy * multi_processor_count; - const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; - - const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; - const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); - const float current_score = float(num_waves_total) - num_waves_fractional; - - const float score_slack = 0.1f; - if (current_score < config_score - || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { - config_score = current_score; - config_waves = num_waves_total; - SplitKStyle split_style = - split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; - best_config = CutlassGemmConfig{ - candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; - current_m_tile = tile_shape.m; - } - else if (current_score == config_score - && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor - || current_m_tile < tile_shape.m)) { - // Prefer deeper pipeline or smaller split-k - SplitKStyle split_style = - split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; - best_config = CutlassGemmConfig{ - candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; - current_m_tile = tile_shape.m; - config_waves = num_waves_total; - } - } - } - } - - if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { - throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config."); - } - - return best_config; -} - -} // namespace fastertransformer diff --git a/cutlass_kernels/cutlass_heuristic.h b/cutlass_kernels/cutlass_heuristic.h deleted file mode 100644 index 691d7ea36f16bb6235296dabec8f79233bbce557..0000000000000000000000000000000000000000 --- a/cutlass_kernels/cutlass_heuristic.h +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include "cutlass_extensions/ft_gemm_configs.h" - -namespace fastertransformer { - -std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only); - -CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, - const std::vector& occupancies, - const int64_t m, - const int64_t n, - const int64_t k, - const int64_t num_experts, - const int split_k_limit, - const size_t workspace_bytes, - const int multi_processor_count, - const int is_weight_only); - -} // namespace fastertransformer diff --git a/cutlass_kernels/cutlass_preprocessors.cc b/cutlass_kernels/cutlass_preprocessors.cc deleted file mode 100644 index 2556a0bdffbe502d2d7fafaf647b4b30e83604fe..0000000000000000000000000000000000000000 --- a/cutlass_kernels/cutlass_preprocessors.cc +++ /dev/null @@ -1,703 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cutlass_preprocessors.h" -#include "cuda_utils.h" -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" - -#include - -namespace fastertransformer { - -int get_bits_in_quant_type(QuantType quant_type) { - switch (quant_type) { - case QuantType::INT8_WEIGHT_ONLY: - return 8; - case QuantType::PACKED_INT4_WEIGHT_ONLY: - return 4; - default: - return -1; - } -} - -struct LayoutDetails { - enum class Layout { - UNKNOWN, - ROW_MAJOR, - COLUMN_MAJOR - }; - - Layout layoutB = Layout::UNKNOWN; - int rows_per_column_tile = 1; - int columns_interleaved = 1; - - bool uses_imma_ldsm = false; -}; - -template -struct getLayoutDetails { -}; - -template<> -struct getLayoutDetails { - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; - return layout_details; - } -}; - -template<> -struct getLayoutDetails { - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; - return layout_details; - } -}; - -template -struct getLayoutDetails> { - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; - layout_details.rows_per_column_tile = RowsPerTile; - layout_details.columns_interleaved = ColumnsInterleaved; - return layout_details; - } -}; - -template -LayoutDetails getLayoutDetailsForArchAndQuantType() -{ - - using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB; - using LayoutB = typename CompileTraits::Layout; - using MmaOperator = typename CompileTraits::Operator; - LayoutDetails details = getLayoutDetails()(); - details.uses_imma_ldsm = std::is_same::value; - return details; -} - -template -LayoutDetails getLayoutDetailsForArch(QuantType quant_type) -{ - LayoutDetails details; - if (quant_type == QuantType::INT8_WEIGHT_ONLY) { - details = getLayoutDetailsForArchAndQuantType(); - } - else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { - details = getLayoutDetailsForArchAndQuantType(); - } - else { - FT_CHECK_WITH_INFO(false, "Unsupported quantization type"); - } - return details; -} - -LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) -{ - if (arch >= 70 && arch < 75) { - return getLayoutDetailsForArch(quant_type); - } - else if (arch >= 75 && arch < 80) { - return getLayoutDetailsForArch(quant_type); - } - else if (arch >= 80 && arch < 90) { - return getLayoutDetailsForArch(quant_type); - } - else { - FT_CHECK_WITH_INFO(false, "Unsupported Arch"); - return LayoutDetails(); - } -} - -// Permutes the rows of B for Turing and Ampere. Throws an error for other -// architectures. The data is permuted such that: For int8, each group of 16 -// rows is permuted using the map below: -// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 -// For int4, each group of 32 rows is permuted using the map below: -// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 -// 23 30 31 -void permute_B_rows_for_mixed_gemm(int8_t *permuted_quantized_tensor, - const int8_t *quantized_tensor, - const std::vector &shape, - QuantType quant_type, - const int64_t arch_version) { - const size_t num_rows = shape[0]; - const size_t num_cols = shape[1]; - - const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); - const int K = 16 / BITS_PER_ELT; - const int ELTS_PER_REG = 32 / BITS_PER_ELT; - - const uint32_t *input_byte_ptr = - reinterpret_cast(quantized_tensor); - uint32_t *output_byte_ptr = - reinterpret_cast(permuted_quantized_tensor); - - int MMA_SHAPE_N = 8; - int B_ROWS_PER_MMA = 8 * K; - const int elts_in_int32 = 32 / BITS_PER_ELT; - - const int num_vec_cols = num_cols / elts_in_int32; - - FT_CHECK_WITH_INFO(arch_version >= 75, - "Unsupported Arch. Pre-volta not supported. Column " - "interleave not needed on Volta."); - - FT_CHECK_WITH_INFO(num_rows % B_ROWS_PER_MMA == 0, - fmtstr("Invalid shape for quantized tensor. Number of " - "rows of quantized matrix must be a multiple of %d", - B_ROWS_PER_MMA)); - - FT_CHECK_WITH_INFO( - num_cols % MMA_SHAPE_N == 0, - fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number " - "of cols must be a multiple of %d.", - MMA_SHAPE_N)); - - // The code is written as below so it works for both int8 - // and packed int4. - for (size_t base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { - for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { - - for (int write_col = 0; write_col < num_vec_cols; ++write_col) { - const int write_row = base_row + tile_row; - const int tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) + - tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); - const int read_row = base_row + tile_read_row; - const int read_col = write_col; - - const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; - const int64_t write_offset = - int64_t(write_row) * num_vec_cols + write_col; - - output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; - } - } - } -} - -// We need to use this transpose to correctly handle packed int4 and int8 data -// The reason this code is relatively complex is that the "trivial" loops took a -// substantial amount of time to transpose leading to long preprocessing times. -// This seemed to be a big issue for relatively large models. -template -void subbyte_transpose_impl(int8_t *transposed_quantized_tensor, - const int8_t *quantized_tensor, - const std::vector &shape) { - const int bits_per_elt = get_bits_in_quant_type(quant_type); - const size_t num_rows = shape[0]; - const size_t num_cols = shape[1]; - - const size_t col_bytes = num_cols * bits_per_elt / 8; - const size_t col_bytes_trans = num_rows * bits_per_elt / 8; - - const uint8_t *input_byte_ptr = - reinterpret_cast(quantized_tensor); - uint8_t *output_byte_ptr = - reinterpret_cast(transposed_quantized_tensor); - - static constexpr int ELTS_PER_BYTE = - quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2; - - static constexpr int M_TILE_L1 = 64; - static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; - uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; - - static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); - - // We assume the dims are a multiple of vector width. Our kernels only handle - // dims which are multiples of 64 for weight-only quantization. As a result, - // this seemed like a reasonable tradeoff because it allows GCC to emit vector - // instructions. - FT_CHECK_WITH_INFO( - !(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), - fmtstr("Number of bytes for rows and cols must be a multiple of %d. " - "However, num_rows_bytes = %ld and num_col_bytes = %d.", - VECTOR_WIDTH, col_bytes_trans, col_bytes)); - - for (size_t row_tile_start = 0; row_tile_start < num_rows; - row_tile_start += M_TILE_L1) { - for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; - col_tile_start_byte += N_TILE_L1) { - - const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); - const int col_limit = - std::min(col_tile_start_byte + N_TILE_L1, col_bytes); - - for (int ii = 0; ii < M_TILE_L1; ++ii) { - const int row = row_tile_start + ii; - - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { - const int col = col_tile_start_byte + jj; - - const size_t logical_src_offset = row * col_bytes + col; - - if (row < row_limit && col < col_limit) { - for (int v = 0; v < VECTOR_WIDTH; ++v) { - cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; - } - } - } - } - - if (quant_type == QuantType::INT8_WEIGHT_ONLY) { - for (int ii = 0; ii < M_TILE_L1; ++ii) { - for (int jj = ii + 1; jj < N_TILE_L1; ++jj) { - std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); - } - } - } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { - - for (int ii = 0; ii < M_TILE_L1; ++ii) { - // Using M_TILE_L1 here is deliberate since we assume that the cache - // tile is square in the number of elements (not necessarily the - // number of bytes). - for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { - const int ii_byte = ii / ELTS_PER_BYTE; - const int ii_bit_offset = ii % ELTS_PER_BYTE; - - const int jj_byte = jj / ELTS_PER_BYTE; - const int jj_bit_offset = jj % ELTS_PER_BYTE; - - uint8_t src_elt = - 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); - uint8_t tgt_elt = - 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); - - cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); - - cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); - } - } - } else { - FT_CHECK_WITH_INFO(false, "Unsupported quantization type."); - } - - const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; - const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; - - const int row_limit_trans = - std::min(row_tile_start_trans + M_TILE_L1, num_cols); - const int col_limit_trans = - std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); - - for (int ii = 0; ii < M_TILE_L1; ++ii) { - const int row = row_tile_start_trans + ii; - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { - const int col = col_tile_start_byte_trans + jj; - - const size_t logical_tgt_offset = row * col_bytes_trans + col; - - if (row < row_limit_trans && col < col_limit_trans) { - for (int v = 0; v < VECTOR_WIDTH; ++v) { - output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; - } - } - } - } - } - } -} - -void subbyte_transpose(int8_t *transposed_quantized_tensor, - const int8_t *quantized_tensor, - const std::vector &shape, QuantType quant_type) { - - if (quant_type == QuantType::INT8_WEIGHT_ONLY) { - subbyte_transpose_impl( - transposed_quantized_tensor, quantized_tensor, shape); - } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { - subbyte_transpose_impl( - transposed_quantized_tensor, quantized_tensor, shape); - } else { - FT_CHECK_WITH_INFO(false, "Invalid quant_tye"); - } -} - -void add_bias_and_interleave_int8s_inplace(int8_t *int8_tensor, - const size_t num_elts) { - for (size_t ii = 0; ii < num_elts; ++ii) { - int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); - } - - // Step 2 will transform the layout of a 32-bit register in CUDA in order to - // match the int4 layout. This has no performance benefit and is purely so - // that int4 and int8 have the same layout. Pictorially, this does the - // following: bit 32 0 - // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) - - FT_CHECK_WITH_INFO(num_elts % 4 == 0, "Dimensions of int8 tensor must be a " - "multiple of 4 for register relayout"); - for (size_t base = 0; base < num_elts; base += 4) { - std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); - } -} - -void add_bias_and_interleave_int4s_inplace(int8_t *packed_int4_tensor, - const size_t num_elts) { - const size_t num_bytes = num_elts / 2; - - // Step 1 will be to transform all the int4s to unsigned in order to make the - // dequantize take as little instructions as possible in the CUDA code. - for (size_t ii = 0; ii < num_bytes; ++ii) { - int8_t transformed_packed_int4s = 0; - int8_t transformed_first_elt = - (int8_t(packed_int4_tensor[ii] << 4) >> 4) + - 8; // The double shift here is to ensure sign extension - int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; - - FT_CHECK_WITH_INFO(transformed_first_elt >= 0 && - transformed_first_elt <= 15, - "Illegal result for int4 transform (first elt)"); - FT_CHECK_WITH_INFO(transformed_second_elt >= 0 && - transformed_second_elt <= 15, - "Illegal result for int4 transform (second elt)"); - - // We don't need to mask in these ops since everything should be in the - // range 0-15 - transformed_packed_int4s |= transformed_first_elt; - transformed_packed_int4s |= (transformed_second_elt << 4); - packed_int4_tensor[ii] = transformed_packed_int4s; - } - - // Step 2 will transform the layout of a 32-bit register in CUDA in order to - // minimize the number of shift & logical instructions That are needed to - // extract the int4s in the GEMM main loop. Pictorially, the loop below will - // do the following: Take as input a 32 bit register with layout: bit 32 0 - // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt - // occupies 4 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt - // occupies 4 bits) - - FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a " - "multiple of 8 for register relayout"); - const size_t num_registers = num_bytes / 4; - - uint32_t *register_ptr = reinterpret_cast(packed_int4_tensor); - for (size_t ii = 0; ii < num_registers; ++ii) { - const uint32_t current_register = register_ptr[ii]; - uint32_t transformed_register = 0; - - for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { - const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; - const int src_shift = 4 * src_idx; - const int dest_shift = 4 * dest_idx; - - const uint32_t src_bits = (current_register >> src_shift) & 0xF; - transformed_register |= (src_bits << dest_shift); - } - register_ptr[ii] = transformed_register; - } -} - -void add_bias_and_interleave_quantized_tensor_inplace(int8_t *tensor, - const size_t num_elts, - QuantType quant_type) { - if (quant_type == QuantType::INT8_WEIGHT_ONLY) { - add_bias_and_interleave_int8s_inplace(tensor, num_elts); - } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { - add_bias_and_interleave_int4s_inplace(tensor, num_elts); - } else { - FT_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving."); - } -} - -void interleave_column_major_tensor(int8_t *interleaved_quantized_tensor, - const int8_t *quantized_tensor, - const std::vector &shape, - QuantType quant_type, - LayoutDetails details) { - // We only want to run this step for weight only quant. - FT_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || - quant_type == QuantType::INT8_WEIGHT_ONLY); - FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D"); - - const size_t num_rows = shape[0]; - const size_t num_cols = shape[1]; - - const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); - const int elts_in_int32 = 32 / BITS_PER_ELT; - - const int rows_per_tile = details.rows_per_column_tile; - - FT_CHECK_WITH_INFO(!(num_rows % elts_in_int32), - fmtstr("The number of rows must be a multiple of %d but " - "the number of rows is %d.", - elts_in_int32, num_rows)); - - FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile), - fmtstr("The number of columns must be a multiple of %d " - "but the number of columns is %ld", - rows_per_tile, num_cols)); - - const uint32_t *input_byte_ptr = - reinterpret_cast(quantized_tensor); - uint32_t *output_byte_ptr = - reinterpret_cast(interleaved_quantized_tensor); - - FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile), - fmtstr("The number of columns must be a multiple of %d " - "but the number of columns is %d.", - rows_per_tile, num_cols)); - - const int num_vec_rows = num_rows / elts_in_int32; - const int vec_rows_per_tile = rows_per_tile / elts_in_int32; - const int interleave = details.columns_interleaved; - - for (size_t read_col = 0; read_col < num_cols; ++read_col) { - const auto write_col = read_col / interleave; - for (int base_vec_row = 0; base_vec_row < num_vec_rows; - base_vec_row += vec_rows_per_tile) { - for (int vec_read_row = base_vec_row; - vec_read_row < - std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); - ++vec_read_row) { - const int64_t vec_write_row = - interleave * base_vec_row + - vec_rows_per_tile * (read_col % interleave) + - vec_read_row % vec_rows_per_tile; - - const int64_t read_offset = - int64_t(read_col) * num_vec_rows + vec_read_row; - const int64_t write_offset = - int64_t(write_col) * num_vec_rows * interleave + vec_write_row; - output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; - } - } - } -} - -void preprocess_weights_for_mixed_gemm(int8_t *preprocessed_quantized_weight, - const int8_t *row_major_quantized_weight, - const std::vector &shape, - QuantType quant_type, int arch) { - LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); - - FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D"); - - size_t num_elts = 1; - for (const auto &dim : shape) { - num_elts *= dim; - } - - const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8; - - std::vector src_buf(num_bytes); - std::vector dst_buf(num_bytes); - std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); - - // Works on row major data, so issue this permutation first. - if (details.uses_imma_ldsm) { - permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch); - src_buf.swap(dst_buf); - } - - if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) { - subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); - src_buf.swap(dst_buf); - } - - if (details.columns_interleaved > 1) { - interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); - src_buf.swap(dst_buf); - } - - add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); - std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); -} - -void preprocess_weights(int8_t *preprocessed_quantized_weight, - const int8_t *row_major_quantized_weight, size_t rows, - size_t cols, bool is_int4, int arch) { - QuantType qtype = is_int4 ? QuantType::PACKED_INT4_WEIGHT_ONLY - : QuantType::INT8_WEIGHT_ONLY; - preprocess_weights_for_mixed_gemm(preprocessed_quantized_weight, - row_major_quantized_weight, {rows, cols}, - qtype, arch); -} - -/* - Arguments: - input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16. - - quant_type - the type of the output quantization weight. - - This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the - zero-point is zero and will automatically construct the scales. - - It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is - viewed as a stack of matrices and a scale is produced for each column of every matrix. - -Outputs - processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM - unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking. - scale_ptr - scales for the quantized weight. - - Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data - layout may not make sense if printed. - - Shapes: - quant_type == int8: - If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n] - If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n] - quant_type == int4: - If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n] - If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape - [b,n] - - The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the - reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind - of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors - must have a dimension of 1, which breaks the semantics we need for batched weights. - */ - -template -void symmetric_quantize(int8_t* processed_quantized_weight, - int8_t* unprocessed_quantized_weight, - ComputeType* scale_ptr, - const WeightType* input_weight_ptr, - const std::vector& shape, - QuantType quant_type) -{ - - FT_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL"); - FT_CHECK_WITH_INFO(scale_ptr, "Scale output pointer is NULL"); - FT_CHECK_WITH_INFO(input_weight_ptr, "Input weight pointer is NULL"); - - FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - const int bits_in_type = get_bits_in_quant_type(quant_type); - const int bytes_per_out_col = num_cols * bits_in_type / 8; - - std::vector weight_buf; - if (unprocessed_quantized_weight == nullptr) { - weight_buf.resize(num_experts * num_rows * num_cols); - unprocessed_quantized_weight = weight_buf.data(); - } - - const int input_mat_size = num_rows * num_cols; - const int quantized_mat_size = num_rows * bytes_per_out_col; - const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); - - std::vector per_col_max(num_cols); - - for (int expert = 0; expert < num_experts; ++expert) { - const WeightType* current_weight = input_weight_ptr + expert * input_mat_size; - int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; - - // First we find the per column max for this expert weight. - for (int jj = 0; jj < num_cols; ++jj) { - per_col_max[jj] = 0.f; - } - - for (int ii = 0; ii < num_rows; ++ii) { - const WeightType* current_weight_row = current_weight + ii * num_cols; - for (int jj = 0; jj < num_cols; ++jj) { - per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); - } - } - - // Then, we construct the scales - ComputeType* current_scales = scale_ptr + expert * num_cols; - for (int jj = 0; jj < num_cols; ++jj) { - per_col_max[jj] *= quant_range_scale; - current_scales[jj] = ComputeType(per_col_max[jj]); - } - - // Finally, construct the weights. - for (int ii = 0; ii < num_rows; ++ii) { - int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; - const WeightType* current_weight_row = current_weight + ii * num_cols; - for (int jj = 0; jj < bytes_per_out_col; ++jj) { - - if (quant_type == QuantType::INT8_WEIGHT_ONLY) { - const float col_scale = per_col_max[jj]; - const float weight_elt = float(current_weight_row[jj]); - const float scaled_weight = round(weight_elt / col_scale); - const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); - current_quantized_weight_row[jj] = clipped_weight; - } - else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { - - // We will pack two int4 elements per iteration of the inner loop. - int8_t packed_int4s = 0; - for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { - const int input_idx = 2 * jj + packed_idx; - if (input_idx < num_cols) { - const float col_scale = per_col_max[input_idx]; - const float weight_elt = float(current_weight_row[input_idx]); - const float scaled_weight = round(weight_elt / col_scale); - int int_weight = int(scaled_weight); - const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); - - // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits - // if packing the second int4 and or the bits into the final result. - packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); - } - } - current_quantized_weight_row[jj] = packed_int4s; - } - else { - FT_CHECK_WITH_INFO(false, "Unsupported quantization type"); - } - } - } - } - const int arch = fastertransformer::getSMVersion(); - preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, arch); -} - -template void -symmetric_quantize(int8_t*, int8_t*, half*, const float*, const std::vector&, QuantType); - -template void -symmetric_quantize(int8_t*, int8_t*, half*, const half*, const std::vector&, QuantType); - - -template -void symmetric_quantize(int8_t* processed_quantized_weight, - ComputeType* scale_ptr, - const WeightType* input_weight_ptr, - const std::vector& shape, - QuantType quant_type) -{ - symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type); -} - -template void symmetric_quantize(int8_t*, float*, const float*, const std::vector&, QuantType); - -template void symmetric_quantize(int8_t*, half*, const float*, const std::vector&, QuantType); - -template void symmetric_quantize(int8_t*, half*, const half*, const std::vector&, QuantType); - -} // namespace fastertransformer diff --git a/cutlass_kernels/cutlass_preprocessors.h b/cutlass_kernels/cutlass_preprocessors.h deleted file mode 100644 index cd37d352025c0cd86e893afbff03343be209f127..0000000000000000000000000000000000000000 --- a/cutlass_kernels/cutlass_preprocessors.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once -#pragma GCC diagnostic ignored "-Wstrict-aliasing" - -#include -#include -#include - -namespace fastertransformer { - -enum class QuantType { INT8_WEIGHT_ONLY, PACKED_INT4_WEIGHT_ONLY }; - -int get_bits_in_quant_type(QuantType quant_type); - -void preprocess_weights(int8_t *preprocessed_quantized_weight, - const int8_t *row_major_quantized_weight, size_t rows, - size_t cols, bool is_int4, int arch); - -template -void symmetric_quantize(int8_t* processed_quantized_weight, - ComputeType* scale_ptr, - const WeightType* input_weight_ptr, - const std::vector& shape, - QuantType quant_type); - - -template -void symmetric_quantize(int8_t* processed_quantized_weight, - int8_t* unprocessed_quantized_weight, - ComputeType* scale_ptr, - const WeightType* input_weight_ptr, - const std::vector& shape, - QuantType quant_type); -} // namespace fastertransformer diff --git a/cutlass_kernels/fpA_intB_gemm.cu b/cutlass_kernels/fpA_intB_gemm.cu deleted file mode 100644 index 5e75a8d45bbb81878c5a0877f7219bb2367bf4d5..0000000000000000000000000000000000000000 --- a/cutlass_kernels/fpA_intB_gemm.cu +++ /dev/null @@ -1,99 +0,0 @@ -#include "fpA_intB_gemm.h" -#include "fpA_intB_gemm/fpA_intB_gemm_template.h" - -namespace fastertransformer -{ - - ActivationType get_activation(const std::string &activation_name) - { - if (activation_name == "identity") - return ActivationType::Identity; - if (activation_name == "relu") - return ActivationType::Relu; - if (activation_name == "silu") - return ActivationType::Silu; - if (activation_name == "gelu") - return ActivationType::Gelu; - // todo: more - return ActivationType::InvalidType; - } - - void gemm_fp16_int(const half *A, - const uint8_t *B, - const half *weight_scales, - half *C, - int m, int n, int k, - char *workspace_ptr, - size_t workspace_bytes, - cudaStream_t stream) - { - CutlassFpAIntBGemmRunner runner; - runner.gemm(A, B, weight_scales, - C, m, n, k, workspace_ptr, workspace_bytes, stream); - } - - template - void gemm_fp16_int_bias_act(const half *A, - const WeightType *B, - const half *weight_scales, - const half *bias, - half *C, - std::optional activation, - int m, int n, int k, int bias_stride, char *workspace_ptr, - size_t workspace_bytes, cudaStream_t stream) - { - CutlassFpAIntBGemmRunner runner; - - if (!activation && bias == nullptr) - { - runner.gemm(A, B, weight_scales, - C, m, n, k, workspace_ptr, workspace_bytes, stream); - } - else if (!activation) - { - runner.gemm_bias_act(A, B, weight_scales, bias, - C, m, n, k, bias_stride, ActivationType::Identity, workspace_ptr, workspace_bytes, stream); - } - else - { - runner.gemm_bias_act(A, B, weight_scales, bias, - C, m, n, k, bias_stride, get_activation(*activation), workspace_ptr, workspace_bytes, stream); - } - } - - template - void gemm_fp16_int_bias_act_residual( - const half *A, const WeightType *B, const half *weight_scales, - const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op, - const std::string &unary_op, int m, int n, - int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream) - { - CutlassFpAIntBGemmRunner runner; - - runner.gemm_bias_act_residual(A, B, weight_scales, bias, residual, - C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream); - } - - template void gemm_fp16_int_bias_act(const half *A, const uint4b_t *B, - const half *weight_scales, const half *bias, - half *C, std::optional activation, int m, - int n, int k, int bias_stride, char *workspace_ptr, - size_t workspace_bytes, cudaStream_t stream); - - template void gemm_fp16_int_bias_act_residual( - const half *A, const uint4b_t *B, const half *weight_scales, - const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op, - const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream); - - template void gemm_fp16_int_bias_act(const half *A, const uint8_t *B, - const half *weight_scales, const half *bias, - half *C, std::optional activation, int m, - int n, int k, int bias_stride, char *workspace_ptr, - size_t workspace_bytes, cudaStream_t stream); - - template void gemm_fp16_int_bias_act_residual( - const half *A, const uint8_t *B, const half *weight_scales, - const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op, - const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream); - -} // namespace fastertransformer diff --git a/cutlass_kernels/fpA_intB_gemm.h b/cutlass_kernels/fpA_intB_gemm.h deleted file mode 100644 index fcc24e3ea7958fea47e628a973c3a811a82b26a3..0000000000000000000000000000000000000000 --- a/cutlass_kernels/fpA_intB_gemm.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include -#include - -#include -#include "cutlass/numeric_types.h" -#include "cutlass/half.h" -#include "cutlass/integer_subbyte.h" - -namespace fastertransformer { - -using half = cutlass::half_t; -using uint4b_t = cutlass::uint4b_t; - -// TODO: Support more general bias shape - -// base gemm -void gemm_fp16_int(const half *A, const uint8_t * B, const half *weight_scales, - half *C, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream); - -template -void gemm_fp16_int_bias_act(const half *A, const WeightType *B, - const half *weight_scales, const half *bias, - half *C, std::optional activation, int m, - int n, int k, int bias_stride, char *workspace_ptr, - size_t workspace_bytes, cudaStream_t stream); - -template -void gemm_fp16_int_bias_act_residual( - const half *A, const WeightType *B, const half *weight_scales, - const half *bias, const half *residual, half *C, const std::string& activation, const std::string& binary_op, - const std::string& unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream); - - -} // namespace fastertransformer diff --git a/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h b/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h deleted file mode 100644 index a921a4c367ab138d15680fb99ce69afbc9973a6b..0000000000000000000000000000000000000000 --- a/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h" -#include "utils/activation_types.h" -#include - -namespace fastertransformer { - -/* - This runner only supports: - T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t} - - Activations, biases, scales and outputs are all assumed to be row-major. - - However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. - In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor - will instantiate the layout and preprocess based on the instantiation, so layout changes should only require - modifications to mix_gemm_B_layout.h. -*/ - -template -class CutlassFpAIntBGemmRunner { -public: - CutlassFpAIntBGemmRunner(); - ~CutlassFpAIntBGemmRunner(); - - void gemm(const T* A, - const WeightType* B, - const T* weight_scales, - T* C, - int m, - int n, - int k, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream); - - void gemm_bias_act(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int m, - int n, - int k, - int bias_stride, - ActivationType activation_type, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream); - - void gemm_bias_act_residual(const T *A, const WeightType *B, - const T *weight_scales, const T *biases, - const T *residual, T *C, int m, int n, int k, - const std::string& activation, const std::string& binary_op, - const std::string& unary_op, - char *workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream); - - // Returns desired workspace size in bytes. - int getWorkspaceSize(const int m, const int n, const int k); - -private: - template - void dispatch_to_arch(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int m, - int n, - int k, - int bias_stride, - CutlassGemmConfig gemm_config, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream, - int* occupancy = nullptr); - - template - void run_gemm(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int m, - int n, - int k, - int bias_stride, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream); - -private: - static constexpr int split_k_limit = 7; - - int sm_; - int multi_processor_count_; -}; - -} // namespace fastertransformer diff --git a/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h deleted file mode 100644 index b7d7f0d652af70d5b7a79f61babea87806bb526f..0000000000000000000000000000000000000000 --- a/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ /dev/null @@ -1,858 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" - -#include "cutlass/gemm/device/gemm_universal_base.h" -#include "cutlass/gemm/kernel/default_gemm.h" -#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" -#include "cutlass/epilogue/thread/linear_combination_residual_block.h" -#include "cutlass_extensions/compute_occupancy.h" - -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/ft_gemm_configs.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/fpA_intB_gemm.h" -#include "cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" - -#pragma GCC diagnostic pop - -#include "../cutlass_heuristic.h" -#include "fpA_intB_gemm.h" -#include "cuda_utils.h" - -namespace fastertransformer { - - template - void generic_mixed_gemm_kernelLauncher(const T *A, - const WeightType *B, - const T *weight_scales, - const T *biases, - T *C, - int m, - int n, - int k, - int bias_stride, - CutlassGemmConfig gemm_config, - char *workspace, - size_t workspace_bytes, - cudaStream_t stream, - int *occupancy = nullptr) - { - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for half, float"); - - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, - ""); - - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - using ElementType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; - using ElementType = ElementType_; - - using CutlassWeightType_ = typename cutlass::platform:: - conditional::value, cutlass::half_t, WeightType>::type; - using CutlassWeightType = CutlassWeightType_; - - // We need separate config for each architecture since we will target different tensorcore instructions. For float, - // we do not target TCs. - using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; - using ElementAccumulator = typename MixedGemmArchTraits::AccType; - - using EpilogueOp = - typename Epilogue::Op; - - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< - ElementType, - cutlass::layout::RowMajor, - MixedGemmArchTraits::ElementsPerAccessA, - CutlassWeightType, - typename MixedGemmArchTraits::LayoutB, - MixedGemmArchTraits::ElementsPerAccessB, - ElementType, - cutlass::layout::RowMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - arch, - ThreadblockShape, - WarpShape, - typename MixedGemmArchTraits::InstructionShape, - EpilogueOp, - typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - Stages, - true, - typename MixedGemmArchTraits::Operator>::GemmKernel; - - using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB; - - if (occupancy != nullptr) - { - *occupancy = compute_occupancy_for_kernel(); - return; - } - - using Gemm = cutlass::gemm::device::GemmUniversalBase; - - const int ldb = - cutlass::platform::is_same::value ? n : k * GemmKernel::kInterleave; - - typename Gemm::Arguments args({m, n, k}, - {reinterpret_cast(const_cast(A)), k}, - {reinterpret_cast(const_cast(B)), ldb}, - {reinterpret_cast(const_cast(weight_scales)), 0}, - // TODO: Support more general bias shape - {reinterpret_cast(const_cast(biases)), bias_stride}, - {reinterpret_cast(C), n}, - gemm_config.split_k_factor, - {ElementAccumulator(1.f), ElementAccumulator(0.f)}); - - // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of - // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the - // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write - // our own predicated iterator in order to relax this limitation. - if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) - { - throw std::runtime_error("Temp assertion: k must be multiple of threadblockK"); - } - - Gemm gemm; - if (gemm.get_workspace_size(args) > workspace_bytes) - { - FT_LOG_WARNING( - "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."); - // If requested split-k factor will require more workspace bytes, revert to standard gemm. - args.batch_count = 1; - } - - auto can_implement = gemm.can_implement(args); - if (can_implement != cutlass::Status::kSuccess) - { - std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); - throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); - } - - auto init_status = gemm.initialize(args, workspace, stream); - if (init_status != cutlass::Status::kSuccess) - { - std::string err_msg = - "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); - throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); - } - - auto run_status = gemm.run(stream); - if (run_status != cutlass::Status::kSuccess) - { - std::string err_msg = - "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); - throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); - } -} - -template -struct dispatch_stages { - static void dispatch(const T *A, - const WeightType *B, - const T *weight_scales, - const T *biases, - T *C, - int m, - int n, - int k, - int bias_stride, - CutlassGemmConfig gemm_config, - char *workspace, - size_t workspace_bytes, - cudaStream_t stream, - int *occupancy = nullptr) - { - - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); - throw std::runtime_error("[FT Error][dispatch_stages::dispatch] " + err_msg); - } -}; - -template -struct dispatch_stages { - static void dispatch(const T *A, - const WeightType *B, - const T *weight_scales, - const T *biases, - T *C, - int m, - int n, - int k, - int bias_stride, - CutlassGemmConfig gemm_config, - char *workspace, - size_t workspace_bytes, - cudaStream_t stream, - int *occupancy = nullptr) - { - - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - generic_mixed_gemm_kernelLauncher( - A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); - } -}; - -template -struct dispatch_stages 2)>::type> { - static void dispatch(const T *A, - const WeightType *B, - const T *weight_scales, - const T *biases, - T *C, - int m, - int n, - int k, - int bias_stride, - CutlassGemmConfig gemm_config, - char *workspace, - size_t workspace_bytes, - cudaStream_t stream, - int *occupancy = nullptr) - { - - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - generic_mixed_gemm_kernelLauncher( - A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); - } -}; - -template -void dispatch_gemm_config(const T *A, - const WeightType *B, - const T *weight_scales, - const T *biases, - T *C, - int m, - int n, - int k, - int bias_stride, - CutlassGemmConfig gemm_config, - char *workspace, - size_t workspace_bytes, - cudaStream_t stream, - int *occupancy = nullptr) -{ - - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - switch (gemm_config.stages) { - case 2: - using DispatcherStages2 = dispatch_stages; - DispatcherStages2::dispatch( - A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); - break; - case 3: - using DispatcherStages3 = dispatch_stages; - DispatcherStages3::dispatch( - A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); - break; - case 4: - using DispatcherStages4 = dispatch_stages; - DispatcherStages4::dispatch( - A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); - break; - default: - std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); - throw std::runtime_error("[FT Error][dispatch_gemm_config] " + err_msg); - break; - } -} - -template -void dispatch_gemm_to_cutlass(const T *A, - const WeightType *B, - const T *weight_scales, - const T *biases, - T *C, - int m, - int n, - int k, - int bias_stride, - char *workspace, - size_t workspace_bytes, - CutlassGemmConfig gemm_config, - cudaStream_t stream, - int *occupancy = nullptr) -{ - - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - - // Note that SIMT configs are omitted here since they are not supported for fpA_intB. - // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best - // for mixed type gemms. - switch (gemm_config.tile_config) { - case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 32, 64>>( - A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); - break; - case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 32, 64>>( - A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); - break; - case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<128, 32, 64>>( - A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy); - break; - case CutlassTileConfig::Undefined: - throw std::runtime_error("[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined."); - break; - case CutlassTileConfig::ChooseWithHeuristic: - throw std::runtime_error( - "[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by heuristic."); - break; - default: - throw std::runtime_error( - "[FT Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); - break; - } -} - -template -CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() -{ - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - sm_ = getSMVersion(); - check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); -} - -template -CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() -{ - FT_LOG_DEBUG(__PRETTY_FUNCTION__); -} - -template -template -void CutlassFpAIntBGemmRunner::dispatch_to_arch(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int m, - int n, - int k, - int bias_stride, - CutlassGemmConfig gemm_config, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream, - int* occupancy) -{ - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - if (sm_ >= 70 && sm_ < 75) { - dispatch_gemm_to_cutlass( - A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); - } else if (sm_ >= 75 && sm_ < 80) { - dispatch_gemm_to_cutlass( - A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); - } else if (sm_ >= 80 && sm_ < 90) { - dispatch_gemm_to_cutlass( - A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); - } - else { - throw std::runtime_error( - "[FT Error][CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed type GEMM"); - } -} - -template -template -void CutlassFpAIntBGemmRunner::run_gemm(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int m, - int n, - int k, - int bias_stride, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream) -{ - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - static constexpr bool is_weight_only = !std::is_same::value; - std::vector candidate_configs = get_candidate_configs(sm_, is_weight_only, false); - std::vector occupancies(candidate_configs.size()); - - for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { - dispatch_to_arch(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - bias_stride, - candidate_configs[ii], - workspace_ptr, - workspace_bytes, - stream, - &occupancies[ii]); - } - // Standard GEMM, so 1 "expert". We use the same function for MoE and regular FFN. - static constexpr int num_experts = 1; - CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(candidate_configs, - occupancies, - m, - n, - k, - num_experts, - split_k_limit, - workspace_bytes, - multi_processor_count_, - is_weight_only); - - dispatch_to_arch( - A, B, weight_scales, biases, C, m, n, k, bias_stride, chosen_config, workspace_ptr, workspace_bytes, stream); -} - -template -void CutlassFpAIntBGemmRunner::gemm_bias_act(const T *A, - const WeightType *B, - const T *weight_scales, - const T *biases, - T *C, - int m, - int n, - int k, - int bias_stride, - ActivationType activation_type, - char *workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream) -{ - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - - switch (activation_type) { - case ActivationType::Relu: - run_gemm( - A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); - break; - case ActivationType::Gelu: - run_gemm( - A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); - break; - case ActivationType::Silu: - run_gemm( - A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); - break; - case ActivationType::Identity: - run_gemm(A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream); - break; - case ActivationType::InvalidType: - FT_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be valid."); - break; - default: { - if (isGatedActivation(activation_type)) { - FT_CHECK_WITH_INFO(false, "Fused gated activations not supported"); - } - else { - FT_CHECK_WITH_INFO(false, "Invalid activation type."); - } - } - } -} - -template -void CutlassFpAIntBGemmRunner::gemm(const T* A, - const WeightType* B, - const T* weight_scales, - T* C, - int m, - int n, - int k, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream) -{ - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - run_gemm(A, B, weight_scales, nullptr, C, m, n, k, 0, workspace_ptr, workspace_bytes, stream); -} - -template -void dispatch_gemm_residual(const T *A, const WeightType *B, - const T *weight_scales, const T *biases, - const T *residual, T *C, int m, int n, int k, - char *workspace_ptr, const size_t workspace_bytes, - cudaStream_t stream) { - using ElementType = typename cutlass::platform::conditional< - cutlass::platform::is_same::value, cutlass::half_t, T>::type; - using ElementOutput = ElementType; - - using MixedGemmArchTraits = - cutlass::gemm::kernel::MixedGemmArchTraits; - using ElementAccumulator = typename EpilogueOp::ElementAccumulator; - - using Swizzle = - typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - using InstructionShape = typename MixedGemmArchTraits::InstructionShape; - - using Epilogue = typename cutlass::gemm::kernel::DefaultGemmWithBroadcast< - ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, - MixedGemmArchTraits::ElementsPerAccessA, WeightType, - typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone, - MixedGemmArchTraits::ElementsPerAccessB, ElementType, - cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape, - InstructionShape, EpilogueOp, Swizzle, stages, - typename MixedGemmArchTraits::Operator>::Epilogue; - - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< - ElementType, cutlass::layout::RowMajor, - MixedGemmArchTraits::ElementsPerAccessA, WeightType, - typename MixedGemmArchTraits::LayoutB, - MixedGemmArchTraits::ElementsPerAccessB, ElementType, - cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape, - InstructionShape, EpilogueOp, Swizzle, stages, true, - typename MixedGemmArchTraits::Operator>::GemmKernel; - - using GemmKernel = cutlass::gemm::kernel::GemmFpAIntBWithBroadcast< - typename GemmKernel_::Mma, Epilogue, - typename GemmKernel_::ThreadblockSwizzle, Arch>; - - using Gemm = cutlass::gemm::device::GemmUniversalBase; - - // TODO: Support batch - const int batch_count = 1; - const auto lda = k; - const int ldb = - cutlass::platform::is_same::value - ? n - : k * GemmKernel::kInterleave; - const int ldc = n; - - typename Gemm::Arguments args( - {m, n, k}, batch_count, - {ElementAccumulator(1.f), ElementAccumulator(1.f)}, A, B, weight_scales, - residual, C, biases, nullptr, 0, 0, 0, 0, 0, 0, lda, ldb, ldc, ldc, 0, 0); - - if (GemmKernel::kInterleave > 1 && - ((k % MixedGemmArchTraits::ThreadblockK) || - (k % MixedGemmArchTraits::ThreadblockK))) { - throw std::runtime_error( - "Temp assertion: k must be multiple of threadblockK"); - } - - Gemm gemm; - auto can_implement = gemm.can_implement(args); - if (can_implement != cutlass::Status::kSuccess) { - std::string err_msg = - "fpA_intB cutlass kernel will fail for params. Error: " + - std::string(cutlassGetStatusString(can_implement)); - throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); - } - - auto init_status = gemm.initialize(args, workspace_ptr, stream); - if (init_status != cutlass::Status::kSuccess) { - std::string err_msg = - "Failed to initialize cutlass fpA_intB gemm. Error: " + - std::string(cutlassGetStatusString(init_status)); - throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); - } - - auto run_status = gemm.run(stream); - if (run_status != cutlass::Status::kSuccess) { - std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + - std::string(cutlassGetStatusString(run_status)); - throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg); - } -} - -template -void dispatch_gemm_residual(CutlassTileConfig tile_config, const T *A, - const WeightType *B, const T *weight_scales, - const T *biases, const T *residual, T *C, int m, - int n, int k, char *workspace_ptr, - const size_t workspace_bytes, cudaStream_t stream) { - if (tile_config == CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64) { - dispatch_gemm_residual< - T, WeightType, Arch, cutlass::gemm::GemmShape<32, 128, 64>, - cutlass::gemm::GemmShape<32, 32, 64>, EpilogueOp, stages>( - A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, - workspace_bytes, stream); - } else if (tile_config == - CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64) { - dispatch_gemm_residual< - T, WeightType, Arch, cutlass::gemm::GemmShape<64, 128, 64>, - cutlass::gemm::GemmShape<64, 32, 64>, EpilogueOp, stages>( - A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, - workspace_bytes, stream); - } else { // CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatch_gemm_residual< - T, WeightType, Arch, cutlass::gemm::GemmShape<128, 128, 64>, - cutlass::gemm::GemmShape<128, 32, 64>, EpilogueOp, stages>( - A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr, - workspace_bytes, stream); - } -} - -template -void dispatch_gemm_residual(CutlassGemmConfig config, const T *A, - const WeightType *B, const T *weight_scales, - const T *biases, const T *residual, T *C, int m, - int n, int k, char *workspace_ptr, - const size_t workspace_bytes, cudaStream_t stream) { - if constexpr (std::is_same::value) { - dispatch_gemm_residual( - config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, - workspace_ptr, workspace_bytes, stream); - } else if constexpr (std::is_same::value) { - dispatch_gemm_residual( - config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, - workspace_ptr, workspace_bytes, stream); - } else { - if (config.stages == 3) { - dispatch_gemm_residual( - config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, - workspace_ptr, workspace_bytes, stream); - } else if (config.stages == 4) { - dispatch_gemm_residual( - config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, - workspace_ptr, workspace_bytes, stream); - } else { // 2 - dispatch_gemm_residual( - config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k, - workspace_ptr, workspace_bytes, stream); - } - } -} - -template class ActivationOp, - template class BinaryOp> -inline void -dispatch_gemm_residual(CutlassGemmConfig config, const T *A, - const WeightType *B, const T *weight_scales, - const T *biases, const T *residual, T *C, int m, int n, - int k, const std::string &unary_op, char *workspace_ptr, - const size_t workspace_bytes, cudaStream_t stream) { - using ElementOutput = T; - using MixedGemmArchTraits = - cutlass::gemm::kernel::MixedGemmArchTraits; - using ElementAccumulator = typename MixedGemmArchTraits::AccType; - - if (unary_op == "identity") { - using EpilogueOp = - cutlass::epilogue::thread::LinearCombinationResidualBlock< - ElementOutput, ElementAccumulator, ElementAccumulator, - ElementOutput, 128 / cutlass::sizeof_bits::value, - ActivationOp, BinaryOp, cutlass::epilogue::thread::Identity>; - dispatch_gemm_residual( - config, A, B, weight_scales, biases, residual, C, m, n, k, - workspace_ptr, workspace_bytes, stream); - } else if (unary_op == "relu") { - using EpilogueOp = - cutlass::epilogue::thread::LinearCombinationResidualBlock< - ElementOutput, ElementAccumulator, ElementAccumulator, - ElementOutput, 128 / cutlass::sizeof_bits::value, - ActivationOp, BinaryOp, cutlass::epilogue::thread::ReLu>; - dispatch_gemm_residual( - config, A, B, weight_scales, biases, residual, C, m, n, k, - workspace_ptr, workspace_bytes, stream); - } else { - throw std::runtime_error( - "[FT Error][Unsupported unary op after residual block] " + unary_op); - } -} - -template class ActivationOp> -void dispatch_gemm_residual(CutlassGemmConfig config, const T *A, - const WeightType *B, const T *weight_scales, - const T *biases, const T *residual, T *C, int m, - int n, int k, const std::string &binary_op, - const std::string &unary_op, char *workspace_ptr, - const size_t workspace_bytes, cudaStream_t stream) { - if (binary_op == "plus") { - dispatch_gemm_residual( - config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op, - workspace_ptr, workspace_bytes, stream); - } else if (binary_op == "multiply") { - dispatch_gemm_residual( - config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op, - workspace_ptr, workspace_bytes, stream); - } else { - throw std::runtime_error( - "[FT Error][Unsupported binary op for residual block] " + binary_op); - } -} - -template -void dispatch_gemm_residual(CutlassGemmConfig config, const T *A, - const WeightType *B, const T *weight_scales, - const T *biases, const T *residual, T *C, int m, - int n, int k, const std::string &activation, - const std::string &binary_op, - const std::string &unary_op, char *workspace_ptr, - const size_t workspace_bytes, cudaStream_t stream) { - if (activation == "identity") { - dispatch_gemm_residual( - config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, - unary_op, workspace_ptr, workspace_bytes, stream); - } else if ("silu") { - dispatch_gemm_residual( - config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, - unary_op, workspace_ptr, workspace_bytes, stream); - } else if ("relu") { - dispatch_gemm_residual( - config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, - unary_op, workspace_ptr, workspace_bytes, stream); - } else if ("gelu") { - dispatch_gemm_residual( - config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op, - unary_op, workspace_ptr, workspace_bytes, stream); - } else { - throw std::runtime_error( - "[FT Error][Unsupported activation before residual binary op] " + - activation); - } -} - -template -void CutlassFpAIntBGemmRunner::gemm_bias_act_residual( - const T *A, const WeightType *B, const T *weight_scales, const T *biases, - const T *residual, T *C, int m, int n, int k, const std::string &activation, - const std::string &binary_op, const std::string &unary_op, - char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { - - std::vector candidate_configs = - get_candidate_configs(sm_, true, false); - std::vector occupancies(candidate_configs.size()); - - for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { - dispatch_to_arch( - A, B, weight_scales, biases, C, m, n, k, 0, candidate_configs[ii], - workspace_ptr, workspace_bytes, stream, &occupancies[ii]); - } - - CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies( - candidate_configs, occupancies, m, n, k, 1, split_k_limit, - workspace_bytes, multi_processor_count_, true); - - if (sm_ >= 80 && sm_ < 90) { - dispatch_gemm_residual( - chosen_config, A, B, weight_scales, biases, residual, C, m, n, k, - activation, binary_op, unary_op, workspace_ptr, workspace_bytes, - stream); - } else if (sm_ >= 75 && sm_ < 80) { - dispatch_gemm_residual( - chosen_config, A, B, weight_scales, biases, residual, C, m, n, k, - activation, binary_op, unary_op, workspace_ptr, workspace_bytes, - stream); - } else if (sm_ == 70) { - dispatch_gemm_residual( - chosen_config, A, B, weight_scales, biases, residual, C, m, n, k, - activation, binary_op, unary_op, workspace_ptr, workspace_bytes, - stream); - } else { - throw std::runtime_error("[FT Error][Unsupported SM] " + sm_); - } -} - -template -int CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, const int n, const int k) -{ - FT_LOG_DEBUG(__PRETTY_FUNCTION__); - // TODO(masahi): Shouldn't it be 0? - - // These are the min tile sizes for each config, which would launch the maximum number of blocks - const int max_grid_m = (m + 31) / 32; - const int max_grid_n = (n + 127) / 128; - // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. - return max_grid_m * max_grid_n * split_k_limit * 4; -} - -} // namespace fastertransformer diff --git a/cutlass_kernels/fpA_intB_gemm_wrapper.cu b/cutlass_kernels/fpA_intB_gemm_wrapper.cu deleted file mode 100644 index aed1c3a02011c5ada7aad82885a2f776f32dc6f8..0000000000000000000000000000000000000000 --- a/cutlass_kernels/fpA_intB_gemm_wrapper.cu +++ /dev/null @@ -1,201 +0,0 @@ -#include -#include "cub/cub.cuh" -#include -#include -#include -#include "fpA_intB_gemm_wrapper.h" -#include "fpA_intB_gemm.h" -#include "cutlass_preprocessors.h" -#include "cuda_utils.h" -#include "weightOnlyBatchedGemv/enabled.h" -#include "weightOnlyBatchedGemv/kernelLauncher.h" -#include "torch_utils.h" - -#include - -namespace ft = fastertransformer; - -int getWorkspaceSize(const int m, const int n, const int k) -{ - // These are the min tile sizes for each config, which would launch the maximum number of blocks - const int max_grid_m = (m + 31) / 32; - const int max_grid_n = (n + 127) / 128; - const int split_k_limit = 7; - // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. - return max_grid_m * max_grid_n * split_k_limit * 4; -} - -std::vector -symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight, - at::ScalarType quant_type, - bool return_unprocessed_quantized_tensor) -{ - CHECK_CPU(weight); - CHECK_CONTIGUOUS(weight); - TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor"); - TORCH_CHECK(weight.dim() == 2 || weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3"); - - auto _st = weight.scalar_type(); - TORCH_CHECK(_st == torch::kFloat32 || _st == torch::kFloat16, "Invalid datatype. Weight must be FP16 or FP32"); - TORCH_CHECK(quant_type == torch::kInt8 || quant_type == at::ScalarType::QUInt4x2, "Must be int4 or int8 quantization"); - ft::QuantType ft_quant_type = ft::get_ft_quant_type(quant_type); - - const size_t num_experts = weight.dim() == 2 ? 1 : weight.size(0); - const size_t num_rows = weight.size(-2); - const size_t num_cols = weight.size(-1); - - const size_t bits_in_type = ft::get_bits_in_quant_type(ft_quant_type); - const size_t bytes_per_out_col = num_cols * bits_in_type / 8; - - const size_t input_mat_size = num_rows * num_cols; - const size_t quantized_mat_size = num_rows * bytes_per_out_col; - - std::vector quantized_weight_shape; - std::vector scale_shape; - if (weight.dim() == 2) { - quantized_weight_shape = {long(num_rows), long(bytes_per_out_col)}; - scale_shape = {long(num_cols)}; - } - else if (weight.dim() == 3) { - quantized_weight_shape = {long(num_experts), long(num_rows), long(bytes_per_out_col)}; - scale_shape = {long(num_experts), long(num_cols)}; - } - else { - TORCH_CHECK(false, "Invalid weight dimension. Weight must have dim 2 or 3"); - } - - torch::Tensor unprocessed_quantized_weight = - torch::empty(quantized_weight_shape, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false)); - - torch::Tensor processed_quantized_weight = torch::empty_like(unprocessed_quantized_weight); - - torch::Tensor scales = torch::empty(scale_shape, torch::dtype(weight.dtype()).device(torch::kCPU).requires_grad(false)); - - int8_t *unprocessed_quantized_weight_ptr = reinterpret_cast(unprocessed_quantized_weight.data_ptr()); - int8_t *processed_quantized_weight_ptr = reinterpret_cast(processed_quantized_weight.data_ptr()); - - if (weight.scalar_type() == at::ScalarType::Float) - { - ft::symmetric_quantize(processed_quantized_weight_ptr, - unprocessed_quantized_weight_ptr, - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(weight.data_ptr()), - {num_rows, num_cols}, - ft_quant_type); - } - else if (weight.scalar_type() == at::ScalarType::Half) - { - ft::symmetric_quantize(processed_quantized_weight_ptr, - unprocessed_quantized_weight_ptr, - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(weight.data_ptr()), - {num_rows, num_cols}, - ft_quant_type); - } - else - { - TORCH_CHECK(false, "Invalid data type. Weight must be FP32/FP16"); - } - - if (return_unprocessed_quantized_tensor) - { - return std::vector{unprocessed_quantized_weight, processed_quantized_weight, scales}; - } - - return std::vector{processed_quantized_weight, scales}; -} - -torch::Tensor preprocess_weights_cuda(torch::Tensor const &origin_weight, - bool is_int4) -{ - // guarantee the weight is cpu tensor - CHECK_CPU(origin_weight); - - torch::Tensor preprocessed_quantized_weight = torch::empty_like(origin_weight); - int8_t *preprocessed_quantized_weight_ptr = reinterpret_cast(preprocessed_quantized_weight.data_ptr()); - const int8_t *row_major_quantized_weight_ptr = reinterpret_cast(origin_weight.data_ptr()); - size_t rows = origin_weight.size(-2); - size_t cols = origin_weight.size(-1); - int arch = ft::getSMVersion(); - ft::preprocess_weights(preprocessed_quantized_weight_ptr, - row_major_quantized_weight_ptr, - rows, - cols, - is_int4, - arch); - return preprocessed_quantized_weight; -} - -torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input, - torch::Tensor const &weight, - torch::Tensor const &scale) -{ - c10::cuda::CUDAGuard device_guard(input.device()); - // TORCH_CHECK(input.dim() == 3 || input.dim() == 2, "Invalid input dim: ", input.dim()); - const int m = input.dim() == 2 ? input.size(0) : input.size(0) * input.size(1); - const int k = input.size(-1); - const int n = weight.size(-1); - auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); - torch::Tensor output = input.dim() == 2 ? torch::empty({m, n}, options) : torch::empty({input.size(0), input.size(1), n}, options); - const ft::half *input_ptr = reinterpret_cast(input.data_ptr()); - const uint8_t *weight_ptr = reinterpret_cast(weight.data_ptr()); - const ft::half *scale_ptr = reinterpret_cast(scale.data_ptr()); - ft::half *output_ptr = reinterpret_cast(output.data_ptr()); - // const int max_size = std::max(n, k); - // size_t workspace_size = getWorkspaceSize(m, max_size, max_size); - // void *ptr = nullptr; - // char *workspace_ptr = workspace_size > 0 ? (char *)cudaMalloc((void **)&ptr, workspace_size) : nullptr; - const bool use_cuda_kernel = m <= SMALL_M_FAST_PATH; - // const bool use_cuda_kernel = false; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if(use_cuda_kernel){ - tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16; - tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b; - tensorrt_llm::kernels::WeightOnlyParams params{weight_ptr, reinterpret_cast(scale.data_ptr()), nullptr, - reinterpret_cast(input.data_ptr()), nullptr, nullptr, reinterpret_cast(output.data_ptr()), m, n, k, 0, weight_only_quant_type, - tensorrt_llm::kernels::WeightOnlyType::PerChannel, - tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type}; - tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream); - } - else - ft::gemm_fp16_int( - input_ptr, - weight_ptr, - scale_ptr, - output_ptr, - m, n, k, - nullptr, - 0, - stream); - return output; -} - - -torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input, - torch::Tensor const &weight, - torch::Tensor const &scale, - torch::Tensor &output, - const int64_t m, - const int64_t n, - const int64_t k) -{ - c10::cuda::CUDAGuard device_guard(input.device()); - - const ft::half *input_ptr = reinterpret_cast(input.data_ptr()); - const uint8_t *weight_ptr = reinterpret_cast(weight.data_ptr()); - const ft::half *scale_ptr = reinterpret_cast(scale.data_ptr()); - ft::half *output_ptr = reinterpret_cast(output.data_ptr()); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - ft::gemm_fp16_int( - input_ptr, - weight_ptr, - scale_ptr, - output_ptr, - m, n, k, - nullptr, - 0, - stream); - return output; -} diff --git a/cutlass_kernels/fpA_intB_gemm_wrapper.h b/cutlass_kernels/fpA_intB_gemm_wrapper.h deleted file mode 100644 index a53d89e7413589be942f6dacdd5c3944526f110c..0000000000000000000000000000000000000000 --- a/cutlass_kernels/fpA_intB_gemm_wrapper.h +++ /dev/null @@ -1,23 +0,0 @@ -#include -#include - -#define SMALL_M_FAST_PATH 4 -std::vector -symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight, - at::ScalarType quant_type, - bool return_unprocessed_quantized_tensor); - -torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight, - bool is_int4); - -torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input, - torch::Tensor const &weight, - torch::Tensor const &scale); - -torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input, - torch::Tensor const &weight, - torch::Tensor const &scale, - torch::Tensor &output, - const int64_t m, - const int64_t n, - const int64_t k); diff --git a/flake.lock b/flake.lock deleted file mode 100644 index 3183af1b7105e56865a4c91101f0d6c77e65dbd4..0000000000000000000000000000000000000000 --- a/flake.lock +++ /dev/null @@ -1,169 +0,0 @@ -{ - "nodes": { - "flake-compat": { - "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-compat_2": { - "locked": { - "lastModified": 1733328505, - "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-utils": { - "inputs": { - "systems": "systems" - }, - "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "flake-utils_2": { - "inputs": { - "systems": "systems_2" - }, - "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "hf-nix": { - "inputs": { - "flake-compat": "flake-compat_2", - "flake-utils": "flake-utils_2", - "nixpkgs": "nixpkgs" - }, - "locked": { - "lastModified": 1753354560, - "narHash": "sha256-vmOfRmr0Qm/IbZTWB2sBn+UFrABSTTA/cTg+m27Yt/E=", - "owner": "huggingface", - "repo": "hf-nix", - "rev": "7f2aceda2a2e72cd573bdb25e5c0667fd75f89d3", - "type": "github" - }, - "original": { - "owner": "huggingface", - "repo": "hf-nix", - "type": "github" - } - }, - "kernel-builder": { - "inputs": { - "flake-compat": "flake-compat", - "flake-utils": "flake-utils", - "hf-nix": "hf-nix", - "nixpkgs": [ - "kernel-builder", - "hf-nix", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1753602110, - "narHash": "sha256-AEt6rSqYqSTgsKZ+2BuGezurpVC2gm+Jpjqg2D54n7E=", - "owner": "huggingface", - "repo": "kernel-builder", - "rev": "2021ea0f8d9e63ada986f189d077fc301cc1c3c9", - "type": "github" - }, - "original": { - "owner": "huggingface", - "ref": "torch-2.8", - "repo": "kernel-builder", - "type": "github" - } - }, - "nixpkgs": { - "locked": { - "lastModified": 1752785354, - "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=", - "owner": "nixos", - "repo": "nixpkgs", - "rev": "d38025438a6ee456758dc03188ca6873a415463b", - "type": "github" - }, - "original": { - "owner": "nixos", - "repo": "nixpkgs", - "rev": "d38025438a6ee456758dc03188ca6873a415463b", - "type": "github" - } - }, - "root": { - "inputs": { - "kernel-builder": "kernel-builder" - } - }, - "systems": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - }, - "systems_2": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - } - }, - "root": "root", - "version": 7 -} diff --git a/flake.nix b/flake.nix deleted file mode 100644 index 9f5ba1f1418aaf002bbdc429ca330add999cf488..0000000000000000000000000000000000000000 --- a/flake.nix +++ /dev/null @@ -1,17 +0,0 @@ -{ - description = "Flake for EETQ kernels"; - - inputs = { - kernel-builder.url = "github:huggingface/kernel-builder/torch-2.8"; - }; - - outputs = - { - self, - kernel-builder, - }: - kernel-builder.lib.genFlakeOutputs { - path = ./.; - rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; - }; -} diff --git a/torch-ext/quantization_eetq/__init__.py b/torch-ext/quantization_eetq/__init__.py deleted file mode 100644 index c65d0601c655d7acf1a12e61b6549618b46a70d7..0000000000000000000000000000000000000000 --- a/torch-ext/quantization_eetq/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .custom_ops import w8_a16_gemm, w8_a16_gemm_, preprocess_weights, quant_weights - -__all__ = ["w8_a16_gemm", "w8_a16_gemm_", "preprocess_weights", "quant_weights"] diff --git a/torch-ext/quantization_eetq/custom_ops.py b/torch-ext/quantization_eetq/custom_ops.py deleted file mode 100644 index 005b5a6e3cd5f7bcfd4aa5d7d80d60a5ed9fab88..0000000000000000000000000000000000000000 --- a/torch-ext/quantization_eetq/custom_ops.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import List -import torch - -from ._ops import ops - - -def w8_a16_gemm( - input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor -) -> torch.Tensor: - return ops.w8_a16_gemm(input, weight, scale) - - -def w8_a16_gemm_( - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - output: torch.Tensor, - m: int, - n: int, - k: int, -) -> torch.Tensor: - return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k) - - -def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor: - return ops.preprocess_weights(origin_weight, is_int4) - - -def quant_weights( - origin_weight: torch.Tensor, - quant_type: torch.dtype, - return_unprocessed_quantized_tensor: bool, -) -> List[torch.Tensor]: - return ops.quant_weights( - origin_weight, quant_type, return_unprocessed_quantized_tensor - ) diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp deleted file mode 100644 index 39e3c7a3d4dbae63904cf0232a4aaf437b893b98..0000000000000000000000000000000000000000 --- a/torch-ext/torch_binding.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include - -#include "registration.h" -#include "torch_binding.h" - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { - ops.def("w8_a16_gemm(Tensor input, Tensor weight, Tensor scale) -> Tensor"); - ops.impl("w8_a16_gemm", torch::kCUDA, &w8_a16_gemm_forward_cuda); - ops.def("w8_a16_gemm_(Tensor input, Tensor weight, Tensor scale, Tensor! output," - "int m, int n, int k) -> Tensor"); - ops.impl("w8_a16_gemm_", torch::kCUDA, &w8_a16_gemm_forward_cuda_); - ops.def("preprocess_weights(Tensor origin_weight, bool is_int4) -> Tensor"); - ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda); - ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type," - "bool return_unprocessed_quantized_tensor) -> Tensor[]"); - ops.impl("quant_weights", torch::kCPU, &symmetric_quantize_last_axis_of_tensor); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h deleted file mode 100644 index 0af398f24b6313742dc24ef8c5dedaa4b91fc06f..0000000000000000000000000000000000000000 --- a/torch-ext/torch_binding.h +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include - -#include - -std::vector -symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight, - at::ScalarType quant_type, - bool return_unprocessed_quantized_tensor); - -torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight, - bool is_int4); - -torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input, - torch::Tensor const&weight, - torch::Tensor const &scale); - -torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input, - torch::Tensor const &weight, - torch::Tensor const &scale, - torch::Tensor &output, - const int64_t m, - const int64_t n, - const int64_t k); diff --git a/utils/activation_types.h b/utils/activation_types.h deleted file mode 100644 index cd90d71f688fe2af04c84ee0ac2328df677e7e72..0000000000000000000000000000000000000000 --- a/utils/activation_types.h +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_utils.h" - -namespace fastertransformer { - -enum class ActivationType { - Gelu, - Relu, - Silu, - GeGLU, - ReGLU, - SiGLU, - Identity, - InvalidType -}; - -inline bool isGatedActivation(ActivationType activaiton_type) -{ - return activaiton_type == ActivationType::GeGLU || activaiton_type == ActivationType::ReGLU - || activaiton_type == ActivationType::SiGLU; -} - -} // namespace fastertransformer diff --git a/utils/cuda_utils.cc b/utils/cuda_utils.cc deleted file mode 100644 index 2f36f1053914fadff4ac5b502ae8863189c46e09..0000000000000000000000000000000000000000 --- a/utils/cuda_utils.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cuda_utils.h" - -namespace fastertransformer { - -/* ***************************** common utils ****************************** */ - -cudaError_t getSetDevice(int i_device, int* o_device) -{ - int current_dev_id = 0; - cudaError_t err = cudaSuccess; - - if (o_device != NULL) { - err = cudaGetDevice(¤t_dev_id); - if (err != cudaSuccess) { - return err; - } - if (current_dev_id == i_device) { - *o_device = i_device; - } - else { - err = cudaSetDevice(i_device); - if (err != cudaSuccess) { - return err; - } - *o_device = current_dev_id; - } - } - else { - err = cudaSetDevice(i_device); - if (err != cudaSuccess) { - return err; - } - } - - return cudaSuccess; -} - -/* ************************** end of common utils ************************** */ -} // namespace fastertransformer diff --git a/utils/cuda_utils.h b/utils/cuda_utils.h deleted file mode 100644 index 2f75300d609c840c3ce9319c43c714682e873e02..0000000000000000000000000000000000000000 --- a/utils/cuda_utils.h +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "logger.h" - -#include -#include -#include -#include -#include - -namespace fastertransformer { -/* **************************** debug tools ********************************* */ -template -void check(T result, char const* const func, const char* const file, int const line) -{ - if (result) { - throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + ("") + " " - + file + ":" + std::to_string(line) + " \n"); - } -} - -#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) - -[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "") -{ - throw std::runtime_error(std::string("[FT][ERROR] ") + info + " Assertion fail: " + file + ":" - + std::to_string(line) + " \n"); -} - -inline void myAssert(bool result, const char* const file, int const line, std::string const& info = "") -{ - if (!result) { - throwRuntimeError(file, line, info); - } -} - -#define FT_CHECK(val) myAssert(val, __FILE__, __LINE__) -#define FT_CHECK_WITH_INFO(val, info) \ - do { \ - bool is_valid_val = (val); \ - if (!is_valid_val) { \ - fastertransformer::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \ - } \ - } while (0) - -/* ***************************** common utils ****************************** */ -inline int getSMVersion() -{ - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - int sm_major = 0; - int sm_minor = 0; - check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); - check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); - return sm_major * 10 + sm_minor; -} - -cudaError_t getSetDevice(int i_device, int* o_device = NULL); -/* ************************** end of common utils ************************** */ -} // namespace fastertransformer diff --git a/utils/logger.cc b/utils/logger.cc deleted file mode 100644 index 764d245927e22d5939ca05f569ba75c50b1f49c5..0000000000000000000000000000000000000000 --- a/utils/logger.cc +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "logger.h" -#include - -namespace fastertransformer { - -Logger::Logger() -{ - char* is_first_rank_only_char = std::getenv("FT_LOG_FIRST_RANK_ONLY"); - bool is_first_rank_only = - (is_first_rank_only_char != nullptr && std::string(is_first_rank_only_char) == "ON") ? true : false; - - int device_id; - cudaGetDevice(&device_id); - - char* level_name = std::getenv("FT_LOG_LEVEL"); - if (level_name != nullptr) { - std::map name_to_level = { - {"TRACE", TRACE}, - {"DEBUG", DEBUG}, - {"INFO", INFO}, - {"WARNING", WARNING}, - {"ERROR", ERROR}, - }; - auto level = name_to_level.find(level_name); - // If FT_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR - if (is_first_rank_only && device_id != 0) { - level = name_to_level.find("ERROR"); - } - if (level != name_to_level.end()) { - setLevel(level->second); - } - else { - fprintf(stderr, - "[FT][WARNING] Invalid logger level FT_LOG_LEVEL=%s. " - "Ignore the environment variable and use a default " - "logging level.\n", - level_name); - level_name = nullptr; - } - } -} - -} // namespace fastertransformer diff --git a/utils/logger.h b/utils/logger.h deleted file mode 100644 index a93dc0d5fcd94b5b568a99a35d9953855df81ce4..0000000000000000000000000000000000000000 --- a/utils/logger.h +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include "string_utils.h" - -namespace fastertransformer { - -class Logger { - -public: - enum Level { - TRACE = 0, - DEBUG = 10, - INFO = 20, - WARNING = 30, - ERROR = 40 - }; - - static Logger& getLogger() - { - thread_local Logger instance; - return instance; - } - Logger(Logger const&) = delete; - void operator=(Logger const&) = delete; - - template - void log(const Level level, const std::string format, const Args&... args) - { - if (level_ <= level) { - std::string fmt = getPrefix(level) + format + "\n"; - FILE* out = level_ < WARNING ? stdout : stderr; - std::string logstr = fmtstr(fmt, args...); - fprintf(out, "%s", logstr.c_str()); - } - } - - template - void log(const Level level, const int rank, const std::string format, const Args&... args) - { - if (level_ <= level) { - std::string fmt = getPrefix(level, rank) + format + "\n"; - FILE* out = level_ < WARNING ? stdout : stderr; - std::string logstr = fmtstr(fmt, args...); - fprintf(out, "%s", logstr.c_str()); - } - } - - void setLevel(const Level level) - { - level_ = level; - log(INFO, "Set logger level by %s", getLevelName(level).c_str()); - } - - int getLevel() const - { - return level_; - } - -private: - const std::string PREFIX = "[FT]"; - const std::map level_name_ = { - {TRACE, "TRACE"}, {DEBUG, "DEBUG"}, {INFO, "INFO"}, {WARNING, "WARNING"}, {ERROR, "ERROR"}}; - -#ifndef NDEBUG - const Level DEFAULT_LOG_LEVEL = DEBUG; -#else - const Level DEFAULT_LOG_LEVEL = INFO; -#endif - Level level_ = DEFAULT_LOG_LEVEL; - - Logger(); - - inline const std::string getLevelName(const Level level) - { - return level_name_.at(level); - } - - inline const std::string getPrefix(const Level level) - { - return PREFIX + "[" + getLevelName(level) + "] "; - } - - inline const std::string getPrefix(const Level level, const int rank) - { - return PREFIX + "[" + getLevelName(level) + "][" + std::to_string(rank) + "] "; - } -}; - -#define FT_LOG(level, ...) \ - do { \ - if (fastertransformer::Logger::getLogger().getLevel() <= level) { \ - fastertransformer::Logger::getLogger().log(level, __VA_ARGS__); \ - } \ - } while (0) - -#define FT_LOG_TRACE(...) FT_LOG(fastertransformer::Logger::TRACE, __VA_ARGS__) -#define FT_LOG_DEBUG(...) FT_LOG(fastertransformer::Logger::DEBUG, __VA_ARGS__) -#define FT_LOG_INFO(...) FT_LOG(fastertransformer::Logger::INFO, __VA_ARGS__) -#define FT_LOG_WARNING(...) FT_LOG(fastertransformer::Logger::WARNING, __VA_ARGS__) -#define FT_LOG_ERROR(...) FT_LOG(fastertransformer::Logger::ERROR, __VA_ARGS__) -} // namespace fastertransformer diff --git a/utils/string_utils.h b/utils/string_utils.h deleted file mode 100644 index ad7b5a0592f504b4a97fbadf6deae7a7bdffacd8..0000000000000000000000000000000000000000 --- a/utils/string_utils.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // std::make_unique -#include // std::stringstream -#include -#include - -namespace fastertransformer { - -template -inline std::string fmtstr(const std::string& format, Args... args) -{ - // This function came from a code snippet in stackoverflow under cc-by-1.0 - // https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf - - // Disable format-security warning in this function. -#if defined(_MSC_VER) // for visual studio -#pragma warning(push) -#pragma warning(warning(disable : 4996)) -#elif defined(__GNUC__) || defined(__clang__) // for gcc or clang -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wformat-security" -#endif - int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0' - if (size_s <= 0) { - throw std::runtime_error("Error during formatting."); - } - auto size = static_cast(size_s); - auto buf = std::make_unique(size); - std::snprintf(buf.get(), size, format.c_str(), args...); -#if defined(_MSC_VER) -#pragma warning(pop) -#elif defined(__GNUC__) || defined(__clang__) -#pragma GCC diagnostic pop -#endif - return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside -} -} // namespace fastertransformer diff --git a/utils/torch_utils.h b/utils/torch_utils.h deleted file mode 100644 index 447cab342311980815e3d7bda5d5d40e98f68e2d..0000000000000000000000000000000000000000 --- a/utils/torch_utils.h +++ /dev/null @@ -1,68 +0,0 @@ -#pragma once -#include "torch/csrc/cuda/Stream.h" -#include "torch/all.h" -#include -#include -#include -#include -#include -// Generates a conflict with CUDA 12.6 between nvtx 2 and 3. Does not -// seem to be used anyway? -// -// #include -#include -#include -#include - -#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") -#define CHECK_TYPE(x, st) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x) -#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x, st) \ - CHECK_TH_CUDA(x); \ - CHECK_CONTIGUOUS(x); \ - CHECK_TYPE(x, st) -#define CHECK_CPU_INPUT(x, st) \ - CHECK_CPU(x); \ - CHECK_CONTIGUOUS(x); \ - CHECK_TYPE(x, st) -#define CHECK_OPTIONAL_INPUT(x, st) \ - if (x.has_value()) { \ - CHECK_INPUT(x.value(), st); \ - } -#define CHECK_OPTIONAL_CPU_INPUT(x, st) \ - if (x.has_value()) { \ - CHECK_CPU_INPUT(x.value(), st); \ - } -#define PRINT_TENSOR(x) std::cout << #x << ":\n" << x << std::endl -#define PRINT_TENSOR_SIZE(x) std::cout << "size of " << #x << ": " << x.sizes() << std::endl - -namespace fastertransformer { - -template -inline T* get_ptr(torch::Tensor& t) -{ - return reinterpret_cast(t.data_ptr()); -} - -std::vector convert_shape(torch::Tensor tensor); - -size_t sizeBytes(torch::Tensor tensor); - -QuantType get_ft_quant_type(torch::ScalarType quant_type) -{ - if (quant_type == torch::kInt8) { - return QuantType::INT8_WEIGHT_ONLY; - } - else if (quant_type == at::ScalarType::QUInt4x2) { - return QuantType::PACKED_INT4_WEIGHT_ONLY; - } - else { - TORCH_CHECK(false, "Invalid quantization type"); - } -} - -} // namespace fastertransformer diff --git a/weightOnlyBatchedGemv/common.h b/weightOnlyBatchedGemv/common.h deleted file mode 100644 index 3628fdf37168baa8bd4dcaa10987d09ddde96dc9..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/common.h +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include -#if defined(ENABLE_BF16) -#include -#endif -#include -#include -#include - -namespace tensorrt_llm -{ -namespace kernels -{ -enum class WeightOnlyQuantType -{ - Int4b, - Int8b -}; -enum class WeightOnlyType -{ - PerChannel, - GroupWise -}; - -struct WeightOnlyPerChannel; -template -struct WeightOnlyGroupWise; - -enum class WeightOnlyActivationFunctionType -{ - Gelu, - Relu, - Identity, - InvalidType -}; - -enum class WeightOnlyActivationType -{ - FP16, - BF16 -}; - -struct WeightOnlyParams -{ - // ActType is fp16 or bf16 - using ActType = void; - using WeiType = uint8_t; - - const uint8_t* qweight; - const ActType* scales; - const ActType* zeros; - const ActType* in; - const ActType* act_scale; - const ActType* bias; - ActType* out; - const int m; - const int n; - const int k; - const int group_size; - WeightOnlyQuantType quant_type; - WeightOnlyType weight_only_type; - WeightOnlyActivationFunctionType act_func_type; - WeightOnlyActivationType act_type; - - WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in, - const ActType* _act_scale, const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k, - const int _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type, - const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type) - : qweight(_qweight) - , scales(_scales) - , zeros(_zeros) - , in(_in) - , act_scale(_act_scale) - , bias(_bias) - , out(_out) - , m(_m) - , n(_n) - , k(_k) - , group_size(_group_size) - , quant_type(_quant_type) - , weight_only_type(_weight_only_type) - , act_func_type(_act_func_type) - , act_type(_act_type) - { - } -}; -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/enabled.h b/weightOnlyBatchedGemv/enabled.h deleted file mode 100644 index 5c77bc75785db512f84d7b9841d7b381c5463125..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/enabled.h +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -#include "common.h" -#include - - -inline int getSMVersion() -{ - int device{-1}; - cudaGetDevice(&device); - int sm_major = 0; - int sm_minor = 0; - cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device); - cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device); - return sm_major * 10 + sm_minor; -} - -namespace tensorrt_llm -{ -namespace kernels -{ -template -struct SupportedLayout -{ - static constexpr bool value = false; -}; - -template <> -struct SupportedLayout> -{ - static constexpr bool value = true; -}; - -template <> -struct SupportedLayout> -{ - static constexpr bool value = true; -}; - -template -bool isEnabled() -{ - using Layout = typename cutlass::gemm::kernel::LayoutDetailsB::Layout; - return SupportedLayout::value; -} - -template -bool isEnabledForArch(int arch) -{ - if (arch >= 70 && arch < 75) - { - return isEnabled(); - } - else if (arch >= 75 && arch < 80) - { - return isEnabled(); - } - else if (arch >= 80 && arch <= 90) - { - return isEnabled(); - } - else - { - // TLLM_CHECK_WITH_INFO(false, "Unsupported Arch"); - assert(0); - return false; - } -} - -inline bool isWeightOnlyBatchedGemvEnabled(WeightOnlyQuantType qtype) -{ - const int arch = getSMVersion(); - if (qtype == WeightOnlyQuantType::Int4b) - { - return isEnabledForArch(arch); - } - else if (qtype == WeightOnlyQuantType::Int8b) - { - return isEnabledForArch(arch); - } - else - { - assert(0); - // TLLM_CHECK_WITH_INFO(false, "Unsupported WeightOnlyQuantType"); - return false; - } -} -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/kernel.h b/weightOnlyBatchedGemv/kernel.h deleted file mode 100644 index f9ec69d3b910a744b9b0fc92163aba8cb6872df0..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/kernel.h +++ /dev/null @@ -1,554 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "common.h" -#include "utility.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -template -struct ActTypeDetails; - -template <> -struct ActTypeDetails -{ - using CutlassType = cutlass::half_t; - using Vec2 = half2; - - __device__ __forceinline__ static Vec2 to_vec2(half v) - { - return __half2half2(v); - } -}; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) -template <> -struct ActTypeDetails<__nv_bfloat16> -{ - using CutlassType = cutlass::bfloat16_t; - using Vec2 = __nv_bfloat162; - - __device__ __forceinline__ static Vec2 to_vec2(__nv_bfloat16 v) - { - return __bfloat162bfloat162(v); - } -}; -#endif - -template -struct ConverterSelector -{ - static_assert(QType == WeightOnlyQuantType::Int4b || QType == WeightOnlyQuantType::Int8b); - - using WeiType = std::conditional_t; - static constexpr int kConvertCount = QType == WeightOnlyQuantType::Int4b ? 8 : 4; - using Converter - = cutlass::FastInterleavedAndBiasedNumericArrayConverter::CutlassType, WeiType, - kConvertCount>; -}; - -template -struct WeightOnlyDetails; - -template -struct WeightOnlyDetails -{ - // Every four rows of the original weights are interleaved into a row with stride of 64, so if each thread - // processes 32 elements(for int4, we can use ldg.128 to load weights), then every group of two adjacent threads - // will alternately process four different row weights - // for example - // every 256 consecutive int4 elements [256*i, 256*(i+1)-1] of row N under interleave layout, - // the first 64 are from [64*i, 64*(i+1)-1] of row 4N before interleaving, - // and the second 64 are from [64*i, 64*(i+1)-1] of row 4N+1 before interleaving, and so on. - // So if each thread loads 32 int4 elements, then the elements of each 2 adjacent threads of each 8 - // consecutive threads will come from row 4N ~ 4N+3 respectively before interleaving. - static constexpr int kElemBits = 4; - static constexpr int kInterleave = 4; - static constexpr int kStride = 64; - - // The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm - // input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ... 31 - // weight 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 - static constexpr int kShuffleSize = 32; - static constexpr int kShuffleBasicTile = 2; - static constexpr int kShuffleContinous = 4; - static constexpr int kShuffleStrided = 4; - - // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the - // corresponding address in shared memory - template - __device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave]) - { -#pragma unroll - for (int i = 0; i < Num; ++i) - { - res[i] += __shfl_xor_sync(~0, res[i], 16); - res[i] += __shfl_xor_sync(~0, res[i], 8); - res[i] += __shfl_xor_sync(~0, res[i], 1); - } - __syncthreads(); - int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; - if (lane == 0 || lane == 2 || lane == 4 || lane == 6) - { -#pragma unroll - for (int i = 0; i < Num; ++i) - { - sm[warp][i * kInterleave + lane / 2] = res[i]; - } - } - __syncthreads(); - } -}; - -template -struct WeightOnlyDetails -{ - // Every two rows of the original weights are interleaved into a row with stride of 64, so if each thread - // processes 16 elements(for int8, we can use ldg.128 to load weights), then every group of four adjacent threads - // will alternately process two different row weights - // for example - // every 128 consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave layout, - // the first 64 are from [64*i, 64*(i+1)-1] of row 2N before interleaving, - // and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 before interleaving. - // So if each thread loads 16 int8 elements, then the elements of the first four and last four threads of each 8 - // consecutive threads will come from row 2N and row 2N+1 respectively before interleaving. - static constexpr int kElemBits = 8; - static constexpr int kInterleave = 2; - static constexpr int kStride = 64; - - // The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm - // input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 - // weight 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 - static constexpr int kShuffleSize = 16; - static constexpr int kShuffleBasicTile = 2; - static constexpr int kShuffleContinous = 2; - static constexpr int kShuffleStrided = 4; - - // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the - // corresponding address in shared memory - template - __device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave]) - { -#pragma unroll - for (int i = 0; i < Num; ++i) - { - res[i] += __shfl_xor_sync(~0, res[i], 16); - res[i] += __shfl_xor_sync(~0, res[i], 8); - res[i] += __shfl_xor_sync(~0, res[i], 2); - res[i] += __shfl_xor_sync(~0, res[i], 1); - } - __syncthreads(); - int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; - if (lane == 0 || lane == 4) - { -#pragma unroll - for (int i = 0; i < Num; ++i) - { - sm[warp][i * kInterleave + lane / 4] = res[i]; - } - } - __syncthreads(); - } -}; - -template -struct WeightOnlyKernelDetails -{ - using Layout = WeightOnlyDetails; - - static constexpr int kElemBits = Layout::kElemBits; - static constexpr int kInterleave = Layout::kInterleave; - static constexpr int kStride = Layout::kStride; - - static constexpr int kShuffleSize = Layout::kShuffleSize; - static constexpr int kShuffleBasicTile = Layout::kShuffleBasicTile; - static constexpr int kShuffleContinous = Layout::kShuffleContinous; - static constexpr int kShuffleStrided = Layout::kShuffleStrided; - - // The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int4/8s_inplace - // Input int8 data layout - // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) - // - // Converted fp16/bf16 data layout - // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) - - // Input int8 data layout - // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) - // - // Converted fp16/bf16 data layout - // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) - static constexpr int kConvertCount = ConverterSelector::kConvertCount; - using Converter = typename ConverterSelector::Converter; - - // Use ldg128 load data from global memory - static constexpr int kAccessSize = 128; - using AccessType = uint4; - - static constexpr int kElemsPerByte = 8 / kElemBits; - static constexpr int kElemsPerThread = kAccessSize / kElemBits; - static constexpr int kBytePerThread = kElemsPerThread / kElemsPerByte; - static constexpr int kThreadsNumPerTile = kStride / kElemsPerThread; - static constexpr int kThreadsNumPerInterleave = kThreadsNumPerTile * kInterleave; - - static constexpr int kConvertIters = kElemsPerThread / kConvertCount; - - // Each thread loads 16(int8b)/32(int4b) quantized weight elements each time through ldg128 - // So more times of ldg128 are needed to load the same number of fp16/bf16 activation elements. - static constexpr int kActivationElemNumPerAccess = kAccessSize / (sizeof(ActType) * 8); - static constexpr int kActivationAccessNum = kElemsPerThread / kActivationElemNumPerAccess; -}; - -template -struct WeightOnlyProperties; - -template <> -struct WeightOnlyProperties -{ - static constexpr bool kIsFineGrained = false; - static constexpr int kGroupSize = 0; -}; - -template -struct WeightOnlyProperties> -{ - static constexpr bool kIsFineGrained = true; - static constexpr int kGroupSize = GS; -}; - -template -struct WeightOnlyScaleLoader -{ - using ElemType = ActType; - using Details = WeightOnlyKernelDetails; - static constexpr bool kIsFineGrained = WeightOnlyProperties::kIsFineGrained; - static constexpr int kGroupSize = WeightOnlyProperties::kGroupSize; - -private: - const ElemType* _scales; - const ElemType* _zeros; - int _stride; - int _offset; - -public: - __device__ __forceinline__ WeightOnlyScaleLoader( - const ElemType* scales, const ElemType* zeros, int initial_offset, int stride) - : _scales(scales) - , _zeros(zeros) - , _stride(stride) - { - _scales += initial_offset; - if constexpr (Zero) - { - _zeros += initial_offset; - } - // Calculate the k dimension index of the element processed by the current thread of layout before interleave - // Used to load scales and zeros in groupwise weight only quant - _offset = threadIdx.x / Details::kThreadsNumPerInterleave * Details::kStride - + (threadIdx.x % Details::kThreadsNumPerTile) * Details::kElemsPerThread; - } - - __device__ __forceinline__ void load(ElemType& scale, ElemType& zero, int nid) - { - int offset = nid * Details::kInterleave; - if constexpr (kIsFineGrained) - { - offset += _offset / kGroupSize * _stride; - } - scale = _scales[offset]; - if constexpr (Zero) - { - zero = _zeros[offset]; - } - else - { - zero = static_cast(0.f); - } - } - - __device__ __forceinline__ void advance() - { - _offset += BlockSize * Details::kElemsPerThread / Details::kInterleave; - } - - __device__ __forceinline__ int offset() - { - return _offset; - } -}; - -template class ActOp, - bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> -__device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* scales, const ActType* zeros, - const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k) -{ - static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0)); - using ActType2 = typename ActTypeDetails::Vec2; - using Details = WeightOnlyKernelDetails; - - using Converter = typename Details::Converter; - using AccType = typename Details::AccessType; - using CvtSrcType = typename Converter::source_type; - using CvtResType = typename Converter::result_type; - using ScaleLoader = WeightOnlyScaleLoader; - extern __shared__ uint8_t shmem[]; - constexpr int Interleave = Details::kInterleave; - constexpr int WarpSize = 32; - constexpr int Num = Batch * NPerBlock; - const int tid = threadIdx.x; - const int bid = blockIdx.x; - const int n_start_id = bid * NPerBlock * Interleave; - // Calculate the n-dimensional index of the data processed by the current thread in the interleave tile - const int interleave_n_id = (tid / Details::kThreadsNumPerTile) % Interleave; - - qweight += n_start_id * k / Details::kElemsPerByte; - ScaleLoader scale_loader(scales, zeros, n_start_id + interleave_n_id, n); - - float(*sm)[Num * Interleave] = reinterpret_cast(shmem); - - // In order to take advantage of hfma2, we use fp16/bf16 for accumulation within threads and fp32 for accumulation - // between threads. - ActType accumulator[Num]; - for (int i = 0; i < Num; ++i) - { - accumulator[i] = static_cast(0.f); - } - - // Iteration in k dimensions - for (int local_k = tid * Details::kElemsPerThread; local_k < k * Interleave; - local_k += BlockSize * Details::kElemsPerThread) - { - ActType weights_f16[Details::kElemsPerThread * NPerBlock]; - ActType scale[NPerBlock], zero[NPerBlock]; -#pragma unroll - for (int idx = 0; idx < NPerBlock; ++idx) - { - // Load quantized weight and scales/zeros - uint8_t weights_quantized[Details::kBytePerThread]; - load(weights_quantized, - qweight + idx * Interleave * k / Details::kElemsPerByte + local_k / Details::kElemsPerByte); - scale_loader.load(scale[idx], zero[idx], idx); - ActType weights_vec[Details::kElemsPerThread]; -#pragma unroll - for (int i = 0; i < Details::kConvertIters; ++i) - { - // Use cutlass::FastInterleavedAndBiasedNumericArrayConverter for I2F type conversion - assign(weights_vec + i * Details::kConvertCount, - Converter::convert(*reinterpret_cast( - weights_quantized + i * Details::kConvertCount / Details::kElemsPerByte))); - } -#pragma unroll - for (int i = 0; i < Details::kShuffleContinous; ++i) - { -#pragma unroll - for (int j = 0; j < Details::kShuffleStrided; ++j) - { - // Dequantize the weights and arrange the shuffled elements back to the correct order in the - // register array - ActType2 v = *reinterpret_cast(weights_vec + i * Details::kShuffleBasicTile - + j * Details::kShuffleContinous * Details::kShuffleBasicTile); - v = __hfma2( - v, ActTypeDetails::to_vec2(scale[idx]), ActTypeDetails::to_vec2(zero[idx])); - weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile - + j * Details::kShuffleBasicTile + 0) - * NPerBlock - + idx] - = v.x; - weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile - + j * Details::kShuffleBasicTile + 1) - * NPerBlock - + idx] - = v.y; - } - } - } - ActType act_scale_v[Details::kElemsPerThread]; - if constexpr (ActScale) - { -#pragma unroll - for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) - { - load(act_scale_v + idx * Details::kActivationElemNumPerAccess, - act_scale + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess); - } - } -#pragma unroll - for (int b = 0; b < Batch; ++b) - { - ActType in_v[Details::kElemsPerThread]; -#pragma unroll - for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) - { - // load activation elements - load(in_v + idx * Details::kActivationElemNumPerAccess, - in + b * k + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess); - if constexpr (ActScale) - { -#pragma unroll - for (int i = 0; i < Details::kActivationElemNumPerAccess; i += 2) - { - *reinterpret_cast(in_v + idx * Details::kActivationElemNumPerAccess + i) = __hmul2( - *reinterpret_cast(in_v + idx * Details::kActivationElemNumPerAccess + i), - *reinterpret_cast(act_scale_v + idx * Details::kActivationElemNumPerAccess + i)); - } - } - } - // Perform vector inner product and accumulate - if constexpr (NPerBlock == 1) - { - ActType2 v = ActTypeDetails::to_vec2(static_cast(0.f)); -#pragma unroll - for (int y = 0; y < Details::kElemsPerThread; y += 2) - { - v = __hfma2( - *reinterpret_cast(weights_f16 + y), *reinterpret_cast(in_v + y), v); - } - accumulator[b] += __hadd(v.x, v.y); - } - else - { -#pragma unroll - for (int x = 0; x < NPerBlock / 2; ++x) - { -#pragma unroll - for (int y = 0; y < Details::kElemsPerThread; ++y) - { - *reinterpret_cast(accumulator + b * NPerBlock + x * 2) - = __hfma2(*reinterpret_cast(weights_f16 + y * NPerBlock + x * 2), - ActTypeDetails::to_vec2(in_v[y]), - *reinterpret_cast(accumulator + b * NPerBlock + x * 2)); - } - } - } - } - scale_loader.advance(); - } - float reses[Num]; -#pragma unroll - for (int i = 0; i < Num; ++i) - { - reses[i] = static_cast(accumulator[i]); - } - - // Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the - // corresponding address in shared memory - Details::Layout::sync(reses, sm); - - // Each thread is responsible for the accumulation and store to global memory of one element - for (int i = tid; i < Num * Interleave; i += BlockSize) - { - int nid = i % (NPerBlock * Interleave); - float v = 0.f; - for (int j = 0; j < BlockSize / WarpSize; ++j) - { - v += sm[j][i]; - } - float bias_v = 0.f; - if constexpr (Bias) - { - bias_v = static_cast(bias[n_start_id + nid]); - } - int b = i / NPerBlock / Interleave; - out[b * n + n_start_id + nid] = static_cast(ActOp::apply(v + bias_v)); - } -} - -template class ActOp, - bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> -__global__ void weight_only_batched_gemv_wrapper(const uint8_t* qweight, const ActType* scales, const ActType* zeros, - const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k) -{ - if constexpr (std::is_same_v) - { - weight_only_batched_gemv(qweight, scales, zeros, in, act_scale, bias, out, n, k); - } -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - else if (std::is_same_v) - { - weight_only_batched_gemv(qweight, scales, zeros, in, act_scale, bias, out, n, k); - } -#endif -} - -template class ActOp, bool Zero, bool Bias, - int NPerBlock, int Batch, int BlockSize> -struct WeightOnlyBatchedGemvKernelLauncher -{ - static void run(const WeightOnlyParams& params, cudaStream_t stream) - { - if (params.act_type == WeightOnlyActivationType::FP16) - { - constexpr int kInterleave = WeightOnlyDetails::kInterleave; - dim3 grid(params.n / NPerBlock / kInterleave); - dim3 block(BlockSize); - int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; - if (params.act_scale != nullptr) - { - weight_only_batched_gemv_wrapper<<>>(params.qweight, - reinterpret_cast(params.scales), reinterpret_cast(params.zeros), - reinterpret_cast(params.in), reinterpret_cast(params.act_scale), - reinterpret_cast(params.bias), reinterpret_cast(params.out), params.n, - params.k); - } - else - { - weight_only_batched_gemv_wrapper<<>>(params.qweight, - reinterpret_cast(params.scales), reinterpret_cast(params.zeros), - reinterpret_cast(params.in), reinterpret_cast(params.act_scale), - reinterpret_cast(params.bias), reinterpret_cast(params.out), params.n, - params.k); - } - } -#if defined(ENABLE_BF16) - else if (params.act_type == WeightOnlyActivationType::BF16) - { - constexpr int kInterleave = WeightOnlyDetails::kInterleave; - dim3 grid(params.n / NPerBlock / kInterleave); - dim3 block(BlockSize); - int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; - if (params.act_scale != nullptr) - { - weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, - NPerBlock, Batch, BlockSize><<>>(params.qweight, - reinterpret_cast(params.scales), - reinterpret_cast(params.zeros), - reinterpret_cast(params.in), - reinterpret_cast(params.act_scale), - reinterpret_cast(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), - params.n, params.k); - } - else - { - weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, - NPerBlock, Batch, BlockSize><<>>(params.qweight, - reinterpret_cast(params.scales), - reinterpret_cast(params.zeros), - reinterpret_cast(params.in), - reinterpret_cast(params.act_scale), - reinterpret_cast(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), - params.n, params.k); - } - } -#endif - } -}; -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/kernelLauncher.cu b/weightOnlyBatchedGemv/kernelLauncher.cu deleted file mode 100644 index 814874ef3e34d90acecf938e338fbc898f6f20bb..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/kernelLauncher.cu +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common.h" -#include "utility.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -template class ActOp, bool Zero, bool Bias, - int N_PER_BLOCK, int BATCH, int BLOCK_SIZE> -struct WeightOnlyBatchedGemvKernelLauncher -{ - static void run(const WeightOnlyParams& params, cudaStream_t stream); -}; - -template class ActOp, int N_PER_BLOCK, - int BATCH, int BLOCK_SIZE> -void select_zero_bias(const WeightOnlyParams& params, cudaStream_t stream) -{ - if (params.zeros && params.bias) - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - } - else if (params.zeros && !params.bias) - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - } - else if (!params.zeros && params.bias) - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - } - else - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - } -} - -template -void select_activation(const WeightOnlyParams& params, cudaStream_t stream) -{ - switch (params.act_func_type) - { - // Currently, activation function is not called in the plugin -#if 0 - case WeightOnlyActivationFunctionType::Gelu: - { - select_zero_bias(params, stream); - break; - } - case WeightOnlyActivationFunctionType::Relu: - { - select_zero_bias(params, stream); - break; - } -#endif - case WeightOnlyActivationFunctionType::Identity: - { - select_zero_bias(params, stream); - break; - } - default: - { - throw std::runtime_error("Use unsupported activation"); - break; - } - } -} - -template -void select_quant_type(const WeightOnlyParams& params, cudaStream_t stream) -{ - if (params.quant_type == WeightOnlyQuantType::Int4b) - { - select_activation(params, stream); - } - else if (params.quant_type == WeightOnlyQuantType::Int8b) - { - select_activation(params, stream); - } - else - { - throw std::runtime_error("Unknown QuantType"); - } -} - -template -void select_groupwise_weight_only(const WeightOnlyParams& params, cudaStream_t stream) -{ - if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 64) - { - select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); - } - else if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 128) - { - select_quant_type, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); - } - else - { - throw std::runtime_error("Only support groupwise weight only for gs=64/128"); - } -} - -void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream) -{ - assert(params.act_func_type == WeightOnlyActivationFunctionType::Identity); - assert(params.weight_only_type == WeightOnlyType::GroupWise - || (params.weight_only_type == WeightOnlyType::PerChannel && params.bias == nullptr - && params.zeros == nullptr)); - if (params.weight_only_type == WeightOnlyType::PerChannel) - { - if (params.quant_type == WeightOnlyQuantType::Int4b) - { - switch (params.m) - { - case 1: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 2: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 3: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 4: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - default: - { - throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); - break; - } - } - } - else if (params.quant_type == WeightOnlyQuantType::Int8b) - { - switch (params.m) - { - case 1: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 2: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 3: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - case 4: - { - WeightOnlyBatchedGemvKernelLauncher::run(params, stream); - break; - } - default: - { - throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); - break; - } - } - } - } - else if (params.weight_only_type == WeightOnlyType::GroupWise) - { - switch (params.m) - { - case 1: - { - select_groupwise_weight_only<2, 1, 256>(params, stream); - break; - } - case 2: - { - select_groupwise_weight_only<2, 2, 256>(params, stream); - break; - } - case 3: - { - select_groupwise_weight_only<2, 3, 128>(params, stream); - break; - } - case 4: - { - select_groupwise_weight_only<2, 4, 128>(params, stream); - break; - } - default: - { - throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); - break; - } - } - } -} -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/kernelLauncher.h b/weightOnlyBatchedGemv/kernelLauncher.h deleted file mode 100644 index 9bfa7302167cfc852e269bbbf7ee7c124fcab3dc..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/kernelLauncher.h +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include "common.h" - -namespace tensorrt_llm -{ -namespace kernels -{ -void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream); -} -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/utility.h b/weightOnlyBatchedGemv/utility.h deleted file mode 100644 index e53814525cac02b9ea1387146c6d6ed9a3443f5d..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/utility.h +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include -#include -#include -#include - -#include "cutlass/cutlass.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -__forceinline__ __device__ float copysignf_pos(float a, float b) -{ - float r; - r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); - return r; -} - -__inline__ __device__ float tanh_opt(float x) -{ -#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000) - float r; - asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); - return r; -#else - const float exp_val = -1.f * fabs(2 * x); - return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); -#endif -} - -template -struct GeluActivation -{ - static __device__ __forceinline__ T apply(const T& val) - { - const float cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (val + 0.044715f * val * val * val)))); - return val * cdf; - } -}; - -template -struct ReluActivation -{ - static __device__ __forceinline__ T apply(const T& val) - { - return val > static_cast(0.0f) ? val : static_cast(0.0f); - } -}; - -template -struct IdentityActivation -{ - static __device__ __forceinline__ T apply(const T& val) - { - return val; - } -}; - -template -__device__ __forceinline__ void load(T0* dst, T1* src, size_t offset = 0) -{ - *reinterpret_cast(dst) = *(reinterpret_cast(src) + offset); -} - -template -__device__ __forceinline__ void assign(T* dst, const AssignType& val) -{ - *reinterpret_cast(dst) = val; -} - -template -__device__ __forceinline__ void store(T0* src, T1* dst, size_t offset = 0) -{ - *(reinterpret_cast(dst) + offset) = *reinterpret_cast(src); -} -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu deleted file mode 100644 index 9594350267ed8529a053b2b34e1f42a920f0e8cf..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 1, 256>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 1, 256>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu deleted file mode 100644 index 94c83ccf78242dd6d9daf50c60efd1e796e0ea0c..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 1, 256>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 1, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 1, 256>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu deleted file mode 100644 index 9ba99bc8270ba365ded192099096223606edfb61..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 2, 256>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 2, 256>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu deleted file mode 100644 index 729d38726f30214e2fc517c91405ded7b99a4523..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 2, 256>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 2, 256>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 2, 256>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu deleted file mode 100644 index 8e48f3e93c29cf82a2c3cef63c988c88c23b5c0a..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 3, 128>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 3, 128>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu deleted file mode 100644 index b73ef8df880aaeb7092c41dd30a9a11f4c2fafe6..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 3, 128>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 3, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 3, 128>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu deleted file mode 100644 index 2a29c8385daacb1d4244adaa5e82029f9494c6ba..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 4, 128>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 4, 128>; - -} // namespace kernels -} // namespace tensorrt_llm diff --git a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu b/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu deleted file mode 100644 index a6f0f5fa52d33a9f359e742e253049b710feccac..0000000000000000000000000000000000000000 --- a/weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel.h" - -namespace tensorrt_llm -{ -namespace kernels -{ - -template struct WeightOnlyBatchedGemvKernelLauncher; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 4, 128>; - -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, true, false, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, true, 2, 4, 128>; -template struct WeightOnlyBatchedGemvKernelLauncher, - IdentityActivation, false, false, 2, 4, 128>; - -} // namespace kernels -} // namespace tensorrt_llm