|
|
|
|
|
#define CUB_IGNORE_DEPRECATED_API |
|
|
|
|
|
#undef CUB_WRAPPED_NAMESPACE |
|
|
#define CUB_WRAPPED_NAMESPACE megablocks |
|
|
|
|
|
#include "new_sort.h" |
|
|
#include "gpu_backend_hip.h" |
|
|
|
|
|
#include <cstdint> |
|
|
|
|
|
namespace megablocks { |
|
|
|
|
|
template <typename T> |
|
|
void cub_radix_sort(torch::Tensor x, |
|
|
int end_bit, |
|
|
torch::Tensor x_out, |
|
|
torch::Tensor iota_out) { |
|
|
|
|
|
torch::Tensor iota = torch::arange(0, x.numel(), x.options()); |
|
|
|
|
|
|
|
|
size_t scratchpad_bytes = 0; |
|
|
GPU_CALL(cubns::DeviceRadixSort::SortPairs(nullptr, |
|
|
scratchpad_bytes, |
|
|
x.data_ptr<T>(), |
|
|
x_out.data_ptr<T>(), |
|
|
iota.data_ptr<T>(), |
|
|
iota_out.data_ptr<T>(), |
|
|
x.numel(), |
|
|
0, |
|
|
end_bit, |
|
|
megablocks::get_current_stream())); |
|
|
|
|
|
|
|
|
auto options = torch::TensorOptions() |
|
|
.dtype(torch::kInt8) |
|
|
.device(x.device()); |
|
|
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options); |
|
|
|
|
|
|
|
|
GPU_CALL(cubns::DeviceRadixSort::SortPairs(scratchpad.data_ptr(), |
|
|
scratchpad_bytes, |
|
|
x.data_ptr<T>(), |
|
|
x_out.data_ptr<T>(), |
|
|
iota.data_ptr<T>(), |
|
|
iota_out.data_ptr<T>(), |
|
|
x.numel(), |
|
|
0, |
|
|
end_bit, |
|
|
megablocks::get_current_stream())); |
|
|
} |
|
|
|
|
|
void sort(torch::Tensor x, |
|
|
int end_bit, |
|
|
torch::Tensor x_out, |
|
|
torch::Tensor iota_out) { |
|
|
TORCH_CHECK(x.is_cuda()); |
|
|
TORCH_CHECK(x.ndimension() == 1); |
|
|
TORCH_CHECK(x.scalar_type() == torch::kInt16 || |
|
|
x.scalar_type() == torch::kInt32 || |
|
|
x.scalar_type() == torch::kInt64); |
|
|
TORCH_CHECK(x_out.is_cuda()); |
|
|
TORCH_CHECK(x_out.ndimension() == 1); |
|
|
TORCH_CHECK(x_out.scalar_type() == x.scalar_type()); |
|
|
TORCH_CHECK(iota_out.is_cuda()); |
|
|
TORCH_CHECK(iota_out.ndimension() == 1); |
|
|
TORCH_CHECK(iota_out.scalar_type() == x.scalar_type()); |
|
|
|
|
|
|
|
|
if (x_out.numel() == 0) return; |
|
|
|
|
|
if (x.scalar_type() == torch::kInt16) { |
|
|
cub_radix_sort<short>(x, end_bit, x_out, iota_out); |
|
|
return; |
|
|
} |
|
|
if (x.scalar_type() == torch::kInt32) { |
|
|
cub_radix_sort<int>(x, end_bit, x_out, iota_out); |
|
|
return; |
|
|
} |
|
|
TORCH_CHECK(x.scalar_type() == torch::kInt64); |
|
|
cub_radix_sort<long>(x, end_bit, x_out, iota_out); |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
#undef CUB_WRAPPED_NAMESPACE |
|
|
|