| 1 | #include "common.cuh" |
| 2 | #include "fattn-common.cuh" |
| 3 | #include "fattn-mma-f16.cuh" |
| 4 | #include "fattn-tile.cuh" |
| 5 | #include "fattn-vec.cuh" |
| 6 | #include "fattn-wmma-f16.cuh" |
| 7 | #include "fattn.cuh" |
| 8 | |
| 9 | template <int DKQ, int DV, int ncols2> |
| 10 | static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 11 | const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; |
| 12 | const ggml_tensor * Q = dst->src[0]; |
| 13 | |
| 14 | if constexpr (ncols2 <= 8) { |
| 15 | if (Q->ne[1] <= 8/ncols2) { |
| 16 | ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst); |
| 17 | return; |
| 18 | } |
| 19 | } |
| 20 | |
| 21 | if (Q->ne[1] <= 16/ncols2) { |
| 22 | ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst); |
| 23 | return; |
| 24 | } |
| 25 | |
| 26 | if (ggml_cuda_highest_compiled_arch(arch: cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) { |
| 27 | ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst); |
| 28 | return; |
| 29 | } |
| 30 | |
| 31 | ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst); |
| 32 | } |
| 33 | |
| 34 | template <int DKQ, int DV> |
| 35 | static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 36 | const ggml_tensor * KQV = dst; |
| 37 | const ggml_tensor * Q = dst->src[0]; |
| 38 | const ggml_tensor * K = dst->src[1]; |
| 39 | const ggml_tensor * mask = dst->src[3]; |
| 40 | |
| 41 | float max_bias = 0.0f; |
| 42 | memcpy(dest: &max_bias, src: (const float *) KQV->op_params + 1, n: sizeof(float)); |
| 43 | |
| 44 | const bool use_gqa_opt = mask && max_bias == 0.0f; |
| 45 | |
| 46 | GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); |
| 47 | const int gqa_ratio = Q->ne[2] / K->ne[2]; |
| 48 | |
| 49 | if (use_gqa_opt && gqa_ratio % 8 == 0) { |
| 50 | ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst); |
| 51 | return; |
| 52 | } |
| 53 | |
| 54 | if (use_gqa_opt && gqa_ratio % 4 == 0) { |
| 55 | ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst); |
| 56 | return; |
| 57 | } |
| 58 | |
| 59 | if (use_gqa_opt && gqa_ratio % 2 == 0) { |
| 60 | ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst); |
| 61 | return; |
| 62 | } |
| 63 | |
| 64 | ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst); |
| 65 | } |
| 66 | |
| 67 | static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 68 | const ggml_tensor * KQV = dst; |
| 69 | const ggml_tensor * Q = dst->src[0]; |
| 70 | const ggml_tensor * K = dst->src[1]; |
| 71 | const ggml_tensor * V = dst->src[2]; |
| 72 | const ggml_tensor * mask = dst->src[3]; |
| 73 | |
| 74 | switch (Q->ne[0]) { |
| 75 | case 64: |
| 76 | GGML_ASSERT(V->ne[0] == 64); |
| 77 | ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst); |
| 78 | break; |
| 79 | case 80: |
| 80 | GGML_ASSERT(V->ne[0] == 80); |
| 81 | ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst); |
| 82 | break; |
| 83 | case 96: |
| 84 | GGML_ASSERT(V->ne[0] == 96); |
| 85 | ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst); |
| 86 | break; |
| 87 | case 112: |
| 88 | GGML_ASSERT(V->ne[0] == 112); |
| 89 | ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst); |
| 90 | break; |
| 91 | case 128: |
| 92 | GGML_ASSERT(V->ne[0] == 128); |
| 93 | ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); |
| 94 | break; |
| 95 | case 256: |
| 96 | GGML_ASSERT(V->ne[0] == 256); |
| 97 | ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); |
| 98 | break; |
| 99 | case 576: { |
| 100 | // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. |
| 101 | GGML_ASSERT(V->ne[0] == 512); |
| 102 | float max_bias = 0.0f; |
| 103 | memcpy(dest: &max_bias, src: (const float *) KQV->op_params + 1, n: sizeof(float)); |
| 104 | |
| 105 | const bool use_gqa_opt = mask && max_bias == 0.0f; |
| 106 | GGML_ASSERT(use_gqa_opt); |
| 107 | |
| 108 | GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); |
| 109 | const int gqa_ratio = Q->ne[2] / K->ne[2]; |
| 110 | GGML_ASSERT(gqa_ratio % 16 == 0); |
| 111 | ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); |
| 112 | } break; |
| 113 | default: |
| 114 | GGML_ABORT("fatal error" ); |
| 115 | break; |
| 116 | } |
| 117 | } |
| 118 | |
| 119 | #define FATTN_VEC_CASE(D, type_K, type_V) \ |
| 120 | { \ |
| 121 | const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \ |
| 122 | const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \ |
| 123 | if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \ |
| 124 | ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \ |
| 125 | return; \ |
| 126 | } \ |
| 127 | } \ |
| 128 | |
| 129 | #define FATTN_VEC_CASES_ALL_D(type_K, type_V) \ |
| 130 | FATTN_VEC_CASE( 64, type_K, type_V) \ |
| 131 | FATTN_VEC_CASE(128, type_K, type_V) \ |
| 132 | FATTN_VEC_CASE(256, type_K, type_V) \ |
| 133 | |
| 134 | static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 135 | ggml_tensor * Q = dst->src[0]; |
| 136 | ggml_tensor * K = dst->src[1]; |
| 137 | ggml_tensor * V = dst->src[2]; |
| 138 | |
| 139 | #ifdef GGML_CUDA_FA_ALL_QUANTS |
| 140 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) |
| 141 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16) |
| 142 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16) |
| 143 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) |
| 144 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) |
| 145 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) |
| 146 | |
| 147 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) |
| 148 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) |
| 149 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) |
| 150 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) |
| 151 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) |
| 152 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) |
| 153 | |
| 154 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) |
| 155 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) |
| 156 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) |
| 157 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) |
| 158 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) |
| 159 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) |
| 160 | |
| 161 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) |
| 162 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) |
| 163 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) |
| 164 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) |
| 165 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) |
| 166 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) |
| 167 | |
| 168 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) |
| 169 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) |
| 170 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) |
| 171 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) |
| 172 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) |
| 173 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) |
| 174 | |
| 175 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) |
| 176 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) |
| 177 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) |
| 178 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) |
| 179 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) |
| 180 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) |
| 181 | #else |
| 182 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) |
| 183 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) |
| 184 | FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) |
| 185 | #endif // GGML_CUDA_FA_ALL_QUANTS |
| 186 | |
| 187 | GGML_ABORT("fatal error" ); |
| 188 | } |
| 189 | |
| 190 | // Best FlashAttention kernel for a specific GPU: |
| 191 | enum best_fattn_kernel { |
| 192 | BEST_FATTN_KERNEL_NONE = 0, |
| 193 | BEST_FATTN_KERNEL_TILE = 200, |
| 194 | BEST_FATTN_KERNEL_VEC = 100, |
| 195 | BEST_FATTN_KERNEL_WMMA_F16 = 300, |
| 196 | BEST_FATTN_KERNEL_MMA_F16 = 400, |
| 197 | }; |
| 198 | |
| 199 | static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) { |
| 200 | #ifndef FLASH_ATTN_AVAILABLE |
| 201 | GGML_UNUSED(device); GGML_UNUSED(dst); |
| 202 | return BEST_FATTN_KERNEL_NONE; |
| 203 | #endif// FLASH_ATTN_AVAILABLE |
| 204 | |
| 205 | const ggml_tensor * KQV = dst; |
| 206 | const ggml_tensor * Q = dst->src[0]; |
| 207 | const ggml_tensor * K = dst->src[1]; |
| 208 | const ggml_tensor * V = dst->src[2]; |
| 209 | const ggml_tensor * mask = dst->src[3]; |
| 210 | |
| 211 | const int gqa_ratio = Q->ne[2] / K->ne[2]; |
| 212 | GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); |
| 213 | |
| 214 | float max_bias = 0.0f; |
| 215 | memcpy(dest: &max_bias, src: (const float *) KQV->op_params + 1, n: sizeof(float)); |
| 216 | |
| 217 | // The effective batch size for the kernel can be increased by gqa_ratio. |
| 218 | // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, |
| 219 | const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; |
| 220 | |
| 221 | const int cc = ggml_cuda_info().devices[device].cc; |
| 222 | |
| 223 | switch (K->ne[0]) { |
| 224 | case 40: |
| 225 | case 64: |
| 226 | case 72: |
| 227 | case 80: |
| 228 | case 96: |
| 229 | case 128: |
| 230 | case 112: |
| 231 | case 256: |
| 232 | if (V->ne[0] != K->ne[0]) { |
| 233 | return BEST_FATTN_KERNEL_NONE; |
| 234 | } |
| 235 | break; |
| 236 | case 576: |
| 237 | if (V->ne[0] != 512) { |
| 238 | return BEST_FATTN_KERNEL_NONE; |
| 239 | } |
| 240 | if (!gqa_opt_applies || gqa_ratio % 16 != 0) { |
| 241 | return BEST_FATTN_KERNEL_NONE; |
| 242 | } |
| 243 | break; |
| 244 | default: |
| 245 | return BEST_FATTN_KERNEL_NONE; |
| 246 | } |
| 247 | |
| 248 | #ifndef GGML_CUDA_FA_ALL_QUANTS |
| 249 | if (K->type != V->type) { |
| 250 | return BEST_FATTN_KERNEL_NONE; |
| 251 | } |
| 252 | #endif // GGML_CUDA_FA_ALL_QUANTS |
| 253 | |
| 254 | switch (K->type) { |
| 255 | case GGML_TYPE_F32: |
| 256 | case GGML_TYPE_F16: |
| 257 | break; |
| 258 | case GGML_TYPE_Q4_1: |
| 259 | case GGML_TYPE_Q5_0: |
| 260 | case GGML_TYPE_Q5_1: |
| 261 | #ifndef GGML_CUDA_FA_ALL_QUANTS |
| 262 | return BEST_FATTN_KERNEL_NONE; |
| 263 | #endif // GGML_CUDA_FA_ALL_QUANTS |
| 264 | case GGML_TYPE_Q4_0: |
| 265 | case GGML_TYPE_Q8_0: |
| 266 | break; |
| 267 | default: |
| 268 | return BEST_FATTN_KERNEL_NONE; |
| 269 | } |
| 270 | |
| 271 | if (mask && mask->ne[2] != 1) { |
| 272 | return BEST_FATTN_KERNEL_NONE; |
| 273 | } |
| 274 | |
| 275 | // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: |
| 276 | const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; |
| 277 | |
| 278 | // If Turing tensor cores available, use them: |
| 279 | if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) { |
| 280 | if (can_use_vector_kernel) { |
| 281 | if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { |
| 282 | if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { |
| 283 | return BEST_FATTN_KERNEL_VEC; |
| 284 | } |
| 285 | } else { |
| 286 | if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { |
| 287 | if (Q->ne[1] <= 2) { |
| 288 | return BEST_FATTN_KERNEL_VEC; |
| 289 | } |
| 290 | } else { |
| 291 | if (Q->ne[1] == 1) { |
| 292 | return BEST_FATTN_KERNEL_VEC; |
| 293 | } |
| 294 | } |
| 295 | } |
| 296 | if (!gqa_opt_applies && Q->ne[1] == 1) { |
| 297 | return BEST_FATTN_KERNEL_VEC; |
| 298 | } |
| 299 | } |
| 300 | |
| 301 | return BEST_FATTN_KERNEL_MMA_F16; |
| 302 | } |
| 303 | |
| 304 | // Use the WMMA kernel if possible: |
| 305 | if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) { |
| 306 | if (can_use_vector_kernel && Q->ne[1] <= 2) { |
| 307 | return BEST_FATTN_KERNEL_VEC; |
| 308 | } |
| 309 | return BEST_FATTN_KERNEL_WMMA_F16; |
| 310 | } |
| 311 | |
| 312 | // If there are no tensor cores available, use the generic tile kernel: |
| 313 | if (can_use_vector_kernel) { |
| 314 | if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { |
| 315 | if (Q->ne[1] == 1) { |
| 316 | if (!gqa_opt_applies) { |
| 317 | return BEST_FATTN_KERNEL_VEC; |
| 318 | } |
| 319 | } |
| 320 | } else { |
| 321 | if (Q->ne[1] <= 2) { |
| 322 | return BEST_FATTN_KERNEL_VEC; |
| 323 | } |
| 324 | } |
| 325 | } |
| 326 | return BEST_FATTN_KERNEL_TILE; |
| 327 | } |
| 328 | |
| 329 | void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 330 | ggml_cuda_set_device(device: ctx.device); |
| 331 | switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { |
| 332 | case BEST_FATTN_KERNEL_NONE: |
| 333 | GGML_ABORT("fatal error" ); |
| 334 | case BEST_FATTN_KERNEL_TILE: |
| 335 | ggml_cuda_flash_attn_ext_tile(ctx, dst); |
| 336 | break; |
| 337 | case BEST_FATTN_KERNEL_VEC: |
| 338 | ggml_cuda_flash_attn_ext_vec(ctx, dst); |
| 339 | break; |
| 340 | case BEST_FATTN_KERNEL_WMMA_F16: |
| 341 | ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); |
| 342 | break; |
| 343 | case BEST_FATTN_KERNEL_MMA_F16: |
| 344 | ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); |
| 345 | break; |
| 346 | } |
| 347 | } |
| 348 | |
| 349 | bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) { |
| 350 | return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE; |
| 351 | } |
| 352 | |