| 1 | #pragma once |
| 2 | #include "common.cuh" |
| 3 | |
| 4 | #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 |
| 5 | |
| 6 | template<typename T> |
| 7 | using to_t_cuda_t = void (*)(const void * x, T * y, int64_t k, cudaStream_t stream); |
| 8 | |
| 9 | typedef to_t_cuda_t<float> to_fp32_cuda_t; |
| 10 | typedef to_t_cuda_t<half> to_fp16_cuda_t; |
| 11 | typedef to_t_cuda_t<nv_bfloat16> to_bf16_cuda_t; |
| 12 | |
| 13 | to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type); |
| 14 | |
| 15 | to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type); |
| 16 | |
| 17 | to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type); |
| 18 | |
| 19 | // TODO more general support for non-contiguous inputs |
| 20 | |
| 21 | template<typename T> |
| 22 | using to_t_nc_cuda_t = void (*)(const void * x, T * y, |
| 23 | int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, |
| 24 | int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream); |
| 25 | |
| 26 | typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t; |
| 27 | typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t; |
| 28 | typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t; |
| 29 | |
| 30 | to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type); |
| 31 | to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type); |
| 32 | to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type); |
| 33 | |
| 34 | template<typename dst_t, typename src_t> |
| 35 | __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) { |
| 36 | if constexpr (std::is_same_v<dst_t, src_t>) { |
| 37 | return x; |
| 38 | } else if constexpr(std::is_same_v<dst_t, nv_bfloat16>) { |
| 39 | return __float2bfloat16(a: float(x)); |
| 40 | } else if constexpr(std::is_same_v<src_t, nv_bfloat16>) { |
| 41 | return __bfloat162float(x); |
| 42 | } else if constexpr(std::is_same_v<dst_t, int32_t>) { |
| 43 | return int32_t(x); |
| 44 | } else { |
| 45 | return float(x); |
| 46 | } |
| 47 | } |
| 48 | |