| 1 | #include "ssm-conv.cuh" |
| 2 | |
| 3 | template <size_t split_d_inner, size_t d_conv> |
| 4 | static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, |
| 5 | const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, |
| 6 | float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, |
| 7 | const int64_t n_t) { |
| 8 | GGML_UNUSED(src0_nb0); |
| 9 | const int tid = threadIdx.x; |
| 10 | const int bidx = blockIdx.x; |
| 11 | const int bidy = blockIdx.y; |
| 12 | |
| 13 | const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1); |
| 14 | const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1); |
| 15 | float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0); |
| 16 | |
| 17 | const int stride_x = src0_nb1 / sizeof(float); |
| 18 | const int stride_w = src1_nb1 / sizeof(float); |
| 19 | const int stride_y = dst_nb1 / sizeof(float); |
| 20 | |
| 21 | float x[d_conv] = { 0.0f }; |
| 22 | float w[d_conv] = { 0.0f }; |
| 23 | |
| 24 | #pragma unroll |
| 25 | for (size_t j = 0; j < d_conv; j++) { |
| 26 | w[j] = w_block[tid * stride_w + j]; |
| 27 | } |
| 28 | |
| 29 | for (int64_t i = 0; i < n_t; i++) { |
| 30 | float sumf = 0.0f; |
| 31 | |
| 32 | if (i == 0) { |
| 33 | for (size_t j = 0; j < d_conv; j++) { |
| 34 | x[j] = x_block[tid * stride_x + j]; |
| 35 | } |
| 36 | } else { |
| 37 | x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; |
| 38 | } |
| 39 | |
| 40 | #pragma unroll |
| 41 | for (size_t j = 0; j < d_conv; j++) { |
| 42 | sumf += x[(i + j) % d_conv] * w[j]; |
| 43 | } |
| 44 | y_block[i * stride_y + tid] = sumf; |
| 45 | } |
| 46 | } |
| 47 | |
| 48 | template <size_t split_d_inner, size_t d_conv, int64_t split_n_t> |
| 49 | static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, |
| 50 | const int src0_nb0, const int src0_nb1, const int src0_nb2, |
| 51 | const int src1_nb1, float * __restrict__ dst, const int dst_nb0, |
| 52 | const int dst_nb1, const int dst_nb2, const int64_t n_t) { |
| 53 | const int tid = threadIdx.x; |
| 54 | const int bidx = blockIdx.x; |
| 55 | const int bidy = blockIdx.y; |
| 56 | const int bidz = blockIdx.z; |
| 57 | |
| 58 | const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 + |
| 59 | bidz * split_n_t * src0_nb0); |
| 60 | const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1); |
| 61 | float * y_block = |
| 62 | (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0); |
| 63 | |
| 64 | const int stride_x = src0_nb1 / sizeof(float); |
| 65 | const int stride_w = src1_nb1 / sizeof(float); |
| 66 | const int stride_y = dst_nb1 / sizeof(float); |
| 67 | |
| 68 | float x[d_conv] = { 0.0f }; |
| 69 | float w[d_conv] = { 0.0f }; |
| 70 | |
| 71 | #pragma unroll |
| 72 | for (size_t j = 0; j < d_conv; j++) { |
| 73 | w[j] = w_block[tid * stride_w + j]; |
| 74 | } |
| 75 | |
| 76 | #pragma unroll |
| 77 | for (int64_t i = 0; i < split_n_t; i++) { |
| 78 | if (bidz * split_n_t + i < n_t) { |
| 79 | float sumf = 0.0f; |
| 80 | |
| 81 | if (i == 0) { |
| 82 | for (size_t j = 0; j < d_conv; j++) { |
| 83 | x[j] = x_block[tid * stride_x + j]; |
| 84 | } |
| 85 | } else { |
| 86 | x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; |
| 87 | } |
| 88 | |
| 89 | #pragma unroll |
| 90 | for (size_t j = 0; j < d_conv; j++) { |
| 91 | sumf += x[(i + j) % d_conv] * w[j]; |
| 92 | } |
| 93 | y_block[i * stride_y + tid] = sumf; |
| 94 | } |
| 95 | } |
| 96 | } |
| 97 | |
| 98 | static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, |
| 99 | const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, |
| 100 | const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, |
| 101 | const int64_t n_s, cudaStream_t stream) { |
| 102 | const int threads = 128; |
| 103 | GGML_ASSERT(nr % threads == 0); |
| 104 | |
| 105 | if (n_t <= 32) { |
| 106 | const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); |
| 107 | if (nc == 4) { |
| 108 | ssm_conv_f32<threads, 4><<<gridDim: blocks, blockDim: threads, sharedMem: 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, |
| 109 | dst, dst_nb0, dst_nb1, dst_nb2, n_t); |
| 110 | } else if (nc == 3) { |
| 111 | ssm_conv_f32<threads, 3><<<gridDim: blocks, blockDim: threads, sharedMem: 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, |
| 112 | dst, dst_nb0, dst_nb1, dst_nb2, n_t); |
| 113 | } else { |
| 114 | GGML_ABORT("Only support kernel size = 3 or size = 4 right now." ); |
| 115 | } |
| 116 | } else { |
| 117 | if (nc == 4) { |
| 118 | const int64_t split_n_t = 32; |
| 119 | dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); |
| 120 | ssm_conv_long_token_f32<threads, 4, split_n_t><<<gridDim: blocks, blockDim: threads, sharedMem: 0, stream>>>( |
| 121 | src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); |
| 122 | } else if (nc == 3) { |
| 123 | const int64_t split_n_t = 32; |
| 124 | dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); |
| 125 | ssm_conv_long_token_f32<threads, 3, split_n_t><<<gridDim: blocks, blockDim: threads, sharedMem: 0, stream>>>( |
| 126 | src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); |
| 127 | } else { |
| 128 | GGML_ABORT("Only support kernel size = 3 or size = 4 right now." ); |
| 129 | } |
| 130 | } |
| 131 | } |
| 132 | |
| 133 | void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 134 | const struct ggml_tensor * src0 = dst->src[0]; // conv_x |
| 135 | const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight |
| 136 | |
| 137 | const int64_t nc = src1->ne[0]; // d_conv |
| 138 | const int64_t nr = src0->ne[1]; // d_inner |
| 139 | const int64_t n_t = dst->ne[1]; // tokens per sequence |
| 140 | const int64_t n_s = dst->ne[2]; // number of sequences in the batch |
| 141 | |
| 142 | GGML_ASSERT(dst->ne[0] == nr); |
| 143 | GGML_ASSERT(src0->nb[0] == sizeof(float)); |
| 144 | GGML_ASSERT(src1->nb[0] == sizeof(float)); |
| 145 | GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); |
| 146 | |
| 147 | const float * src0_d = (const float *) src0->data; |
| 148 | const float * src1_d = (const float *) src1->data; |
| 149 | float * dst_d = (float *) dst->data; |
| 150 | cudaStream_t stream = ctx.stream(); |
| 151 | |
| 152 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 153 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 154 | ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], |
| 155 | dst->nb[2], nc, nr, n_t, n_s, stream); |
| 156 | } |
| 157 | |