| 1 | #pragma once |
| 2 | |
| 3 | #include <cuda_runtime.h> |
| 4 | #include <cuda.h> |
| 5 | #include <cublas_v2.h> |
| 6 | #include <cuda_bf16.h> |
| 7 | #include <cuda_fp16.h> |
| 8 | |
| 9 | #if CUDART_VERSION >= 12050 |
| 10 | #include <cuda_fp8.h> |
| 11 | #endif // CUDART_VERSION >= 12050 |
| 12 | |
| 13 | #if CUDART_VERSION < 11020 |
| 14 | #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED |
| 15 | #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH |
| 16 | #define CUBLAS_COMPUTE_16F CUDA_R_16F |
| 17 | #define CUBLAS_COMPUTE_32F CUDA_R_32F |
| 18 | #define cublasComputeType_t cudaDataType_t |
| 19 | #endif // CUDART_VERSION < 11020 |
| 20 | |