Spaces:
Build error
Build error
| // Copyright (c) Facebook, Inc. and its affiliates.All Rights Reserved | |
| template <typename T> | |
| __host__ __device__ T div_round_up(T val, T divisor) { | |
| return (val + divisor - 1) / divisor; | |
| } | |
| template <uint32_t S> | |
| __global__ void kernel_topp_masking( | |
| const int * __restrict__ sorted_indices, | |
| const float * __restrict__ sorted_weights, | |
| bool *output_mask, | |
| const float p, const uint32_t B, | |
| const uint32_t N, const uint32_t D) { | |
| const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; | |
| if (b >= N) return; | |
| const uint32_t batch_id = blockIdx.y; | |
| // locate | |
| sorted_weights += (b + batch_id * N) * D; | |
| sorted_indices += (b + batch_id * N) * D; | |
| output_mask += (b + batch_id * N) * D; | |
| float w_sum = 0; | |
| for (uint32_t d = 0; d < S; d++){ | |
| if (d >= D) break; | |
| w_sum += sorted_weights[d]; | |
| output_mask[sorted_indices[d]] = true; | |
| if (w_sum >= p) break; | |
| } | |
| } | |
| void topp_masking_cuda( | |
| const int *sorted_indices, | |
| const float *sorted_weights, bool *output_mask, | |
| const float p, const uint32_t B, const uint32_t N, const uint32_t D) { | |
| static constexpr uint32_t N_THREAD = 512; | |
| const dim3 blocks = {div_round_up(N, N_THREAD), B, 1}; | |
| if (D < 8) kernel_topp_masking<8><<< blocks, N_THREAD>>>(sorted_indices, sorted_weights, output_mask, p, B, N, D); | |
| else if (D < 16) kernel_topp_masking<16><<< blocks, N_THREAD>>>(sorted_indices, sorted_weights, output_mask, p, B, N, D); | |
| else if (D < 32) kernel_topp_masking<32><<< blocks, N_THREAD>>>(sorted_indices, sorted_weights, output_mask, p, B, N, D); | |
| else if (D < 64) kernel_topp_masking<64><<< blocks, N_THREAD>>>(sorted_indices, sorted_weights, output_mask, p, B, N, D); | |
| else if (D < 128) kernel_topp_masking<128><<<blocks, N_THREAD>>>(sorted_indices, sorted_weights, output_mask, p, B, N, D); | |
| else if (D < 256) kernel_topp_masking<256><<<blocks, N_THREAD>>>(sorted_indices, sorted_weights, output_mask, p, B, N, D); | |
| else throw std::runtime_error{"# of sampled points should not exceed 256"}; | |
| } | |
| void topp_masking( | |
| at::Tensor sorted_indices, at::Tensor sorted_weights, at::Tensor output_mask, | |
| const float p, const uint32_t B, const uint32_t N, const uint32_t D) { | |
| CHECK_CUDA(sorted_indices); | |
| CHECK_CUDA(sorted_weights); | |
| CHECK_CUDA(output_mask); | |
| CHECK_CONTIGUOUS(sorted_indices); | |
| CHECK_CONTIGUOUS(sorted_weights); | |
| CHECK_CONTIGUOUS(output_mask); | |
| CHECK_IS_FLOAT(sorted_weights); | |
| CHECK_IS_INT(sorted_indices); | |
| topp_masking_cuda(sorted_indices.data_ptr<int>(), sorted_weights.data_ptr<float>(), output_mask.data_ptr<bool>(), p, B, N, D); | |
| } | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def("topp_masking", &topp_masking, "topp masking"); | |
| } | |