File size: 1,894 Bytes
1e407f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
#pragma once
#include <torch/extension.h>
#include <utility>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#include <hipcub/hipcub.hpp>
#include <c10/hip/HIPStream.h>
namespace cubns = hipcub;
#else
#include <cuda_runtime.h>
#include <cub/cub.cuh>
#include <c10/cuda/CUDAStream.h>
namespace cubns = cub;
#endif
namespace megablocks {
#ifdef __HIP_PLATFORM_AMD__
using gpuError_t = hipError_t;
using gpuStream_t = hipStream_t;
constexpr gpuError_t kGpuSuccess = hipSuccess;
inline gpuStream_t get_current_stream() {
return c10::hip::getCurrentHIPStream();
}
inline const char* get_error_string(gpuError_t status) {
return hipGetErrorString(status);
}
inline gpuError_t get_last_error() {
return hipGetLastError();
}
template <typename... Args>
inline gpuError_t gpuMemcpyAsync(Args&&... args) {
return hipMemcpyAsync(std::forward<Args>(args)...);
}
template <typename... Args>
inline gpuError_t gpuMemsetAsync(Args&&... args) {
return hipMemsetAsync(std::forward<Args>(args)...);
}
#else
using gpuError_t = cudaError_t;
using gpuStream_t = cudaStream_t;
constexpr gpuError_t kGpuSuccess = cudaSuccess;
inline gpuStream_t get_current_stream() {
return c10::cuda::getCurrentCUDAStream();
}
inline const char* get_error_string(gpuError_t status) {
return cudaGetErrorString(status);
}
inline gpuError_t get_last_error() {
return cudaGetLastError();
}
template <typename... Args>
inline gpuError_t gpuMemcpyAsync(Args&&... args) {
return cudaMemcpyAsync(std::forward<Args>(args)...);
}
template <typename... Args>
inline gpuError_t gpuMemsetAsync(Args&&... args) {
return cudaMemsetAsync(std::forward<Args>(args)...);
}
#endif
inline void gpuCheck(gpuError_t status, const char* expr) {
TORCH_CHECK(status == kGpuSuccess, get_error_string(status));
}
} // namespace megablocks
#define GPU_CALL(expr) ::megablocks::gpuCheck((expr), #expr)
|