1#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
2#define USE_CUB
3#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
4
5#ifdef USE_CUB
6#include <cub/cub.cuh>
7using namespace cub;
8#endif // USE_CUB
9
10#include "ssm-scan.cuh"
11
12// We would like to keep pragma unroll for cases where L_template is not 0,
13// so we suppress the clang transformation warning.
14#ifdef __clang__
15#pragma clang diagnostic push
16#pragma clang diagnostic ignored "-Wpass-failed"
17#endif // __clang__
18template <size_t splitD, size_t N, size_t L_template>
19__global__ void __launch_bounds__(splitD, 1)
20 ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
21 const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
22 const int32_t * __restrict__ src6, float * __restrict__ dst,
23 const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
24 const int src2_nb1, const int src2_nb2, const int src3_nb1,
25 const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
26 const int64_t s_off, const int64_t d_inner, const int64_t L_param)
27{
28 const size_t L = L_template == 0 ? L_param : L_template;
29 const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
30 const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
31 const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
32 const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
33 const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
34 const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
35 float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
36 float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
37
38 const int stride_x = src1_nb2 / sizeof(float);
39 const int stride_dt = src2_nb1 / sizeof(float);
40 const int stride_B = src4_nb2 / sizeof(float);
41 const int stride_C = src5_nb2 / sizeof(float);
42 const int stride_y = d_inner;
43
44 float regA[N];
45 float regs0[N];
46
47 __shared__ float smemB[N];
48 __shared__ float smemC[N];
49
50#ifdef USE_CUB
51 using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
52 using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
53
54 union CubTempStorage {
55 typename BlockLoad::TempStorage load_temp;
56 typename BlockStore::TempStorage store_temp;
57 };
58 __shared__ CubTempStorage cub_temp_storage;
59
60 BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
61 BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
62#else
63 const int stride_s0 = src0_nb2 / sizeof(float);
64 const int stride_A = src3_nb1 / sizeof(float);
65#pragma unroll
66 for (size_t n = 0; n < N; ++n)
67 {
68 regA[n] = A_block[threadIdx.x * stride_A + n];
69 regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
70 }
71#endif
72
73#pragma unroll
74 for (size_t i = 0; i < L; i++)
75 {
76 if (threadIdx.x < N)
77 {
78 smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
79 smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
80 }
81 __syncthreads();
82
83 float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
84 if (dt_soft_plus <= 20.0f)
85 {
86 dt_soft_plus = log1pf(a: expf(a: dt_soft_plus));
87 }
88 float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
89
90 float sumf = 0.0f;
91#pragma unroll
92 for (size_t n = 0; n < N; n++)
93 {
94 float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
95 sumf += state * smemC[n];
96 regs0[n] = state;
97 }
98 y_block[i * stride_y + threadIdx.x] = sumf;
99 }
100
101#ifdef USE_CUB
102 BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
103#else
104 const int stride_s = stride_s0;
105#pragma unroll
106 for (size_t n = 0; n < N; ++n)
107 {
108 s_block[threadIdx.x * stride_s + n] = regs0[n];
109 }
110#endif
111}
112#ifdef __clang__
113#pragma clang diagnostic pop
114#endif // __clang__
115
116// assumes as many threads as d_state
117template <int splitH, int d_state>
118__global__ void __launch_bounds__(d_state, 1)
119 ssm_scan_f32_group(
120 const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
121 const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
122 const int32_t * __restrict__ src6, float * __restrict__ dst,
123 const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
124 const int src2_nb1, const int src2_nb2, const int src3_nb1,
125 const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
126 const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
127
128 const int head_idx = (blockIdx.x * splitH) / d_head;
129 const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
130 const int seq_idx = blockIdx.y;
131
132 const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
133
134 const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
135 const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
136 const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
137 const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
138 const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
139 const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
140 float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
141 float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
142
143 // strides across n_seq_tokens
144 const int stride_x = src1_nb2 / sizeof(float);
145 const int stride_dt = src2_nb1 / sizeof(float);
146 const int stride_B = src4_nb2 / sizeof(float);
147 const int stride_C = src5_nb2 / sizeof(float);
148 const int stride_y = n_head * d_head;
149
150 float state[splitH];
151 // for the parallel accumulation
152 __shared__ float stateC[splitH * d_state];
153
154#pragma unroll
155 for (int j = 0; j < splitH; j++) {
156 state[j] = s0_block[j * d_state + threadIdx.x];
157 }
158
159 for (int64_t i = 0; i < n_tok; i++) {
160 // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
161 // TODO: only calculate B and C once per head group
162 // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
163 float dt_soft_plus = dt_block[i * stride_dt];
164 if (dt_soft_plus <= 20.0f) {
165 dt_soft_plus = log1pf(a: expf(a: dt_soft_plus));
166 }
167 const float dA = expf(a: dt_soft_plus * A_block[0]);
168 const float B = B_block[i * stride_B + threadIdx.x];
169 const float C = C_block[i * stride_C + threadIdx.x];
170
171 // across d_head
172#pragma unroll
173 for (int j = 0; j < splitH; j++) {
174 const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
175
176 state[j] = (state[j] * dA) + (B * x_dt);
177
178 stateC[j * d_state + threadIdx.x] = state[j] * C;
179 }
180
181 __syncthreads();
182
183 // parallel accumulation for stateC
184 // TODO: simplify
185 {
186 static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
187 static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
188
189 // reduce until w matches the warp size
190 // TODO: does this work even when the physical warp size is 64?
191#pragma unroll
192 for (int w = d_state; w > WARP_SIZE; w >>= 1) {
193 // (assuming there are d_state threads)
194#pragma unroll
195 for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
196 // TODO: check for bank conflicts
197 const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
198 stateC[k] += stateC[k + (w >> 1)];
199
200 }
201 __syncthreads();
202 }
203
204 static_assert(splitH >= d_state / WARP_SIZE);
205
206#pragma unroll
207 for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
208 float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
209 y = warp_reduce_sum(x: y);
210
211 // store the above accumulations
212 if (threadIdx.x % WARP_SIZE == 0) {
213 const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
214 y_block[i * stride_y + k] = y;
215 }
216 }
217 }
218 }
219
220 // write back the state
221#pragma unroll
222 for (int j = 0; j < splitH; j++) {
223 s_block[j * d_state + threadIdx.x] = state[j];
224 }
225}
226
227static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
228 const float * src4, const float * src5, const int32_t * src6, float * dst,
229 const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
230 const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
231 const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
232 const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
233 cudaStream_t stream) {
234 const int threads = 128;
235 // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
236 if (src3_nb1 == sizeof(float)) {
237 // Mamba-2
238 if (d_state == 128) {
239 GGML_ASSERT(d_state % threads == 0);
240 // NOTE: can be any power of two between 4 and 64
241 const int splitH = 16;
242 GGML_ASSERT(head_dim % splitH == 0);
243 const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
244 ssm_scan_f32_group<16, 128><<<gridDim: blocks, blockDim: threads, sharedMem: 0, stream>>>(
245 src0, src1, src2, src3, src4, src5, src6, dst,
246 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
247 src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, d_head: head_dim, n_group, n_tok);
248 } else if (d_state == 256) { // Falcon-H1
249 const int threads = 256;
250 // NOTE: can be any power of two between 8 and 64
251 const int splitH = 16;
252 GGML_ASSERT(head_dim % splitH == 0);
253 const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
254 ssm_scan_f32_group<16, 256><<<gridDim: blocks, blockDim: threads, sharedMem: 0, stream>>>(
255 src0, src1, src2, src3, src4, src5, src6, dst,
256 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
257 src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, d_head: head_dim, n_group, n_tok);
258 } else {
259 GGML_ABORT("doesn't support d_state!=(128 or 256).");
260 }
261 } else {
262 // Mamba-1
263 GGML_ASSERT(n_head % threads == 0);
264 GGML_ASSERT(head_dim == 1);
265 GGML_ASSERT(n_group == 1);
266 const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
267 const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
268 if (d_state == 16) {
269 switch (n_tok)
270 {
271 case 1:
272 ssm_scan_f32<threads, 16, 1><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>(
273 src0, src1, src2, src3, src4, src5, src6, dst,
274 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
275 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok);
276 break;
277 case 2:
278 ssm_scan_f32<threads, 16, 2><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>(
279 src0, src1, src2, src3, src4, src5, src6, dst,
280 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
281 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok);
282 break;
283 case 3:
284 ssm_scan_f32<threads, 16, 3><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>(
285 src0, src1, src2, src3, src4, src5, src6, dst,
286 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
287 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok);
288 break;
289 case 4:
290 ssm_scan_f32<threads, 16, 4><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>(
291 src0, src1, src2, src3, src4, src5, src6, dst,
292 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
293 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok);
294 break;
295 case 5:
296 ssm_scan_f32<threads, 16, 5><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>(
297 src0, src1, src2, src3, src4, src5, src6, dst,
298 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
299 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok);
300 break;
301 case 6:
302 ssm_scan_f32<threads, 16, 6><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>(
303 src0, src1, src2, src3, src4, src5, src6, dst,
304 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
305 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok);
306 break;
307 case 7:
308 ssm_scan_f32<threads, 16, 7><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>(
309 src0, src1, src2, src3, src4, src5, src6, dst,
310 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
311 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok);
312 break;
313 case 8:
314 ssm_scan_f32<threads, 16, 8><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>(
315 src0, src1, src2, src3, src4, src5, src6, dst,
316 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
317 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok);
318 break;
319 default:
320 ssm_scan_f32<threads, 16, 0><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>(
321 src0, src1, src2, src3, src4, src5, src6, dst,
322 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
323 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok);
324 break;
325 }
326 } else {
327 GGML_ABORT("doesn't support d_state!=16.");
328 }
329 }
330}
331
332void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
333 const struct ggml_tensor * src0 = dst->src[0]; // s
334 const struct ggml_tensor * src1 = dst->src[1]; // x
335 const struct ggml_tensor * src2 = dst->src[2]; // dt
336 const struct ggml_tensor * src3 = dst->src[3]; // A
337 const struct ggml_tensor * src4 = dst->src[4]; // B
338 const struct ggml_tensor * src5 = dst->src[5]; // C
339 const struct ggml_tensor * src6 = dst->src[6]; // ids
340
341 const int64_t nc = src0->ne[0]; // d_state
342 const int64_t nr = src0->ne[1]; // head_dim or 1
343 const int64_t nh = src1->ne[1]; // n_head
344 const int64_t ng = src4->ne[1]; // n_group
345 const int64_t n_t = src1->ne[2]; // number of tokens per sequence
346 const int64_t n_s = src1->ne[3]; // number of sequences in the batch
347
348 const int64_t s_off = ggml_nelements(src1) * sizeof(float);
349
350 GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
351 GGML_ASSERT(src0->nb[0] == sizeof(float));
352 GGML_ASSERT(src1->nb[0] == sizeof(float));
353 GGML_ASSERT(src2->nb[0] == sizeof(float));
354 GGML_ASSERT(src3->nb[0] == sizeof(float));
355 GGML_ASSERT(src4->nb[0] == sizeof(float));
356 GGML_ASSERT(src5->nb[0] == sizeof(float));
357 GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
358
359 const float * src0_d = (const float *) src0->data;
360 const float * src1_d = (const float *) src1->data;
361 const float * src2_d = (const float *) src2->data;
362 const float * src3_d = (const float *) src3->data;
363 const float * src4_d = (const float *) src4->data;
364 const float * src5_d = (const float *) src5->data;
365 const int32_t * src6_d = (const int32_t *) src6->data;
366 float * dst_d = (float *) dst->data;
367 cudaStream_t stream = ctx.stream();
368
369 GGML_ASSERT(src0->type == GGML_TYPE_F32);
370 GGML_ASSERT(src6->type == GGML_TYPE_I32);
371 GGML_ASSERT(dst->type == GGML_TYPE_F32);
372
373 ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
374 src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
375 src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
376 s_off, nc, nr, nh, ng, n_t, n_s, stream);
377}
378