1#include "common.cuh"
2#include "fattn-common.cuh"
3#include "fattn-mma-f16.cuh"
4#include "fattn-tile.cuh"
5#include "fattn-vec.cuh"
6#include "fattn-wmma-f16.cuh"
7#include "fattn.cuh"
8
9template <int DKQ, int DV, int ncols2>
10static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
11 const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
12 const ggml_tensor * Q = dst->src[0];
13
14 if constexpr (ncols2 <= 8) {
15 if (Q->ne[1] <= 8/ncols2) {
16 ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
17 return;
18 }
19 }
20
21 if (Q->ne[1] <= 16/ncols2) {
22 ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
23 return;
24 }
25
26 if (ggml_cuda_highest_compiled_arch(arch: cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
27 ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
28 return;
29 }
30
31 ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
32}
33
34template <int DKQ, int DV>
35static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
36 const ggml_tensor * KQV = dst;
37 const ggml_tensor * Q = dst->src[0];
38 const ggml_tensor * K = dst->src[1];
39 const ggml_tensor * mask = dst->src[3];
40
41 float max_bias = 0.0f;
42 memcpy(dest: &max_bias, src: (const float *) KQV->op_params + 1, n: sizeof(float));
43
44 const bool use_gqa_opt = mask && max_bias == 0.0f;
45
46 GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
47 const int gqa_ratio = Q->ne[2] / K->ne[2];
48
49 if (use_gqa_opt && gqa_ratio % 8 == 0) {
50 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
51 return;
52 }
53
54 if (use_gqa_opt && gqa_ratio % 4 == 0) {
55 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
56 return;
57 }
58
59 if (use_gqa_opt && gqa_ratio % 2 == 0) {
60 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
61 return;
62 }
63
64 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
65}
66
67static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
68 const ggml_tensor * KQV = dst;
69 const ggml_tensor * Q = dst->src[0];
70 const ggml_tensor * K = dst->src[1];
71 const ggml_tensor * V = dst->src[2];
72 const ggml_tensor * mask = dst->src[3];
73
74 switch (Q->ne[0]) {
75 case 64:
76 GGML_ASSERT(V->ne[0] == 64);
77 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
78 break;
79 case 80:
80 GGML_ASSERT(V->ne[0] == 80);
81 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
82 break;
83 case 96:
84 GGML_ASSERT(V->ne[0] == 96);
85 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
86 break;
87 case 112:
88 GGML_ASSERT(V->ne[0] == 112);
89 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
90 break;
91 case 128:
92 GGML_ASSERT(V->ne[0] == 128);
93 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
94 break;
95 case 256:
96 GGML_ASSERT(V->ne[0] == 256);
97 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
98 break;
99 case 576: {
100 // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
101 GGML_ASSERT(V->ne[0] == 512);
102 float max_bias = 0.0f;
103 memcpy(dest: &max_bias, src: (const float *) KQV->op_params + 1, n: sizeof(float));
104
105 const bool use_gqa_opt = mask && max_bias == 0.0f;
106 GGML_ASSERT(use_gqa_opt);
107
108 GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
109 const int gqa_ratio = Q->ne[2] / K->ne[2];
110 GGML_ASSERT(gqa_ratio % 16 == 0);
111 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
112 } break;
113 default:
114 GGML_ABORT("fatal error");
115 break;
116 }
117}
118
119#define FATTN_VEC_CASE(D, type_K, type_V) \
120 { \
121 const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
122 const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
123 if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
124 ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
125 return; \
126 } \
127 } \
128
129#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
130 FATTN_VEC_CASE( 64, type_K, type_V) \
131 FATTN_VEC_CASE(128, type_K, type_V) \
132 FATTN_VEC_CASE(256, type_K, type_V) \
133
134static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
135 ggml_tensor * Q = dst->src[0];
136 ggml_tensor * K = dst->src[1];
137 ggml_tensor * V = dst->src[2];
138
139#ifdef GGML_CUDA_FA_ALL_QUANTS
140 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
141 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
142 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
143 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
144 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
145 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
146
147 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
148 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
149 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
150 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
151 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
152 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
153
154 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
155 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
156 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
157 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
158 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
159 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
160
161 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
162 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
163 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
164 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
165 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
166 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
167
168 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
169 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
170 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
171 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
172 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
173 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
174
175 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
176 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
177 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
178 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
179 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
180 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
181#else
182 FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
183 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
184 FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
185#endif // GGML_CUDA_FA_ALL_QUANTS
186
187 GGML_ABORT("fatal error");
188}
189
190// Best FlashAttention kernel for a specific GPU:
191enum best_fattn_kernel {
192 BEST_FATTN_KERNEL_NONE = 0,
193 BEST_FATTN_KERNEL_TILE = 200,
194 BEST_FATTN_KERNEL_VEC = 100,
195 BEST_FATTN_KERNEL_WMMA_F16 = 300,
196 BEST_FATTN_KERNEL_MMA_F16 = 400,
197};
198
199static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
200#ifndef FLASH_ATTN_AVAILABLE
201 GGML_UNUSED(device); GGML_UNUSED(dst);
202 return BEST_FATTN_KERNEL_NONE;
203#endif// FLASH_ATTN_AVAILABLE
204
205 const ggml_tensor * KQV = dst;
206 const ggml_tensor * Q = dst->src[0];
207 const ggml_tensor * K = dst->src[1];
208 const ggml_tensor * V = dst->src[2];
209 const ggml_tensor * mask = dst->src[3];
210
211 const int gqa_ratio = Q->ne[2] / K->ne[2];
212 GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
213
214 float max_bias = 0.0f;
215 memcpy(dest: &max_bias, src: (const float *) KQV->op_params + 1, n: sizeof(float));
216
217 // The effective batch size for the kernel can be increased by gqa_ratio.
218 // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
219 const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
220
221 const int cc = ggml_cuda_info().devices[device].cc;
222
223 switch (K->ne[0]) {
224 case 40:
225 case 64:
226 case 72:
227 case 80:
228 case 96:
229 case 128:
230 case 112:
231 case 256:
232 if (V->ne[0] != K->ne[0]) {
233 return BEST_FATTN_KERNEL_NONE;
234 }
235 break;
236 case 576:
237 if (V->ne[0] != 512) {
238 return BEST_FATTN_KERNEL_NONE;
239 }
240 if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
241 return BEST_FATTN_KERNEL_NONE;
242 }
243 break;
244 default:
245 return BEST_FATTN_KERNEL_NONE;
246 }
247
248#ifndef GGML_CUDA_FA_ALL_QUANTS
249 if (K->type != V->type) {
250 return BEST_FATTN_KERNEL_NONE;
251 }
252#endif // GGML_CUDA_FA_ALL_QUANTS
253
254 switch (K->type) {
255 case GGML_TYPE_F32:
256 case GGML_TYPE_F16:
257 break;
258 case GGML_TYPE_Q4_1:
259 case GGML_TYPE_Q5_0:
260 case GGML_TYPE_Q5_1:
261#ifndef GGML_CUDA_FA_ALL_QUANTS
262 return BEST_FATTN_KERNEL_NONE;
263#endif // GGML_CUDA_FA_ALL_QUANTS
264 case GGML_TYPE_Q4_0:
265 case GGML_TYPE_Q8_0:
266 break;
267 default:
268 return BEST_FATTN_KERNEL_NONE;
269 }
270
271 if (mask && mask->ne[2] != 1) {
272 return BEST_FATTN_KERNEL_NONE;
273 }
274
275 // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
276 const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
277
278 // If Turing tensor cores available, use them:
279 if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) {
280 if (can_use_vector_kernel) {
281 if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
282 if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
283 return BEST_FATTN_KERNEL_VEC;
284 }
285 } else {
286 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
287 if (Q->ne[1] <= 2) {
288 return BEST_FATTN_KERNEL_VEC;
289 }
290 } else {
291 if (Q->ne[1] == 1) {
292 return BEST_FATTN_KERNEL_VEC;
293 }
294 }
295 }
296 if (!gqa_opt_applies && Q->ne[1] == 1) {
297 return BEST_FATTN_KERNEL_VEC;
298 }
299 }
300
301 return BEST_FATTN_KERNEL_MMA_F16;
302 }
303
304 // Use the WMMA kernel if possible:
305 if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
306 if (can_use_vector_kernel && Q->ne[1] <= 2) {
307 return BEST_FATTN_KERNEL_VEC;
308 }
309 return BEST_FATTN_KERNEL_WMMA_F16;
310 }
311
312 // If there are no tensor cores available, use the generic tile kernel:
313 if (can_use_vector_kernel) {
314 if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
315 if (Q->ne[1] == 1) {
316 if (!gqa_opt_applies) {
317 return BEST_FATTN_KERNEL_VEC;
318 }
319 }
320 } else {
321 if (Q->ne[1] <= 2) {
322 return BEST_FATTN_KERNEL_VEC;
323 }
324 }
325 }
326 return BEST_FATTN_KERNEL_TILE;
327}
328
329void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
330 ggml_cuda_set_device(device: ctx.device);
331 switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
332 case BEST_FATTN_KERNEL_NONE:
333 GGML_ABORT("fatal error");
334 case BEST_FATTN_KERNEL_TILE:
335 ggml_cuda_flash_attn_ext_tile(ctx, dst);
336 break;
337 case BEST_FATTN_KERNEL_VEC:
338 ggml_cuda_flash_attn_ext_vec(ctx, dst);
339 break;
340 case BEST_FATTN_KERNEL_WMMA_F16:
341 ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
342 break;
343 case BEST_FATTN_KERNEL_MMA_F16:
344 ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
345 break;
346 }
347}
348
349bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
350 return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
351}
352