|  | #ifdef TEST_ON_CUDA | 
					
						
						|  | #include <mma.h> | 
					
						
						|  |  | 
					
						
						|  | #include <cuda_fp16.h> | 
					
						
						|  | #include <cuda_fp8.h> | 
					
						
						|  |  | 
					
						
						|  | namespace wmma = nvcuda::wmma; | 
					
						
						|  |  | 
					
						
						|  | #define LIB_CALL(call)                                                                                                 \ | 
					
						
						|  | do {                                                                                                                 \ | 
					
						
						|  | cudaError_t err = call;                                                                                            \ | 
					
						
						|  | if (err != cudaSuccess) {                                                                                          \ | 
					
						
						|  | abort();                                                                                                         \ | 
					
						
						|  | }                                                                                                                  \ | 
					
						
						|  | } while (0) | 
					
						
						|  |  | 
					
						
						|  | #define HOST_TYPE(x) cuda##x | 
					
						
						|  |  | 
					
						
						|  | #else | 
					
						
						|  |  | 
					
						
						|  | #ifndef HIP_HEADERS__ | 
					
						
						|  | #include <hip/hip_runtime.h> | 
					
						
						|  | #include <hip/hip_fp8.h> | 
					
						
						|  | #include <hip/hip_fp16.h> | 
					
						
						|  | #include <rocwmma/rocwmma.hpp> | 
					
						
						|  | #define HIP_HEADERS__ | 
					
						
						|  | #endif | 
					
						
						|  |  | 
					
						
						|  | namespace wmma = rocwmma; | 
					
						
						|  |  | 
					
						
						|  | #define LIB_CALL(call)                                                                                                 \ | 
					
						
						|  | do {                                                                                                                 \ | 
					
						
						|  | hipError_t err = call;                                                                                             \ | 
					
						
						|  | if (err != hipSuccess) {                                                                                           \ | 
					
						
						|  | abort();                                                                                                         \ | 
					
						
						|  | }                                                                                                                  \ | 
					
						
						|  | } while (0) | 
					
						
						|  |  | 
					
						
						|  | #define HOST_TYPE(x) hip##x | 
					
						
						|  |  | 
					
						
						|  | #endif | 
					
						
						|  |  | 
					
						
						|  |  |