|
|
#pragma once |
|
|
|
|
|
#include <torch/extension.h> |
|
|
#include <utility> |
|
|
|
|
|
#ifdef __HIP_PLATFORM_AMD__ |
|
|
#include <hip/hip_runtime.h> |
|
|
#include <hipcub/hipcub.hpp> |
|
|
#include <c10/hip/HIPStream.h> |
|
|
namespace cubns = hipcub; |
|
|
#else |
|
|
#include <cuda_runtime.h> |
|
|
#include <cub/cub.cuh> |
|
|
#include <c10/cuda/CUDAStream.h> |
|
|
namespace cubns = cub; |
|
|
#endif |
|
|
|
|
|
namespace megablocks { |
|
|
#ifdef __HIP_PLATFORM_AMD__ |
|
|
using gpuError_t = hipError_t; |
|
|
using gpuStream_t = hipStream_t; |
|
|
constexpr gpuError_t kGpuSuccess = hipSuccess; |
|
|
inline gpuStream_t get_current_stream() { |
|
|
return c10::hip::getCurrentHIPStream(); |
|
|
} |
|
|
inline const char* get_error_string(gpuError_t status) { |
|
|
return hipGetErrorString(status); |
|
|
} |
|
|
inline gpuError_t get_last_error() { |
|
|
return hipGetLastError(); |
|
|
} |
|
|
template <typename... Args> |
|
|
inline gpuError_t gpuMemcpyAsync(Args&&... args) { |
|
|
return hipMemcpyAsync(std::forward<Args>(args)...); |
|
|
} |
|
|
template <typename... Args> |
|
|
inline gpuError_t gpuMemsetAsync(Args&&... args) { |
|
|
return hipMemsetAsync(std::forward<Args>(args)...); |
|
|
} |
|
|
#else |
|
|
using gpuError_t = cudaError_t; |
|
|
using gpuStream_t = cudaStream_t; |
|
|
constexpr gpuError_t kGpuSuccess = cudaSuccess; |
|
|
inline gpuStream_t get_current_stream() { |
|
|
return c10::cuda::getCurrentCUDAStream(); |
|
|
} |
|
|
inline const char* get_error_string(gpuError_t status) { |
|
|
return cudaGetErrorString(status); |
|
|
} |
|
|
inline gpuError_t get_last_error() { |
|
|
return cudaGetLastError(); |
|
|
} |
|
|
template <typename... Args> |
|
|
inline gpuError_t gpuMemcpyAsync(Args&&... args) { |
|
|
return cudaMemcpyAsync(std::forward<Args>(args)...); |
|
|
} |
|
|
template <typename... Args> |
|
|
inline gpuError_t gpuMemsetAsync(Args&&... args) { |
|
|
return cudaMemsetAsync(std::forward<Args>(args)...); |
|
|
} |
|
|
#endif |
|
|
|
|
|
inline void gpuCheck(gpuError_t status, const char* expr) { |
|
|
TORCH_CHECK(status == kGpuSuccess, get_error_string(status)); |
|
|
} |
|
|
} |
|
|
|
|
|
#define GPU_CALL(expr) ::megablocks::gpuCheck((expr), #expr) |
|
|
|
|
|
|