1#include "common.cuh"
2#include "fattn-common.cuh"
3#include "fattn-wmma-f16.cuh"
4
5// nbatch_fa == number of KQ rows to process per iteration
6// nbatch_K == number of K columns to load in parallel for KQ calculation
7
8// TODO optimize kernel parameters for FP16 NVIDIA (P100)
9// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
10
11// The ROCm compiler cannot handle templating in __launch_bounds__.
12// As a workaround, define a macro to package the kernel parameters as uint32_t:
13#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \
14 if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
15 static_assert((nthreads) <= 512, "bad nthreads"); \
16 static_assert((occupancy) <= 8, "bad occupancy"); \
17 static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \
18 static_assert((nbatch_K) <= 256, "bad nbatch_K"); \
19 return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \
20 } \
21
22static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) {
23 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40)
24 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40)
25 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40)
26 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40)
27 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40)
28
29 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64)
30 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64)
31 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64)
32 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64)
33 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
34
35 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72)
36 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72)
37 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72)
38 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72)
39 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72)
40
41 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40)
42 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40)
43 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40)
44 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40)
45 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40)
46
47 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48)
48 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48)
49 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48)
50 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48)
51 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48)
52
53 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56)
54 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56)
55 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56)
56 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56)
57 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56)
58
59 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64)
60 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64)
61 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64)
62 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64)
63 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
64
65 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64)
66 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64)
67 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64)
68 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
69 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
70
71 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
72
73 return 0;
74}
75
76static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) {
77 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
78 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
79 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
80 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
81 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
82
83 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64)
84 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64)
85 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64)
86 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64)
87 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
88
89 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
90 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
91 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
92 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
93 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
94
95 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
96 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
97 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
98 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
99 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
100
101 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
102 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
103 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
104 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
105 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
106
107 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
108 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
109 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
110 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
111 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
112
113 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64)
114 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128)
115 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128)
116 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128)
117 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
118
119 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64)
120 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64)
121 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256)
122 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
123 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
124
125 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
126
127 return 0;
128}
129
130static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
131 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
132 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
133 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
134 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
135 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
136 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
137
138 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64)
139 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64)
140 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64)
141 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64)
142 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
143 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
144
145 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
146 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
147 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
148 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
149 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
150 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
151
152 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
153 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
154 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
155 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
156 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
157 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
158
159 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
160 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
161 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
162 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
163 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
164 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
165
166 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
167 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
168 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
169 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
170 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
171 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
172
173 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64)
174 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128)
175 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128)
176 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128)
177 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
178 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
179
180 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
181 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
182 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
183 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
184 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
185
186 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
187 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
188
189 return 0;
190}
191
192static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
193 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
194 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
195 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
196 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
197 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
198 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
199
200 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64)
201 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64)
202 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64)
203 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64)
204 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
205 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
206
207 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
208 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
209 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
210 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
211 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
212 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
213
214 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
215 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
216 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
217 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
218 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
219 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
220
221 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
222 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
223 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
224 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
225 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
226 GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
227
228 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
229 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
230 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
231 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
232 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
233 GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
234
235 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64)
236 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64)
237 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64)
238 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
239 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
240 GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
241
242 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
243 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
244 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
245 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
246 GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
247
248 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
249 GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
250
251 return 0;
252}
253
254static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
255 if (GGML_CUDA_CC_IS_AMD(cc)) {
256 if (GGML_CUDA_CC_IS_RDNA(cc)) {
257 return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
258 }
259 return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
260 }
261 if (fast_fp16_available(cc)) {
262 return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
263 }
264 return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
265}
266
267static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {
268#ifdef GGML_USE_HIP
269#ifdef RDNA
270 return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
271#else
272 return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
273#endif // RDNA
274#else
275#ifdef FAST_FP16_AVAILABLE
276 return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
277#else
278 return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
279#endif // FAST_FP16_AVAILABLE
280#endif // GGML_USE_HIP
281}
282
283static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
284 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1);
285}
286
287static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) {
288 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1);
289}
290
291static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
292 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1);
293}
294
295static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) {
296 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1);
297}
298
299static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
300 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1);
301}
302
303static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
304 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1);
305}
306
307static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) {
308 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1);
309}
310
311static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) {
312 return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1);
313}
314
315// TODO: deduplicate with mma-f16
316template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
317static __device__ __forceinline__ void flash_attn_tile_load_tile(
318 const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
319 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
320 constexpr int cpy_ne = cpy_nb / 4;
321
322 auto load = [&] __device__ (const int n) {
323 const int stride_j = warp_size >> n;
324
325 if (stride_j == 0) {
326 return;
327 }
328
329 const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j);
330 const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j);
331 const int stride_i = warp_size / stride_j;
332
333 if (j0_start == j0_stop) {
334 return;
335 }
336
337#pragma unroll
338 for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
339 const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
340
341 if (i0 + nwarps*stride_i <= I || i < I) {
342#pragma unroll
343 for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
344 const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
345
346 const half2 zero[cpy_ne] = {{0.0f, 0.0f}};
347 ggml_cuda_memcpy_1<cpy_nb>(
348 dst: tile_KV + i*(J/2 + J_padding) + j,
349 src: !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
350 }
351 }
352 }
353 };
354 // 1: max 64*16=512 bytes, 512 half
355 // 2: max 32*16=512 bytes, 256 half
356 // 3: max 16*16=256 bytes, 128 half
357 // 4: max 8*16=128 bytes, 64 half
358 // 5: max 4*16= 64 bytes, 32 half
359 // 6: max 2*16= 32 bytes, 16 half
360 // 7: max 1*16= 16 bytes, 8 half
361 static_assert(J % 8 == 0, "bad J");
362 static_assert((J/2) % cpy_ne == 0, "bad J");
363 ggml_cuda_unroll<7>{}(load);
364}
365
366template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
367static __device__ __forceinline__ void flash_attn_tile_load_tile(
368 const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
369 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
370 constexpr int cpy_ne = cpy_nb / 4;
371
372 auto load = [&] __device__ (const int n) {
373 const int stride_j = warp_size >> n;
374
375 if (stride_j == 0) {
376 return;
377 }
378
379 const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j);
380 const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j);
381 const int stride_i = warp_size / stride_j;
382
383 if (j0_start == j0_stop) {
384 return;
385 }
386
387#pragma unroll
388 for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
389 const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
390
391 if (i0 + nwarps*stride_i <= I || i < I) {
392#pragma unroll
393 for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
394 const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
395
396 const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
397 half2 tmp_h2[cpy_ne/2];
398 ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
399 dst: tmp_h2, src: !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
400
401 float2 tmp_f2[cpy_ne/2];
402#pragma unroll
403 for (int l = 0; l < cpy_ne/2; ++l) {
404 tmp_f2[l] = __half22float2(a: tmp_h2[l]);
405 }
406 ggml_cuda_memcpy_1<sizeof(tmp_f2)>(dst: tile_KV + i*(J + J_padding) + 2*j, src: tmp_f2);
407 }
408 }
409 }
410 };
411 // 1: max 32*16=512 bytes, 128 float
412 // 2: max 16*16=256 bytes, 64 float
413 // 3: max 8*16=128 bytes, 32 float
414 // 4: max 4*16= 64 bytes, 16 float
415 // 5: max 2*16= 32 bytes, 8 float
416 static_assert(J % 8 == 0, "bad J");
417 static_assert(J % cpy_ne == 0, "bad J");
418 ggml_cuda_unroll<5>{}(load);
419}
420
421// Function that performs a single iteration in for the KQ matrix multiplication:
422template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int nbatch_fa, int nbatch_K,
423 bool use_logit_softcap, bool oob_check, typename T_vec_dot>
424static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
425 T_vec_dot * const Q_tmp,
426 const half2 * const __restrict__ K_h2,
427 T_vec_dot * const KV_tmp,
428 const int stride_K2,
429 const int k_VKQ_0,
430 const int k_VKQ_sup,
431 const int k_KQ_0,
432 float * KQ_acc) {
433 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
434 constexpr int cpy_ne = cpy_nb / 4;
435
436 constexpr int ncols = ncols1*ncols2;
437 constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
438 constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
439
440 flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
441 (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
442 __syncthreads();
443
444#ifdef FAST_FP16_AVAILABLE
445 static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
446#pragma unroll
447 for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
448 half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
449 half2 Q_k[cpw][cpy_ne];
450#else
451 static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
452#pragma unroll
453 for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
454 float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
455 float Q_k[cpw][cpy_ne];
456#endif // FAST_FP16_AVAILABLE
457
458#pragma unroll
459 for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
460 const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
461
462#ifdef FAST_FP16_AVAILABLE
463 ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]);
464#else
465 ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]);
466#endif // FAST_FP16_AVAILABLE
467 }
468#pragma unroll
469 for (int jc0 = 0; jc0 < cpw; ++jc0) {
470 const int jc = jc0 + (threadIdx.y / np)*cpw;
471
472#ifdef FAST_FP16_AVAILABLE
473 ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]);
474#else
475 ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]);
476#endif // FAST_FP16_AVAILABLE
477 }
478
479#pragma unroll
480 for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
481#pragma unroll
482 for (int jc0 = 0; jc0 < cpw; ++jc0) {
483#pragma unroll
484 for (int k = 0; k < cpy_ne; ++k) {
485 ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]);
486 }
487 }
488 }
489 }
490
491 if (k_KQ_0 + nbatch_K < DKQ) {
492 __syncthreads(); // Sync not needed on last iteration.
493 }
494}
495
496// Function that performs a single iteration of the main loop over up to nbatch_fa tokens.
497template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int DV, int nbatch_fa, int nbatch_K,
498 bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc>
499static __device__ __forceinline__ void flash_attn_tile_iter(
500 T_vec_dot * const Q_tmp,
501 const half2 * const __restrict__ K_h2,
502 const half2 * const __restrict__ V_h2,
503 const half * const __restrict__ mask,
504 const float logit_softcap,
505 const float slope,
506 T_KQ * const KQ,
507 T_vec_dot * const KV_tmp,
508 const int stride_K2,
509 const int stride_V2,
510 const int stride_mask,
511 float * const KQ_max,
512 float * const KQ_sum,
513 T_acc * const VKQ,
514 const int k_VKQ_0,
515 const int k_VKQ_max) {
516 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
517 constexpr int cpy_ne = cpy_nb / 4;
518
519 constexpr int ncols = ncols1*ncols2;
520 constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
521 constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
522
523 constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
524
525 // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory.
526 // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs].
527#ifdef FAST_FP16_AVAILABLE
528 constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
529#else
530 constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
531#endif // FAST_FP16_AVAILABLE
532 static_assert(cpw % KQ_cs == 0, "bad KQ_cs");
533 const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data
534
535 float KQ_max_new[cpw];
536#pragma unroll
537 for (int jc0 = 0; jc0 < cpw; ++jc0) {
538 KQ_max_new[jc0] = KQ_max[jc0];
539 }
540
541 float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication.
542
543 // KQ = K @ Q matrix multiplication:
544 constexpr int nbatch_K_last = DKQ % nbatch_K;
545#pragma unroll
546 for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {
547 flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(
548 Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
549 }
550 if (nbatch_K_last > 0) {
551 constexpr int k_KQ_0 = DKQ - nbatch_K_last;
552 flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(
553 Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
554 }
555
556 // Apply logit softcap + mask, update KQ_max:
557#pragma unroll
558 for (int jc0 = 0; jc0 < cpw; ++jc0) {
559 const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2;
560
561#pragma unroll
562 for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
563 const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
564
565 if (use_logit_softcap) {
566 KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
567 }
568
569 if (!oob_check || i_KQ < k_VKQ_sup) {
570 KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
571 slope*__half2float(a: mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
572
573 KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
574 }
575 }
576
577 KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
578 }
579
580 if constexpr (np == 1) {
581 __syncthreads();
582 } else {
583 static_assert(cpw == 1, "bad cpw");
584 __shared__ float KQ_max_new_shared[nwarps];
585 if (threadIdx.x == 0) {
586 KQ_max_new_shared[threadIdx.y] = KQ_max_new[0];
587 }
588 __syncthreads();
589 KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np];
590 KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
591 }
592
593 // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
594#pragma unroll
595 for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
596#ifdef FAST_FP16_AVAILABLE
597 half tmp[nbatch_fa/(np*warp_size)][KQ_cs];
598#else
599 float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
600#endif // FAST_FP16_AVAILABLE
601
602#pragma unroll
603 for (int jc1 = 0; jc1 < KQ_cs; ++jc1) {
604 const int jc = jc0 + jc1;
605
606 const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]);
607 KQ_max[jc] = KQ_max_new[jc];
608
609 float KQ_sum_add = 0.0f;
610#pragma unroll
611 for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
612 const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
613 expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
614 KQ_sum_add += val;
615 tmp[i0/(np*warp_size)][jc1] = val;
616 }
617 KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
618
619#ifdef FAST_FP16_AVAILABLE
620 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
621#pragma unroll
622 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
623 VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
624 }
625#else
626#pragma unroll
627 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
628 VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
629 VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
630 }
631#endif // FAST_FP16_AVAILABLE
632 }
633
634#pragma unroll
635 for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
636 const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x;
637
638 ggml_cuda_memcpy_1<sizeof(tmp[0])>(
639 KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs,
640 tmp[i0/(np*warp_size)]);
641 }
642 }
643
644 // VKQ = V @ KQ matrix multiplication:
645 static_assert(DV <= DKQ, "bad DV");
646 static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K");
647 constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.
648 static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V");
649 static_assert(nbatch_V % np == 0, "bad nbatch_V");
650#pragma unroll
651 for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
652 flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
653 (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
654 __syncthreads();
655
656#ifdef FAST_FP16_AVAILABLE
657#pragma unroll
658 for (int k1 = 0; k1 < nbatch_V; k1 += np) {
659 half2 V_k[(DVp/2)/warp_size];
660 half2 KQ_k[cpw];
661
662 constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
663#pragma unroll
664 for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
665 ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]);
666 }
667#pragma unroll
668 for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
669 const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
670
671 half tmp[KQ_cs];
672 ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(
673 &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
674#pragma unroll
675 for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) {
676 KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]);
677 }
678 }
679
680#pragma unroll
681 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
682#pragma unroll
683 for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
684 VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0];
685 }
686 }
687 }
688#else
689#pragma unroll
690 for (int k1 = 0; k1 < nbatch_V; k1 += np) {
691 float2 V_k[(DVp/2)/warp_size];
692 float KQ_k[cpw];
693
694 constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
695#pragma unroll
696 for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
697 ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]);
698 }
699#pragma unroll
700 for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
701 const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
702
703 ggml_cuda_memcpy_1<KQ_cs*sizeof(float)>(
704 &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
705 }
706
707#pragma unroll
708 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
709#pragma unroll
710 for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
711 VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0];
712 VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0];
713 }
714 }
715 }
716#endif // FAST_FP16_AVAILABLE
717
718 __syncthreads();
719 }
720}
721
722template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size
723__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2))
724static __global__ void flash_attn_tile(
725 const char * __restrict__ Q,
726 const char * __restrict__ K,
727 const char * __restrict__ V,
728 const char * __restrict__ mask,
729 const char * __restrict__ sinks,
730 const int * __restrict__ KV_max,
731 float * __restrict__ dst,
732 float2 * __restrict__ dst_meta,
733 const float scale,
734 const float max_bias,
735 const float m0,
736 const float m1,
737 const uint32_t n_head_log2,
738 const float logit_softcap,
739 const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
740 const int32_t nb01, const int32_t nb02, const int32_t nb03,
741 const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
742 const int32_t nb11, const int32_t nb12, const int64_t nb13,
743 const int32_t nb21, const int32_t nb22, const int64_t nb23,
744 const int32_t ne31, const int32_t ne32, const int32_t ne33,
745 const int32_t nb31, const int32_t nb32, const int64_t nb33) {
746#ifdef FLASH_ATTN_AVAILABLE
747
748 // Skip unused kernel variants for faster compilation:
749
750 if (
751#ifdef GGML_USE_WMMA_FATTN
752 (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
753#endif // GGML_USE_WMMA_FATTN
754 (use_logit_softcap && !(DV == 128 || DV == 256))
755 ) {
756 GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
757 max_bias, m0, m1, n_head_log2, logit_softcap,
758 ne00, ne01, ne02, ne03,
759 nb01, nb02, nb03,
760 ne10, ne11, ne12, ne13,
761 nb11, nb12, nb13,
762 nb21, nb22, nb23,
763 ne31, ne32, ne33,
764 nb31, nb32, nb33);
765 NO_DEVICE_CODE;
766 return;
767 }
768
769 static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols: ncols1*ncols2) != 0, "kernel config not defined");
770
771 constexpr int ncols = ncols1*ncols2;
772 constexpr int warp_size = 32;
773 constexpr int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols: ncols1*ncols2) / warp_size;
774 constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols: ncols1*ncols2);
775 constexpr int nbatch_K = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols: ncols1*ncols2);
776
777 // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
778
779 const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on.
780
781 const int sequence = blockIdx.z / (ne02/ncols2);
782 const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
783 const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
784 const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0);
785 const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
786 const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
787
788 const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr;
789
790 const int stride_K2 = nb11 / sizeof(half2);
791 const int stride_V2 = nb21 / sizeof(half2);
792 const int stride_mask = nb31 / sizeof(half);
793
794 const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, h: head0, n_head_log2, m0, m1) : 1.0f;
795
796 constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
797 constexpr int cpy_ne = cpy_nb / 4;
798
799 constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp.
800 constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column.
801 static_assert(cpw == 1 || np == 1, "bad cpw / np");
802 static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0");
803
804 constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.
805 constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
806
807 // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.
808 // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.
809 // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).
810 // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications.
811 // VKQ == Accumulators in registers for the final VKQ result.
812#ifdef FAST_FP16_AVAILABLE
813 __shared__ half2 Q_tmp[ncols * DKQ/2];
814 __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
815 __shared__ half KQ[ncols * nbatch_fa];
816 half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
817#else
818 __shared__ float Q_tmp[ncols * DKQ];
819 __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
820 __shared__ float KQ[ncols * nbatch_fa];
821 float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
822#endif // FAST_FP16_AVAILABLE
823
824 float KQ_max[cpw];
825#pragma unroll
826 for (int j0 = 0; j0 < ncols; j0 += nwarps) {
827 KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
828 }
829 float KQ_sum[cpw] = {0.0f};
830
831 // Load Q data, convert to FP16 if fast:
832#pragma unroll
833 for (int jc0 = 0; jc0 < cpw; ++jc0) {
834 const int jc = jc0 + (threadIdx.y / np)*cpw;
835
836 const int j = jc / ncols2;
837 const int c = jc % ncols2;
838
839 constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;
840
841#pragma unroll
842 for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
843 if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
844 float tmp_f[cpy_ne_D] = {0.0f};
845 if (ncols1 == 1 || col_Q_0 + j < ne01) {
846 ggml_cuda_memcpy_1<sizeof(tmp_f)>
847 (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float))
848 + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
849 }
850
851#pragma unroll
852 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
853 tmp_f[i1] *= scale;
854 }
855
856#ifdef FAST_FP16_AVAILABLE
857 half2 tmp_h2[cpy_ne_D/2];
858#pragma unroll
859 for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
860 tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
861 }
862 ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
863 &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)],
864 tmp_h2);
865#else
866 ggml_cuda_memcpy_1<sizeof(tmp_f)>(
867 &Q_tmp[jc* DKQ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x* cpy_ne_D],
868 tmp_f);
869#endif // FAST_FP16_AVAILABLE
870 }
871 }
872 }
873
874 __syncthreads();
875
876 // Main loop over KV cache:
877 const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
878 if (ncols2 == 1) {
879 // Branch with out-of-bounds checks.
880 int k_VKQ_0 = blockIdx.y*nbatch_fa;
881 while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
882 constexpr bool oob_check = false;
883 flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
884 (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
885 stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
886 k_VKQ_0 += gridDim.y*nbatch_fa;
887 }
888 if (k_VKQ_0 < k_VKQ_max) {
889 constexpr bool oob_check = true;
890 flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
891 (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
892 stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
893 }
894 } else {
895 // Branch without out-of-bounds checks.
896 for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
897 constexpr bool oob_check = false;
898 flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
899 (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
900 stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
901 }
902 }
903
904#pragma unroll
905 for (int jc0 = 0; jc0 < cpw; ++jc0) {
906 KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]);
907 }
908
909 if constexpr (np > 1) {
910 static_assert(cpw == 1, "bad cpw");
911 static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small");
912
913#ifdef FAST_FP16_AVAILABLE
914 half2 * VKQ_combine = (half2 *) KV_tmp;
915#else
916 float * VKQ_combine = (float *) KV_tmp;
917#endif // FAST_FP16_AVAILABLE
918 float * KQ_sum_combine = (float *) Q_tmp;
919
920 if (threadIdx.y % np != 0) {
921#ifdef FAST_FP16_AVAILABLE
922 constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
923#pragma unroll
924 for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
925 ggml_cuda_memcpy_1<cpy_ne_D*4>(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]);
926 }
927#else
928 constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
929#pragma unroll
930 for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
931 ggml_cuda_memcpy_1<cpy_ne_D*4>(
932 &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size);
933 }
934#endif // FAST_FP16_AVAILABLE
935
936 if (threadIdx.x == 0) {
937 KQ_sum_combine[threadIdx.y] = KQ_sum[0];
938 }
939
940 return;
941 }
942
943 __syncthreads();
944
945#pragma unroll
946 for (int ip = 1; ip < np; ++ip) {
947#ifdef FAST_FP16_AVAILABLE
948 constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
949#pragma unroll
950 for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
951 half2 tmp[cpy_ne_D];
952 ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
953#pragma unroll
954 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
955 VKQ[i0/warp_size + i1] += tmp[i1];
956 }
957 }
958#else
959 constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
960#pragma unroll
961 for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
962 float tmp[cpy_ne_D];
963 ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
964#pragma unroll
965 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
966 ((float *)VKQ)[i0/warp_size + i1] += tmp[i1];
967 }
968 }
969#endif // FAST_FP16_AVAILABLE
970
971 KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip];
972 }
973 }
974
975 // Attention sink: adjust KQ max and sum only for the first of all parallel blocks:
976 if (sinks && blockIdx.y == 0) {
977#pragma unroll
978 for (int jc0 = 0; jc0 < cpw; ++jc0) {
979 const int jc = jc0 + (threadIdx.y/np)*cpw;
980 const float sink = ((const float *) sinks)[head0 + jc % ncols2];
981
982 float KQ_max_new_j = fmaxf(KQ_max[jc0], sink);
983 const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j);
984 KQ_max[jc0] = KQ_max_new_j;
985
986 const float val = expf(sink - KQ_max[jc0]);
987 KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val;
988
989#ifdef FAST_FP16_AVAILABLE
990 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
991#pragma unroll
992 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
993 VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
994 }
995#else
996#pragma unroll
997 for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
998 VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
999 VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
1000 }
1001#endif // FAST_FP16_AVAILABLE
1002 }
1003 }
1004
1005 // Write back results:
1006#pragma unroll
1007 for (int jc0 = 0; jc0 < cpw; ++jc0) {
1008 const int jc = jc0 + (threadIdx.y/np)*cpw;
1009
1010 const int j = jc / ncols2;
1011 const int c = jc % ncols2;
1012
1013 if (ncols1 > 1 && col_Q_0 + j >= ne01) {
1014 return;
1015 }
1016
1017 const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
1018
1019 const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
1020
1021#ifdef FAST_FP16_AVAILABLE
1022 constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
1023#pragma unroll
1024 for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
1025 float2 tmp[cpy_ne_D];
1026#pragma unroll
1027 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
1028 tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
1029 tmp[i1].x *= scale;
1030 tmp[i1].y *= scale;
1031 }
1032 if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
1033 ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
1034 }
1035 }
1036#else
1037 constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
1038#pragma unroll
1039 for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
1040 if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
1041#pragma unroll
1042 for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
1043 VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
1044 VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
1045 }
1046 ggml_cuda_memcpy_1<cpy_ne_D*4>(
1047 &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
1048 &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
1049 }
1050 }
1051#endif // FAST_FP16_AVAILABLE
1052
1053 if (gridDim.y != 1 && threadIdx.x == 0) {
1054 dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]);
1055 }
1056 }
1057#else
1058 GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
1059 max_bias, m0, m1, n_head_log2, logit_softcap,
1060 ne00, ne01, ne02, ne03,
1061 nb01, nb02, nb03,
1062 ne10, ne11, ne12, ne13,
1063 nb11, nb12, nb13,
1064 nb21, nb22, nb23,
1065 ne31, ne32, ne33,
1066 nb31, nb32, nb33);
1067 NO_DEVICE_CODE;
1068#endif // FLASH_ATTN_AVAILABLE
1069}
1070
1071template <int DKQ, int DV, int ncols2, bool use_logit_softcap>
1072static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1073 const ggml_tensor * Q = dst->src[0];
1074
1075 const int id = ggml_cuda_get_device();
1076 const int cc = ggml_cuda_info().devices[id].cc;
1077 const int warp_size = 32;
1078
1079 constexpr size_t nbytes_shared = 0;
1080
1081#ifdef GGML_USE_HIP
1082 if constexpr (DV <= 128) {
1083 if (Q->ne[1] > 32/ncols2) {
1084 constexpr int cols_per_block = 64;
1085 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1086 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1087 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1088 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1089 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1090 return;
1091 }
1092 }
1093#endif // GGML_USE_HIP
1094
1095#ifndef GGML_USE_HIP
1096 if constexpr (DV <= 256)
1097#endif // GGML_USE_HIP
1098 {
1099 if (Q->ne[1] > 16/ncols2) {
1100 constexpr int cols_per_block = 32;
1101 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols: cols_per_block, cc) / warp_size;
1102 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols: cols_per_block, cc);
1103 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1104 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1105 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1106 return;
1107 }
1108 }
1109
1110 if (Q->ne[1] > 8/ncols2) {
1111 constexpr int cols_per_block = 16;
1112 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols: cols_per_block, cc) / warp_size;
1113 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols: cols_per_block, cc);
1114 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1115 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1116 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1117 return;
1118 }
1119
1120 if constexpr (ncols2 <= 8) {
1121 if (Q->ne[1] > 4/ncols2) {
1122 constexpr int cols_per_block = 8;
1123 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols: cols_per_block, cc) / warp_size;
1124 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols: cols_per_block, cc);
1125 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1126 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1127 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1128 return;
1129 }
1130 }
1131
1132 if constexpr (ncols2 <= 4) {
1133 if (Q->ne[1] > 2/ncols2) {
1134 constexpr int cols_per_block = 4;
1135 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols: cols_per_block, cc) / warp_size;
1136 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols: cols_per_block, cc);
1137 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1138 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1139 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1140 return;
1141 }
1142 }
1143
1144 if constexpr (ncols2 <= 2) {
1145 constexpr int cols_per_block = 2;
1146 const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols: cols_per_block, cc) / warp_size;
1147 const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols: cols_per_block, cc);
1148 fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1149 launch_fattn<DV, cols_per_block/ncols2, ncols2>
1150 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1151 return;
1152 }
1153
1154 GGML_ABORT("fatal error");
1155}
1156
1157template <int DKQ, int DV, bool use_logit_softcap>
1158static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1159 const ggml_tensor * KQV = dst;
1160 const ggml_tensor * Q = dst->src[0];
1161 const ggml_tensor * K = dst->src[1];
1162 const ggml_tensor * mask = dst->src[3];
1163
1164 float max_bias = 0.0f;
1165 memcpy(dest: &max_bias, src: (const float *) KQV->op_params + 1, n: sizeof(float));
1166
1167 GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
1168 const int gqa_ratio = Q->ne[2] / K->ne[2];
1169
1170 const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
1171 const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
1172 const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
1173
1174 if constexpr (DV == 512) {
1175 if (use_gqa_opt && gqa_ratio % 16 == 0) {
1176 launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
1177 return;
1178 }
1179 }
1180
1181 if constexpr (DV <= 256) {
1182 if (use_gqa_opt && gqa_ratio % 8 == 0) {
1183 launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
1184 return;
1185 }
1186
1187 if (use_gqa_opt && gqa_ratio % 4 == 0) {
1188 launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
1189 return;
1190 }
1191
1192 if (use_gqa_opt && gqa_ratio % 2 == 0) {
1193 launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
1194 return;
1195 }
1196
1197 launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
1198 return;
1199 }
1200 GGML_ABORT("fatal error");
1201}
1202
1203template <int DKQ, int DV>
1204void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1205 const ggml_tensor * KQV = dst;
1206
1207 float logit_softcap;
1208 memcpy(dest: &logit_softcap, src: (const float *) KQV->op_params + 2, n: sizeof(float));
1209
1210 if (logit_softcap == 0.0f) {
1211 constexpr bool use_logit_softcap = false;
1212 launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
1213 } else {
1214 constexpr bool use_logit_softcap = true;
1215 launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
1216 }
1217}
1218
1219void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
1220
1221#define DECL_FATTN_TILE_CASE(DKQ, DV) \
1222 template void ggml_cuda_flash_attn_ext_tile_case \
1223 <DKQ, DV>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
1224
1225extern DECL_FATTN_TILE_CASE( 40, 40);
1226extern DECL_FATTN_TILE_CASE( 64, 64);
1227extern DECL_FATTN_TILE_CASE( 72, 72);
1228extern DECL_FATTN_TILE_CASE( 80, 80);
1229extern DECL_FATTN_TILE_CASE( 96, 96);
1230extern DECL_FATTN_TILE_CASE(112, 112);
1231extern DECL_FATTN_TILE_CASE(128, 128);
1232extern DECL_FATTN_TILE_CASE(256, 256);
1233extern DECL_FATTN_TILE_CASE(576, 512);
1234