megablocks-hip / csrc /new_sort.hip
leonardlin's picture
Add ROCm build artifacts and HIP backend
1e407f0
// !!! This is a file automatically generated by hipify!!!
#define CUB_IGNORE_DEPRECATED_API
#undef CUB_WRAPPED_NAMESPACE
#define CUB_WRAPPED_NAMESPACE megablocks
#include "new_sort.h"
#include "gpu_backend_hip.h"
#include <cstdint>
namespace megablocks {
template <typename T>
void cub_radix_sort(torch::Tensor x,
int end_bit,
torch::Tensor x_out,
torch::Tensor iota_out) {
// Get iota for values in sort.
torch::Tensor iota = torch::arange(0, x.numel(), x.options());
// Get temporary buffer size.
size_t scratchpad_bytes = 0;
GPU_CALL(cubns::DeviceRadixSort::SortPairs(nullptr,
scratchpad_bytes,
x.data_ptr<T>(),
x_out.data_ptr<T>(),
iota.data_ptr<T>(),
iota_out.data_ptr<T>(),
x.numel(),
/*begin_bit*/0,
/*end_bit=*/end_bit,
megablocks::get_current_stream()));
// Allocate scratchpad.
auto options = torch::TensorOptions()
.dtype(torch::kInt8)
.device(x.device());
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
// Run the kernel.
GPU_CALL(cubns::DeviceRadixSort::SortPairs(scratchpad.data_ptr(),
scratchpad_bytes,
x.data_ptr<T>(),
x_out.data_ptr<T>(),
iota.data_ptr<T>(),
iota_out.data_ptr<T>(),
x.numel(),
/*begin_bit=*/0,
/*end_bit=*/end_bit,
megablocks::get_current_stream()));
}
void sort(torch::Tensor x,
int end_bit,
torch::Tensor x_out,
torch::Tensor iota_out) {
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(x.ndimension() == 1);
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
x.scalar_type() == torch::kInt32 ||
x.scalar_type() == torch::kInt64);
TORCH_CHECK(x_out.is_cuda());
TORCH_CHECK(x_out.ndimension() == 1);
TORCH_CHECK(x_out.scalar_type() == x.scalar_type());
TORCH_CHECK(iota_out.is_cuda());
TORCH_CHECK(iota_out.ndimension() == 1);
TORCH_CHECK(iota_out.scalar_type() == x.scalar_type());
// Exit early if there is not work to do.
if (x_out.numel() == 0) return;
if (x.scalar_type() == torch::kInt16) {
cub_radix_sort<short>(x, end_bit, x_out, iota_out);
return;
}
if (x.scalar_type() == torch::kInt32) {
cub_radix_sort<int>(x, end_bit, x_out, iota_out);
return;
}
TORCH_CHECK(x.scalar_type() == torch::kInt64);
cub_radix_sort<long>(x, end_bit, x_out, iota_out);
}
} // namespace megablocks
#undef CUB_WRAPPED_NAMESPACE