| 1 | #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 |
| 2 | #define USE_CUB |
| 3 | #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 |
| 4 | |
| 5 | #ifdef USE_CUB |
| 6 | #include <cub/cub.cuh> |
| 7 | using namespace cub; |
| 8 | #endif // USE_CUB |
| 9 | |
| 10 | #include "ssm-scan.cuh" |
| 11 | |
| 12 | // We would like to keep pragma unroll for cases where L_template is not 0, |
| 13 | // so we suppress the clang transformation warning. |
| 14 | #ifdef __clang__ |
| 15 | #pragma clang diagnostic push |
| 16 | #pragma clang diagnostic ignored "-Wpass-failed" |
| 17 | #endif // __clang__ |
| 18 | template <size_t splitD, size_t N, size_t L_template> |
| 19 | __global__ void __launch_bounds__(splitD, 1) |
| 20 | ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, |
| 21 | const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, |
| 22 | const int32_t * __restrict__ src6, float * __restrict__ dst, |
| 23 | const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, |
| 24 | const int src2_nb1, const int src2_nb2, const int src3_nb1, |
| 25 | const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, |
| 26 | const int64_t s_off, const int64_t d_inner, const int64_t L_param) |
| 27 | { |
| 28 | const size_t L = L_template == 0 ? L_param : L_template; |
| 29 | const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); |
| 30 | const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float)); |
| 31 | const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); |
| 32 | const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1); |
| 33 | const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3)); |
| 34 | const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3)); |
| 35 | float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float)); |
| 36 | float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2); |
| 37 | |
| 38 | const int stride_x = src1_nb2 / sizeof(float); |
| 39 | const int stride_dt = src2_nb1 / sizeof(float); |
| 40 | const int stride_B = src4_nb2 / sizeof(float); |
| 41 | const int stride_C = src5_nb2 / sizeof(float); |
| 42 | const int stride_y = d_inner; |
| 43 | |
| 44 | float regA[N]; |
| 45 | float regs0[N]; |
| 46 | |
| 47 | __shared__ float smemB[N]; |
| 48 | __shared__ float smemC[N]; |
| 49 | |
| 50 | #ifdef USE_CUB |
| 51 | using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>; |
| 52 | using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>; |
| 53 | |
| 54 | union CubTempStorage { |
| 55 | typename BlockLoad::TempStorage load_temp; |
| 56 | typename BlockStore::TempStorage store_temp; |
| 57 | }; |
| 58 | __shared__ CubTempStorage cub_temp_storage; |
| 59 | |
| 60 | BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA); |
| 61 | BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0); |
| 62 | #else |
| 63 | const int stride_s0 = src0_nb2 / sizeof(float); |
| 64 | const int stride_A = src3_nb1 / sizeof(float); |
| 65 | #pragma unroll |
| 66 | for (size_t n = 0; n < N; ++n) |
| 67 | { |
| 68 | regA[n] = A_block[threadIdx.x * stride_A + n]; |
| 69 | regs0[n] = s0_block[threadIdx.x * stride_s0 + n]; |
| 70 | } |
| 71 | #endif |
| 72 | |
| 73 | #pragma unroll |
| 74 | for (size_t i = 0; i < L; i++) |
| 75 | { |
| 76 | if (threadIdx.x < N) |
| 77 | { |
| 78 | smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x]; |
| 79 | smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x]; |
| 80 | } |
| 81 | __syncthreads(); |
| 82 | |
| 83 | float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x]; |
| 84 | if (dt_soft_plus <= 20.0f) |
| 85 | { |
| 86 | dt_soft_plus = log1pf(a: expf(a: dt_soft_plus)); |
| 87 | } |
| 88 | float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus; |
| 89 | |
| 90 | float sumf = 0.0f; |
| 91 | #pragma unroll |
| 92 | for (size_t n = 0; n < N; n++) |
| 93 | { |
| 94 | float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt; |
| 95 | sumf += state * smemC[n]; |
| 96 | regs0[n] = state; |
| 97 | } |
| 98 | y_block[i * stride_y + threadIdx.x] = sumf; |
| 99 | } |
| 100 | |
| 101 | #ifdef USE_CUB |
| 102 | BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0); |
| 103 | #else |
| 104 | const int stride_s = stride_s0; |
| 105 | #pragma unroll |
| 106 | for (size_t n = 0; n < N; ++n) |
| 107 | { |
| 108 | s_block[threadIdx.x * stride_s + n] = regs0[n]; |
| 109 | } |
| 110 | #endif |
| 111 | } |
| 112 | #ifdef __clang__ |
| 113 | #pragma clang diagnostic pop |
| 114 | #endif // __clang__ |
| 115 | |
| 116 | // assumes as many threads as d_state |
| 117 | template <int splitH, int d_state> |
| 118 | __global__ void __launch_bounds__(d_state, 1) |
| 119 | ssm_scan_f32_group( |
| 120 | const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, |
| 121 | const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, |
| 122 | const int32_t * __restrict__ src6, float * __restrict__ dst, |
| 123 | const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, |
| 124 | const int src2_nb1, const int src2_nb2, const int src3_nb1, |
| 125 | const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, |
| 126 | const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { |
| 127 | |
| 128 | const int head_idx = (blockIdx.x * splitH) / d_head; |
| 129 | const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); |
| 130 | const int seq_idx = blockIdx.y; |
| 131 | |
| 132 | const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); |
| 133 | |
| 134 | const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); |
| 135 | const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); |
| 136 | const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); |
| 137 | const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1); |
| 138 | const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); |
| 139 | const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); |
| 140 | float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH; |
| 141 | float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); |
| 142 | |
| 143 | // strides across n_seq_tokens |
| 144 | const int stride_x = src1_nb2 / sizeof(float); |
| 145 | const int stride_dt = src2_nb1 / sizeof(float); |
| 146 | const int stride_B = src4_nb2 / sizeof(float); |
| 147 | const int stride_C = src5_nb2 / sizeof(float); |
| 148 | const int stride_y = n_head * d_head; |
| 149 | |
| 150 | float state[splitH]; |
| 151 | // for the parallel accumulation |
| 152 | __shared__ float stateC[splitH * d_state]; |
| 153 | |
| 154 | #pragma unroll |
| 155 | for (int j = 0; j < splitH; j++) { |
| 156 | state[j] = s0_block[j * d_state + threadIdx.x]; |
| 157 | } |
| 158 | |
| 159 | for (int64_t i = 0; i < n_tok; i++) { |
| 160 | // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements |
| 161 | // TODO: only calculate B and C once per head group |
| 162 | // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here. |
| 163 | float dt_soft_plus = dt_block[i * stride_dt]; |
| 164 | if (dt_soft_plus <= 20.0f) { |
| 165 | dt_soft_plus = log1pf(a: expf(a: dt_soft_plus)); |
| 166 | } |
| 167 | const float dA = expf(a: dt_soft_plus * A_block[0]); |
| 168 | const float B = B_block[i * stride_B + threadIdx.x]; |
| 169 | const float C = C_block[i * stride_C + threadIdx.x]; |
| 170 | |
| 171 | // across d_head |
| 172 | #pragma unroll |
| 173 | for (int j = 0; j < splitH; j++) { |
| 174 | const float x_dt = x_block[i * stride_x + j] * dt_soft_plus; |
| 175 | |
| 176 | state[j] = (state[j] * dA) + (B * x_dt); |
| 177 | |
| 178 | stateC[j * d_state + threadIdx.x] = state[j] * C; |
| 179 | } |
| 180 | |
| 181 | __syncthreads(); |
| 182 | |
| 183 | // parallel accumulation for stateC |
| 184 | // TODO: simplify |
| 185 | { |
| 186 | static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2" ); |
| 187 | static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2" ); |
| 188 | |
| 189 | // reduce until w matches the warp size |
| 190 | // TODO: does this work even when the physical warp size is 64? |
| 191 | #pragma unroll |
| 192 | for (int w = d_state; w > WARP_SIZE; w >>= 1) { |
| 193 | // (assuming there are d_state threads) |
| 194 | #pragma unroll |
| 195 | for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) { |
| 196 | // TODO: check for bank conflicts |
| 197 | const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1)); |
| 198 | stateC[k] += stateC[k + (w >> 1)]; |
| 199 | |
| 200 | } |
| 201 | __syncthreads(); |
| 202 | } |
| 203 | |
| 204 | static_assert(splitH >= d_state / WARP_SIZE); |
| 205 | |
| 206 | #pragma unroll |
| 207 | for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) { |
| 208 | float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)]; |
| 209 | y = warp_reduce_sum(x: y); |
| 210 | |
| 211 | // store the above accumulations |
| 212 | if (threadIdx.x % WARP_SIZE == 0) { |
| 213 | const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE); |
| 214 | y_block[i * stride_y + k] = y; |
| 215 | } |
| 216 | } |
| 217 | } |
| 218 | } |
| 219 | |
| 220 | // write back the state |
| 221 | #pragma unroll |
| 222 | for (int j = 0; j < splitH; j++) { |
| 223 | s_block[j * d_state + threadIdx.x] = state[j]; |
| 224 | } |
| 225 | } |
| 226 | |
| 227 | static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, |
| 228 | const float * src4, const float * src5, const int32_t * src6, float * dst, |
| 229 | const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, |
| 230 | const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, |
| 231 | const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, |
| 232 | const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, |
| 233 | cudaStream_t stream) { |
| 234 | const int threads = 128; |
| 235 | // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! |
| 236 | if (src3_nb1 == sizeof(float)) { |
| 237 | // Mamba-2 |
| 238 | if (d_state == 128) { |
| 239 | GGML_ASSERT(d_state % threads == 0); |
| 240 | // NOTE: can be any power of two between 4 and 64 |
| 241 | const int splitH = 16; |
| 242 | GGML_ASSERT(head_dim % splitH == 0); |
| 243 | const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); |
| 244 | ssm_scan_f32_group<16, 128><<<gridDim: blocks, blockDim: threads, sharedMem: 0, stream>>>( |
| 245 | src0, src1, src2, src3, src4, src5, src6, dst, |
| 246 | src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, |
| 247 | src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, d_head: head_dim, n_group, n_tok); |
| 248 | } else if (d_state == 256) { // Falcon-H1 |
| 249 | const int threads = 256; |
| 250 | // NOTE: can be any power of two between 8 and 64 |
| 251 | const int splitH = 16; |
| 252 | GGML_ASSERT(head_dim % splitH == 0); |
| 253 | const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); |
| 254 | ssm_scan_f32_group<16, 256><<<gridDim: blocks, blockDim: threads, sharedMem: 0, stream>>>( |
| 255 | src0, src1, src2, src3, src4, src5, src6, dst, |
| 256 | src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, |
| 257 | src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, d_head: head_dim, n_group, n_tok); |
| 258 | } else { |
| 259 | GGML_ABORT("doesn't support d_state!=(128 or 256)." ); |
| 260 | } |
| 261 | } else { |
| 262 | // Mamba-1 |
| 263 | GGML_ASSERT(n_head % threads == 0); |
| 264 | GGML_ASSERT(head_dim == 1); |
| 265 | GGML_ASSERT(n_group == 1); |
| 266 | const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); |
| 267 | const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); |
| 268 | if (d_state == 16) { |
| 269 | switch (n_tok) |
| 270 | { |
| 271 | case 1: |
| 272 | ssm_scan_f32<threads, 16, 1><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>( |
| 273 | src0, src1, src2, src3, src4, src5, src6, dst, |
| 274 | src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, |
| 275 | src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok); |
| 276 | break; |
| 277 | case 2: |
| 278 | ssm_scan_f32<threads, 16, 2><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>( |
| 279 | src0, src1, src2, src3, src4, src5, src6, dst, |
| 280 | src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, |
| 281 | src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok); |
| 282 | break; |
| 283 | case 3: |
| 284 | ssm_scan_f32<threads, 16, 3><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>( |
| 285 | src0, src1, src2, src3, src4, src5, src6, dst, |
| 286 | src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, |
| 287 | src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok); |
| 288 | break; |
| 289 | case 4: |
| 290 | ssm_scan_f32<threads, 16, 4><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>( |
| 291 | src0, src1, src2, src3, src4, src5, src6, dst, |
| 292 | src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, |
| 293 | src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok); |
| 294 | break; |
| 295 | case 5: |
| 296 | ssm_scan_f32<threads, 16, 5><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>( |
| 297 | src0, src1, src2, src3, src4, src5, src6, dst, |
| 298 | src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, |
| 299 | src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok); |
| 300 | break; |
| 301 | case 6: |
| 302 | ssm_scan_f32<threads, 16, 6><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>( |
| 303 | src0, src1, src2, src3, src4, src5, src6, dst, |
| 304 | src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, |
| 305 | src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok); |
| 306 | break; |
| 307 | case 7: |
| 308 | ssm_scan_f32<threads, 16, 7><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>( |
| 309 | src0, src1, src2, src3, src4, src5, src6, dst, |
| 310 | src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, |
| 311 | src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok); |
| 312 | break; |
| 313 | case 8: |
| 314 | ssm_scan_f32<threads, 16, 8><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>( |
| 315 | src0, src1, src2, src3, src4, src5, src6, dst, |
| 316 | src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, |
| 317 | src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok); |
| 318 | break; |
| 319 | default: |
| 320 | ssm_scan_f32<threads, 16, 0><<<gridDim: blocks, blockDim: threads, sharedMem: smem_size, stream>>>( |
| 321 | src0, src1, src2, src3, src4, src5, src6, dst, |
| 322 | src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, |
| 323 | src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, d_inner: n_head, L_param: n_tok); |
| 324 | break; |
| 325 | } |
| 326 | } else { |
| 327 | GGML_ABORT("doesn't support d_state!=16." ); |
| 328 | } |
| 329 | } |
| 330 | } |
| 331 | |
| 332 | void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 333 | const struct ggml_tensor * src0 = dst->src[0]; // s |
| 334 | const struct ggml_tensor * src1 = dst->src[1]; // x |
| 335 | const struct ggml_tensor * src2 = dst->src[2]; // dt |
| 336 | const struct ggml_tensor * src3 = dst->src[3]; // A |
| 337 | const struct ggml_tensor * src4 = dst->src[4]; // B |
| 338 | const struct ggml_tensor * src5 = dst->src[5]; // C |
| 339 | const struct ggml_tensor * src6 = dst->src[6]; // ids |
| 340 | |
| 341 | const int64_t nc = src0->ne[0]; // d_state |
| 342 | const int64_t nr = src0->ne[1]; // head_dim or 1 |
| 343 | const int64_t nh = src1->ne[1]; // n_head |
| 344 | const int64_t ng = src4->ne[1]; // n_group |
| 345 | const int64_t n_t = src1->ne[2]; // number of tokens per sequence |
| 346 | const int64_t n_s = src1->ne[3]; // number of sequences in the batch |
| 347 | |
| 348 | const int64_t s_off = ggml_nelements(src1) * sizeof(float); |
| 349 | |
| 350 | GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst)); |
| 351 | GGML_ASSERT(src0->nb[0] == sizeof(float)); |
| 352 | GGML_ASSERT(src1->nb[0] == sizeof(float)); |
| 353 | GGML_ASSERT(src2->nb[0] == sizeof(float)); |
| 354 | GGML_ASSERT(src3->nb[0] == sizeof(float)); |
| 355 | GGML_ASSERT(src4->nb[0] == sizeof(float)); |
| 356 | GGML_ASSERT(src5->nb[0] == sizeof(float)); |
| 357 | GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); |
| 358 | |
| 359 | const float * src0_d = (const float *) src0->data; |
| 360 | const float * src1_d = (const float *) src1->data; |
| 361 | const float * src2_d = (const float *) src2->data; |
| 362 | const float * src3_d = (const float *) src3->data; |
| 363 | const float * src4_d = (const float *) src4->data; |
| 364 | const float * src5_d = (const float *) src5->data; |
| 365 | const int32_t * src6_d = (const int32_t *) src6->data; |
| 366 | float * dst_d = (float *) dst->data; |
| 367 | cudaStream_t stream = ctx.stream(); |
| 368 | |
| 369 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 370 | GGML_ASSERT(src6->type == GGML_TYPE_I32); |
| 371 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 372 | |
| 373 | ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d, |
| 374 | src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], |
| 375 | src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3], |
| 376 | s_off, nc, nr, nh, ng, n_t, n_s, stream); |
| 377 | } |
| 378 | |