1#pragma once
2
3#include "common.cuh"
4#include "convert.cuh"
5#include "vecdotq.cuh"
6
7#include <cstdint>
8
9#define FATTN_KQ_STRIDE 256
10#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
11#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
12
13typedef void (* fattn_kernel_t)(
14 const char * __restrict__ Q,
15 const char * __restrict__ K,
16 const char * __restrict__ V,
17 const char * __restrict__ mask,
18 const char * __restrict__ sinks,
19 const int * __restrict__ KV_max,
20 float * __restrict__ dst,
21 float2 * __restrict__ dst_meta,
22 const float scale,
23 const float max_bias,
24 const float m0,
25 const float m1,
26 const uint32_t n_head_log2,
27 const float logit_softcap,
28 const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
29 const int32_t nb01, const int32_t nb02, const int32_t nb03,
30 const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
31 const int32_t nb11, const int32_t nb12, const int64_t nb13,
32 const int32_t nb21, const int32_t nb22, const int64_t nb23,
33 const int32_t ne31, const int32_t ne32, const int32_t ne33,
34 const int32_t nb31, const int32_t nb32, const int64_t nb33);
35
36typedef float (*vec_dot_KQ_t)(
37 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
38
39template <int D, int nthreads>
40static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
41 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
42
43 const half2 * K_h2 = (const half2 *) K_c;
44 GGML_UNUSED(Q_q8);
45 GGML_UNUSED(Q_ds_v);
46
47 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
48 constexpr int cpy_ne = cpy_nb / 4;
49
50 float sum = 0.0f;
51
52#pragma unroll
53 for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
54 half2 tmp[cpy_ne];
55 ggml_cuda_memcpy_1<sizeof(tmp)>(dst: tmp, src: K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
56#pragma unroll
57 for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
58#ifdef FAST_FP16_AVAILABLE
59 ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
60#else
61 ggml_cuda_mad(acc&: sum, v: __half22float2(a: tmp[k_KQ_1]), u: ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
62#endif // FP16_AVAILABLE
63 }
64 }
65
66 return sum;
67}
68
69template<int D, int nthreads>
70static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
71 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
72
73 const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
74 GGML_UNUSED(Q_v);
75
76 float sum = 0.0f;
77
78#pragma unroll
79 for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
80 const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
81
82 const int ib = k_KQ / QI8_1;
83 const int iqs4 = k_KQ % QI4_0;
84 const int shift = k_KQ & (QI8_1/2);
85
86 int v;
87 ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
88 v = (v >> shift) & 0x0F0F0F0F;
89 const int u = Q_q8[k_KQ_0/nthreads];
90
91 const int sumi = ggml_cuda_dp4a(a: v, b: u, c: 0);
92
93 const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
94 sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y);
95 }
96
97 return sum;
98}
99
100template<int D, int nthreads>
101static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1(
102 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
103
104 const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
105 GGML_UNUSED(Q_v);
106
107 float sum = 0.0f;
108
109#pragma unroll
110 for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
111 const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
112
113 const int ib = k_KQ / QI8_1;
114 const int iqs4 = k_KQ % QI4_1;
115 const int shift = k_KQ & (QI8_1/2);
116
117 int v;
118 ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
119 v = (v >> shift) & 0x0F0F0F0F;
120 const int u = Q_q8[k_KQ_0/nthreads];
121
122 const int sumi = ggml_cuda_dp4a(a: v, b: u, c: 0);
123
124 const float2 K_dm = __half22float2(K_q4_1[ib].dm);
125 const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
126
127 sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
128 }
129
130 return sum;
131}
132
133template<int D, int nthreads>
134static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0(
135 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
136
137 const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
138 GGML_UNUSED(Q_v);
139
140 float sum = 0.0f;
141
142#pragma unroll
143 for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
144 const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
145
146 const int ib = k_KQ / QI8_1;
147 const int iqs4 = k_KQ % QI5_0;
148 const int iqs8 = k_KQ % QI8_1;
149 const int shift = k_KQ & (QI8_1/2);
150
151 int v;
152 ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
153 v = (v >> shift) & 0x0F0F0F0F;
154
155 {
156 int vh;
157 ggml_cuda_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
158 vh >>= iqs8 * QI5_0;
159
160 v |= (vh << 4) & 0x00000010; // 0 -> 4
161 v |= (vh << 11) & 0x00001000; // 1 -> 12
162 v |= (vh << 18) & 0x00100000; // 2 -> 20
163 v |= (vh << 25) & 0x10000000; // 3 -> 28
164 }
165
166 const int u = Q_q8[k_KQ_0/nthreads];
167
168 const int sumi = ggml_cuda_dp4a(a: v, b: u, c: 0);
169
170 const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
171
172 sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y);
173 }
174
175 return sum;
176}
177
178template<int D, int nthreads>
179static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1(
180 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
181
182 const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
183 GGML_UNUSED(Q_v);
184
185 float sum = 0.0f;
186
187#pragma unroll
188 for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
189 const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
190
191 const int ib = k_KQ / QI8_1;
192 const int iqs4 = k_KQ % QI5_1;
193 const int iqs8 = k_KQ % QI8_1;
194 const int shift = k_KQ & (QI8_1/2);
195
196 int v;
197 ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
198 v = (v >> shift) & 0x0F0F0F0F;
199
200 {
201 int vh;
202 ggml_cuda_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
203 vh >>= iqs8 * QI5_0;
204
205 v |= (vh << 4) & 0x00000010; // 0 -> 4
206 v |= (vh << 11) & 0x00001000; // 1 -> 12
207 v |= (vh << 18) & 0x00100000; // 2 -> 20
208 v |= (vh << 25) & 0x10000000; // 3 -> 28
209 }
210
211 const int u = Q_q8[k_KQ_0/nthreads];
212
213 const int sumi = ggml_cuda_dp4a(a: v, b: u, c: 0);
214
215 const float2 K_dm = __half22float2(K_q5_1[ib].dm);
216 const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
217
218 sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
219 }
220
221 return sum;
222}
223
224template <int D, int nthreads>
225static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0(
226 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
227
228 const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
229 GGML_UNUSED(Q_v);
230
231 float sum = 0.0f;
232
233#pragma unroll
234 for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
235 const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
236
237 const int ib = k_KQ / QI8_0;
238 const int iqs = k_KQ % QI8_0;
239
240 int v;
241 ggml_cuda_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
242
243 const float2 * Q_ds = (const float2 *) Q_ds_v;
244 const float Q_d = Q_ds[k_KQ_0/nthreads].x;
245
246 sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
247 }
248
249 return sum;
250}
251
252template <typename Tds, int ni>
253static __device__ __forceinline__ void quantize_q8_1_to_shared(
254 const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
255
256 float vals[sizeof(int)] = {0.0f};
257#pragma unroll
258 for (int l = 0; l < int(sizeof(int)); ++l) {
259 vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f;
260 }
261
262 float amax = fabsf(a: vals[0]);
263 float sum = vals[0];
264#pragma unroll
265 for (int l = 1; l < int(sizeof(int)); ++l) {
266 amax = fmaxf(a: amax, b: fabsf(a: vals[l]));
267 sum += vals[l];
268 }
269#pragma unroll
270 for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
271 amax = fmaxf(a: amax, b: __shfl_xor_sync(mask: 0xFFFFFFFF, val: amax, offset: mask, width: 32));
272 sum += __shfl_xor_sync(mask: 0xFFFFFFFF, val: sum, offset: mask, width: 32);
273 }
274
275 const float d = amax / 127;
276 int q32 = 0;
277 int8_t * q8 = (int8_t *) &q32;
278
279 if (d != 0.0f) {
280#pragma unroll
281 for (int l = 0; l < int(sizeof(int)); ++l) {
282 q8[l] = roundf(a: vals[l] / d);
283 }
284 }
285
286 yq32[threadIdx.x] = q32;
287 if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) {
288 if (std::is_same<Tds, half2>::value) {
289 ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
290 } else {
291 ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);
292 }
293 }
294}
295
296typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
297
298template <typename T, int ne>
299static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
300 if constexpr (std::is_same_v<T, half>) {
301 ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
302 } else if constexpr (std::is_same_v<T, float>) {
303 static_assert(ne % 2 == 0, "bad ne");
304 half2 tmp[ne/2];
305 ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
306 float2 * dst_f2 = (float2 *) dst;
307#pragma unroll
308 for (int l = 0; l < ne/2; ++l) {
309 dst_f2[l] = __half22float2(tmp[l]);
310 }
311 } else {
312 static_assert(std::is_same_v<T, void>, "unsupported type");
313 }
314}
315
316template <typename T, int ne>
317static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
318 const block_q4_0 * x = (const block_q4_0 *) vx;
319
320 const int64_t ib = i0 / QK4_0;
321 const int iqs = i0 % (QK4_0/2);
322 const int shift = (i0 % QK4_0) / (QK4_0/2);
323
324 int q;
325 static_assert(ne == 2 || ne == 4, "bad ne");
326 ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
327 q >>= 4*shift;
328 q &= 0x0F0F0F0F;
329 q = __vsubss4(a: q, b: 0x08080808);
330
331 const int8_t * q8 = (const int8_t *) &q;
332
333#ifdef FP16_AVAILABLE
334 if constexpr (std::is_same_v<T, half>) {
335 const half2 d = __half2half2(x[ib].d);
336
337#pragma unroll
338 for (int l0 = 0; l0 < ne; l0 += 2) {
339 ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
340 }
341 } else
342#endif // FP16_AVAILABLE
343 if constexpr (std::is_same_v<T, float>) {
344 const float d = x[ib].d;
345
346#pragma unroll
347 for (int l = 0; l < ne; ++l) {
348 ((float *) dst)[l] = d * q8[l];
349 }
350 } else {
351 static_assert(std::is_same_v<T, void>, "bad type");
352 }
353}
354
355template <typename T, int ne>
356static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
357 const block_q4_1 * x = (const block_q4_1 *) vx;
358
359 const int64_t ib = i0 / QK4_1;
360 const int iqs = i0 % (QK4_1/2);
361 const int shift = (i0 % QK4_1) / (QK4_1/2);
362
363 int q;
364 static_assert(ne == 2 || ne == 4, "bad ne");
365 ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
366 q >>= 4*shift;
367 q &= 0x0F0F0F0F;
368
369 const int8_t * q8 = (const int8_t *) &q;
370
371#ifdef FP16_AVAILABLE
372 if constexpr (std::is_same_v<T, half>) {
373 const half2 dm = x[ib].dm;
374 const half2 d = __half2half2( __low2half(dm));
375 const half2 m = __half2half2(__high2half(dm));
376
377#pragma unroll
378 for (int l0 = 0; l0 < ne; l0 += 2) {
379 ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
380 }
381 } else
382#endif // FP16_AVAILABLE
383 if constexpr (std::is_same_v<T, float>) {
384 const float2 dm = __half22float2(x[ib].dm);
385
386#pragma unroll
387 for (int l = 0; l < ne; ++l) {
388 ((float *) dst)[l] = dm.x * q8[l] + dm.y;
389 }
390 } else {
391 static_assert(std::is_same_v<T, void>, "bad type");
392 }
393}
394
395template <typename T, int ne>
396static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
397 const block_q5_0 * x = (const block_q5_0 *) vx;
398
399 const int64_t ib = i0 / QK5_0;
400 const int idq = i0 % QK5_0;
401 const int iqs = i0 % (QK5_0/2);
402 const int shift = (i0 % QK5_0) / (QK5_0/2);
403
404 int q;
405 static_assert(ne == 2 || ne == 4, "bad ne");
406 ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
407 q >>= 4*shift;
408 q &= 0x0F0F0F0F;
409
410 {
411 int qh;
412 ggml_cuda_memcpy_1<ne, 2>(&qh, x[ib].qh);
413#pragma unroll
414 for (int l = 0; l < ne; ++l) {
415 q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
416 }
417 }
418
419 q = __vsubss4(a: q, b: 0x10101010);
420
421 const int8_t * q8 = (const int8_t *) &q;
422
423#ifdef FP16_AVAILABLE
424 if constexpr (std::is_same_v<T, half>) {
425 const half2 d = __half2half2(x[ib].d);
426
427#pragma unroll
428 for (int l0 = 0; l0 < ne; l0 += 2) {
429 ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
430 }
431 } else
432#endif // FP16_AVAILABLE
433 if constexpr (std::is_same_v<T, float>) {
434 const float d = x[ib].d;
435
436#pragma unroll
437 for (int l = 0; l < ne; ++l) {
438 ((float *) dst)[l] = d * q8[l];
439 }
440 } else {
441 static_assert(std::is_same_v<T, void>, "bad type");
442 }
443}
444
445template <typename T, int ne>
446static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
447 const block_q5_1 * x = (const block_q5_1 *) vx;
448
449 const int64_t ib = i0 / QK5_1;
450 const int idq = i0 % QK5_1;
451 const int iqs = i0 % (QK5_1/2);
452 const int shift = (i0 % QK5_1) / (QK5_1/2);
453
454 int q;
455 static_assert(ne == 2 || ne == 4, "bad ne");
456 ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
457 q >>= 4*shift;
458 q &= 0x0F0F0F0F;
459
460 {
461 int qh;
462 ggml_cuda_memcpy_1<ne>(&qh, x[ib].qh);
463#pragma unroll
464 for (int l = 0; l < ne; ++l) {
465 q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
466 }
467 }
468
469 const int8_t * q8 = (const int8_t *) &q;
470
471#ifdef FP16_AVAILABLE
472 if constexpr (std::is_same_v<T, half>) {
473 const half2 dm = x[ib].dm;
474 const half2 d = __half2half2( __low2half(dm));
475 const half2 m = __half2half2(__high2half(dm));
476
477#pragma unroll
478 for (int l0 = 0; l0 < ne; l0 += 2) {
479 ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
480 }
481 } else
482#endif // FP16_AVAILABLE
483 if constexpr (std::is_same_v<T, float>) {
484 const float2 dm = __half22float2(x[ib].dm);
485
486#pragma unroll
487 for (int l = 0; l < ne; ++l) {
488 ((float *) dst)[l] = dm.x * q8[l] + dm.y;
489 }
490 } else {
491 static_assert(std::is_same_v<T, void>, "bad type");
492 }
493}
494
495template <typename T, int ne>
496static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
497 const block_q8_0 * x = (const block_q8_0 *) vx;
498
499 const int64_t ib = i0 / QK8_0;
500 const int iqs = i0 % QK8_0;
501
502 static_assert(ne % 2 == 0, "bad ne");
503 int8_t qs[ne];
504 ggml_cuda_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
505
506#ifdef FP16_AVAILABLE
507 if constexpr (std::is_same<T, half>::value) {
508 const half2 d = __half2half2(x[ib].d);
509
510#pragma unroll
511 for (int l0 = 0; l0 < ne; l0 += 2) {
512 ((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
513 }
514 } else
515#endif // FP16_AVAILABLE
516 if constexpr (std::is_same<T, float>::value) {
517 const float d = x[ib].d;
518
519#pragma unroll
520 for (int l = 0; l < ne; ++l) {
521 ((float *) dst)[l] = d * qs[l];
522 }
523 } else {
524 static_assert(std::is_same_v<T, void>, "unsupported type");
525 }
526}
527
528template <ggml_type type_K, int D, int nthreads>
529constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
530 if constexpr (type_K == GGML_TYPE_F16) {
531 return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
532 } else if constexpr (type_K == GGML_TYPE_Q4_0) {
533 return vec_dot_fattn_vec_KQ_q4_0<D, nthreads>;
534 } else if constexpr (type_K == GGML_TYPE_Q4_1) {
535 return vec_dot_fattn_vec_KQ_q4_1<D, nthreads>;
536 } else if constexpr (type_K == GGML_TYPE_Q5_0) {
537 return vec_dot_fattn_vec_KQ_q5_0<D, nthreads>;
538 } else if constexpr (type_K == GGML_TYPE_Q5_1) {
539 return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
540 } else if constexpr (type_K == GGML_TYPE_Q8_0) {
541 return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
542 } else {
543 static_assert(type_K == -1, "bad type");
544 return nullptr;
545 }
546}
547
548template <ggml_type type_V, typename T, int ne>
549constexpr __device__ dequantize_V_t get_dequantize_V() {
550 if constexpr (type_V == GGML_TYPE_F16) {
551 return dequantize_V_f16<T, ne>;
552 } else if constexpr (type_V == GGML_TYPE_Q4_0) {
553 return dequantize_V_q4_0<T, ne>;
554 } else if constexpr (type_V == GGML_TYPE_Q4_1) {
555 return dequantize_V_q4_1<T, ne>;
556 } else if constexpr (type_V == GGML_TYPE_Q5_0) {
557 return dequantize_V_q5_0<T, ne>;
558 } else if constexpr (type_V == GGML_TYPE_Q5_1) {
559 return dequantize_V_q5_1<T, ne>;
560 } else if constexpr (type_V == GGML_TYPE_Q8_0) {
561 return dequantize_V_q8_0<T, ne>;
562 } else {
563 static_assert(type_V == -1, "bad type");
564 return nullptr;
565 }
566}
567
568template <int ncols1>
569__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
570static __global__ void flash_attn_mask_to_KV_max(
571 const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
572 const int ne31 = gridDim.x;
573 const int tid = threadIdx.x;
574 const int sequence = blockIdx.y;
575 const int jt = blockIdx.x;
576
577 mask += sequence*s33 + jt*ncols1*s31;
578
579 __shared__ int buf_iw[WARP_SIZE];
580 if (tid < WARP_SIZE) {
581 buf_iw[tid] = 1;
582 }
583 __syncthreads();
584
585 int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
586 for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
587 int all_inf = 1;
588
589#pragma unroll
590 for (int j = 0; j < ncols1; ++j) {
591 const float2 tmp = __half22float2(a: mask[j*s31 + KV_max_sj/2 + tid]);
592 all_inf = all_inf && int(isinf(x: tmp.x)) && int(isinf(x: tmp.y));
593 }
594
595 all_inf = warp_reduce_all(x: all_inf);
596 if (tid % WARP_SIZE == 0) {
597 buf_iw[tid / WARP_SIZE] = all_inf;
598 }
599 __syncthreads();
600 all_inf = buf_iw[tid % WARP_SIZE];
601 __syncthreads();
602 all_inf = warp_reduce_all(x: all_inf);
603
604 if (!all_inf) {
605 break;
606 }
607 }
608
609 // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
610 // If the break was triggered it's the lower edge of the tile with the first non-masked values.
611 // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
612 KV_max_sj += FATTN_KQ_STRIDE;
613
614 if (threadIdx.x != 0) {
615 return;
616 }
617
618 KV_max[sequence*ne31 + jt] = KV_max_sj;
619}
620
621template<int D, int ncols1, int ncols2> // D == head size
622__launch_bounds__(D, 1)
623static __global__ void flash_attn_stream_k_fixup(
624 float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
625 constexpr int ncols = ncols1*ncols2;
626
627 const int bidx0 = blockIdx.x;
628 const int j = blockIdx.y;
629 const int c = blockIdx.z;
630 const int jc = j*ncols2 + c;
631 const int tid = threadIdx.x;
632
633 const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
634
635 const int iter_k = ne11 / FATTN_KQ_STRIDE;
636 const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
637
638 const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
639 const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
640
641 const bool did_not_have_any_data = kbc0 == kbc0_stop;
642 const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
643 const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
644 if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
645 return;
646 }
647
648 const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
649 const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
650 const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
651
652 if (jt*ncols1 + j >= ne01) {
653 return;
654 }
655
656 dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
657
658 // Load the partial result that needs a fixup:
659 float dst_val = 0.0f;
660 float max_val = 0.0f;
661 float rowsum = 0.0f;
662 {
663 dst_val = *dst;
664
665 const float2 tmp = dst_fixup[bidx0*ncols + jc];
666 max_val = tmp.x;
667 rowsum = tmp.y;
668 }
669
670 // Iterate over previous blocks and compute the combined results.
671 // All CUDA blocks that get here must have a previous block that needs a fixup.
672 int bidx = bidx0 - 1;
673 int kbc_stop = kbc0;
674 while(true) {
675 const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
676 if (kbc == kbc_stop) { // Did not have any data.
677 bidx--;
678 kbc_stop = kbc;
679 continue;
680 }
681
682 const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
683
684 const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
685
686 // Scale the current and new value accumulators depending on the max. values.
687 const float max_val_new = fmaxf(a: max_val, b: tmp.x);
688
689 const float diff_val = max_val - max_val_new;
690 const float diff_add = tmp.x - max_val_new;
691
692 const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(a: diff_val) : 0.0f;
693 const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(a: diff_add) : 0.0f;
694
695 dst_val = scale_val*dst_val + scale_add*dst_add;
696 rowsum = scale_val*rowsum + scale_add*tmp.y;
697
698 max_val = max_val_new;
699
700 // If this block started in a previous tile we are done and don't need to combine additional partial results.
701 if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
702 break;
703 }
704 bidx--;
705 kbc_stop = kbc;
706 }
707
708 // Write back final result:
709 *dst = dst_val / rowsum;
710}
711
712template<int D> // D == head size
713__launch_bounds__(D, 1)
714static __global__ void flash_attn_combine_results(
715 const float * __restrict__ VKQ_parts,
716 const float2 * __restrict__ VKQ_meta,
717 float * __restrict__ dst,
718 const int parallel_blocks) {
719 // Dimension 0: threadIdx.x
720 // Dimension 1: blockIdx.x
721 // Dimension 2: blockIdx.y
722 // Dimension 3: blockIdx.z
723 // Memory layout is permuted with [0, 2, 1, 3]
724
725 const int ne01 = gridDim.x;
726 const int ne02 = gridDim.y;
727
728 const int col = blockIdx.x;
729 const int head = blockIdx.y;
730 const int sequence = blockIdx.z;
731
732 const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
733
734 VKQ_parts += j_dst_unrolled * parallel_blocks*D;
735 VKQ_meta += j_dst_unrolled * parallel_blocks;
736 dst += j_dst_unrolled * D;
737
738 const int tid = threadIdx.x;
739 __builtin_assume(tid < D);
740
741 extern __shared__ float2 meta[];
742 for (int i = tid; i < 2*parallel_blocks; i += D) {
743 ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
744 }
745
746 __syncthreads();
747
748 float kqmax = meta[0].x;
749 for (int l = 1; l < parallel_blocks; ++l) {
750 kqmax = max(a: kqmax, b: meta[l].x);
751 }
752
753 float VKQ_numerator = 0.0f;
754 float VKQ_denominator = 0.0f;
755 for (int l = 0; l < parallel_blocks; ++l) {
756 const float KQ_max_scale = expf(a: meta[l].x - kqmax);
757
758 VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
759 VKQ_denominator += KQ_max_scale * meta[l].y;
760 }
761
762 dst[tid] = VKQ_numerator / VKQ_denominator;
763}
764
765template <int DV, int ncols1, int ncols2>
766void launch_fattn(
767 ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
768 const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
769) {
770 constexpr int ncols = ncols1 * ncols2;
771
772 const bool is_mla = DV == 512; // TODO better parameterization
773
774 const ggml_tensor * Q = dst->src[0];
775 const ggml_tensor * K = dst->src[1];
776 const ggml_tensor * V = dst->src[2];
777
778 GGML_ASSERT(V || is_mla);
779
780 const ggml_tensor * mask = dst->src[3];
781 const ggml_tensor * sinks = dst->src[4];
782
783 ggml_tensor * KQV = dst;
784
785 GGML_ASSERT(Q->type == GGML_TYPE_F32);
786 GGML_ASSERT(KQV->type == GGML_TYPE_F32);
787
788 GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
789 GGML_ASSERT( K->nb[0] == ggml_element_size(K));
790 GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
791
792 GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
793 GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
794 "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
795
796 ggml_cuda_pool & pool = ctx.pool();
797 cudaStream_t main_stream = ctx.stream();
798 const int id = ggml_cuda_get_device();
799 const int cc = ggml_cuda_info().devices[id].cc;
800 const int nsm = ggml_cuda_info().devices[id].nsm;
801
802 ggml_cuda_pool_alloc<half> K_f16(pool);
803 ggml_cuda_pool_alloc<half> V_f16(pool);
804 ggml_cuda_pool_alloc<int> KV_max(pool);
805 ggml_cuda_pool_alloc<float> dst_tmp(pool);
806 ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
807
808 const char * K_data = (const char *) K->data;
809 size_t nb11 = K->nb[1];
810 size_t nb12 = K->nb[2];
811 size_t nb13 = K->nb[3];
812
813 const char * V_data = V ? (const char *) V->data : nullptr;
814 size_t nb21 = V ? V->nb[1] : nb11;
815 size_t nb22 = V ? V->nb[2] : nb12;
816 size_t nb23 = V ? V->nb[3] : nb13;
817
818 if (need_f16_K && K->type != GGML_TYPE_F16) {
819 const size_t bs = ggml_blck_size(K->type);
820 const size_t ts = ggml_type_size(K->type);
821
822 K_f16.alloc(ggml_nelements(K));
823 if (ggml_is_contiguously_allocated(K)) {
824 to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
825 to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
826
827 nb11 = nb11*bs*sizeof(half)/ts;
828 nb12 = nb12*bs*sizeof(half)/ts;
829 nb13 = nb13*bs*sizeof(half)/ts;
830 } else {
831 GGML_ASSERT(K->nb[0] == ts);
832 to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
833 const int64_t s01 = nb11 / ts;
834 const int64_t s02 = nb12 / ts;
835 const int64_t s03 = nb13 / ts;
836 to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
837
838 nb11 = K->ne[0] * sizeof(half);
839 nb12 = K->ne[1] * nb11;
840 nb13 = K->ne[2] * nb12;
841 }
842 K_data = (char *) K_f16.ptr;
843 }
844
845 if (V && need_f16_V && V->type != GGML_TYPE_F16) {
846 const size_t bs = ggml_blck_size(V->type);
847 const size_t ts = ggml_type_size(V->type);
848
849 V_f16.alloc(ggml_nelements(V));
850 if (ggml_is_contiguously_allocated(V)) {
851 to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
852 to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
853 V_data = (char *) V_f16.ptr;
854
855 nb21 = nb21*bs*sizeof(half)/ts;
856 nb22 = nb22*bs*sizeof(half)/ts;
857 nb23 = nb23*bs*sizeof(half)/ts;
858 } else {
859 GGML_ASSERT(V->nb[0] == ts);
860 to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
861 const int64_t s01 = nb21 / ts;
862 const int64_t s02 = nb22 / ts;
863 const int64_t s03 = nb23 / ts;
864 to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
865
866 nb21 = V->ne[0] * sizeof(half);
867 nb22 = V->ne[1] * nb21;
868 nb23 = V->ne[2] * nb22;
869 }
870 V_data = (char *) V_f16.ptr;
871 }
872
873 const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
874 const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
875
876 // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
877 // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
878 // multiple sequences of possibly different lengths.
879 if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
880 const int s31 = mask->nb[1] / sizeof(half2);
881 const int s33 = mask->nb[3] / sizeof(half2);
882
883 const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
884 const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
885
886 const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
887 const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
888
889 KV_max.alloc(size: ne_KV_max);
890 flash_attn_mask_to_KV_max<ncols1><<<gridDim: blocks_num_KV_max, blockDim: block_dim_KV_max, sharedMem: 0, stream: main_stream>>>
891 ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
892 CUDA_CHECK(cudaGetLastError());
893 }
894
895 const dim3 block_dim(warp_size, nwarps, 1);
896 int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
897 CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
898 GGML_ASSERT(max_blocks_per_sm > 0);
899 int parallel_blocks = max_blocks_per_sm;
900
901 dim3 blocks_num;
902 if (stream_k) {
903 // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
904 const int max_blocks = max_blocks_per_sm*nsm;
905 const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
906 const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
907
908 const int nblocks_stream_k = max_blocks;
909
910 const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
911
912 blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
913 blocks_num.y = 1;
914 blocks_num.z = 1;
915
916 dst_tmp_meta.alloc(size: blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
917 } else {
918 const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
919
920 // parallel_blocks must not be larger than what the tensor size allows:
921 parallel_blocks = std::min(a: parallel_blocks, b: ntiles_KQ);
922
923 // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
924 // Test whether parallel_blocks can be set to a higher value for better efficiency.
925 const int blocks_per_wave = nsm * max_blocks_per_sm;
926 int nwaves_best = 0;
927 int efficiency_percent_best = 0;
928 for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
929 const int nblocks_total = ntiles_total * parallel_blocks_test;
930 const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
931 const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
932
933 // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
934 if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
935 break;
936 }
937
938 if (efficiency_percent > efficiency_percent_best) {
939 nwaves_best = nwaves;
940 efficiency_percent_best = efficiency_percent;
941 parallel_blocks = parallel_blocks_test;
942 }
943 }
944
945 blocks_num.x = ntiles_x;
946 blocks_num.y = parallel_blocks;
947 blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
948
949 if (parallel_blocks > 1) {
950 dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
951 dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
952 }
953 }
954
955 float scale = 1.0f;
956 float max_bias = 0.0f;
957 float logit_softcap = 0.0f;
958
959 memcpy(dest: &scale, src: (const float *) KQV->op_params + 0, n: sizeof(float));
960 memcpy(dest: &max_bias, src: (const float *) KQV->op_params + 1, n: sizeof(float));
961 memcpy(dest: &logit_softcap, src: (const float *) KQV->op_params + 2, n: sizeof(float));
962
963 if (logit_softcap != 0.0f) {
964 scale /= logit_softcap;
965 }
966
967 const uint32_t n_head = Q->ne[2];
968 const uint32_t n_head_log2 = 1u << uint32_t(floorf(x: log2f(x: float(n_head))));
969
970 const float m0 = powf(x: 2.0f, y: -(max_bias ) / n_head_log2);
971 const float m1 = powf(x: 2.0f, y: -(max_bias / 2.0f) / n_head_log2);
972
973 GGML_ASSERT(block_dim.x % warp_size == 0);
974 fattn_kernel<<<gridDim: blocks_num, blockDim: block_dim, sharedMem: nbytes_shared, stream: main_stream>>>(
975 (const char *) Q->data,
976 K_data,
977 V_data,
978 mask ? ((const char *) mask->data) : nullptr,
979 sinks ? ((const char *) sinks->data) : nullptr,
980 KV_max.ptr,
981 !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
982 scale, max_bias, m0, m1, n_head_log2, logit_softcap,
983 Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
984 K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
985 nb21, nb22, nb23,
986 mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
987 mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
988 );
989 CUDA_CHECK(cudaGetLastError());
990
991 if (stream_k) {
992 if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
993 const dim3 block_dim_combine(DV, 1, 1);
994 const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
995
996 flash_attn_stream_k_fixup<DV, ncols1, ncols2>
997 <<<gridDim: blocks_num_combine, blockDim: block_dim_combine, sharedMem: 0, stream: main_stream>>>
998 ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
999 }
1000 } else if (parallel_blocks > 1) {
1001 const dim3 block_dim_combine(DV, 1, 1);
1002 const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
1003 const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
1004
1005 flash_attn_combine_results<DV>
1006 <<<gridDim: blocks_num_combine, blockDim: block_dim_combine, sharedMem: nbytes_shared_combine, stream: main_stream>>>
1007 (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
1008 }
1009 CUDA_CHECK(cudaGetLastError());
1010}
1011