1#include "ggml.h"
2#include "common.cuh"
3#include "unary.cuh"
4#include "mmvf.cuh"
5#include "convert.cuh"
6
7template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
8static __global__ void mul_mat_vec_f(
9 const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
10 const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
11 const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
12 const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
13 const int row = blockIdx.x;
14 const int channel_dst = blockIdx.y;
15 const int channel_x = ids ? ids[channel_dst] : fastdiv(n: (uint32_t) channel_dst, fastdiv_values: channel_ratio);
16 const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
17 const int sample_dst = blockIdx.z;
18 const int sample_x = fastdiv(n: (uint32_t) sample_dst, fastdiv_values: sample_ratio);
19 const int sample_y = sample_dst;
20 const int tid = threadIdx.x;
21
22 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
23
24 x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
25 y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
26 dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
27
28 bool use_gate = false;
29 bool use_bias = false;
30 bool use_gate_bias = false;
31 ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
32 const T * gate_x = nullptr;
33 const float * x_bias = nullptr;
34 const float * gate_bias = nullptr;
35
36 if constexpr (has_fusion) {
37 use_gate = fusion.gate != nullptr;
38 use_bias = fusion.x_bias != nullptr;
39 use_gate_bias = fusion.gate_bias != nullptr;
40 glu_op = fusion.glu_op;
41
42 if (use_gate) {
43 gate_x = static_cast<const T *>(fusion.gate);
44 }
45 if (use_bias) {
46 x_bias = static_cast<const float *>(fusion.x_bias);
47 }
48 if (use_gate_bias) {
49 gate_bias = static_cast<const float *>(fusion.gate_bias);
50 use_gate_bias = use_gate;
51 } else {
52 use_gate_bias = false;
53 }
54 }
55
56 if (use_gate) {
57 gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
58 }
59 if constexpr (has_fusion) {
60 const int channel_bias = ids ? channel_x : channel_dst;
61 if (use_bias) {
62 x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
63 }
64 if (use_gate_bias) {
65 gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
66 }
67 }
68
69 const float2 * y2 = (const float2 *) y;
70
71 extern __shared__ char data_mmv[];
72 float * buf_iw = (float *) data_mmv;
73 float * buf_iw_gate = nullptr;
74 if constexpr (has_fusion) {
75 buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
76 }
77
78 if (block_size > warp_size) {
79 if (tid < warp_size) {
80 buf_iw[tid] = 0.0f;
81 if constexpr (has_fusion) {
82 if (use_gate) {
83 buf_iw_gate[tid] = 0.0f;
84 }
85 }
86 }
87 __syncthreads();
88 }
89
90 float sumf[ncols_dst] = {0.0f};
91 float sumf_gate[ncols_dst];
92 if constexpr (has_fusion) {
93#pragma unroll
94 for (int j = 0; j < ncols_dst; ++j) {
95 sumf_gate[j] = 0.0f;
96 }
97 }
98
99 if constexpr (std::is_same_v<T, float>) {
100 const float2 * x2 = (const float2 *) x;
101 const float2 * gate_x2 = nullptr;
102 if constexpr (has_fusion) {
103 if (use_gate) {
104 gate_x2 = (const float2 *) gate_x;
105 }
106 }
107
108 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
109 const float2 tmpx = x2[col2];
110 float2 tmpx_gate = make_float2(x: 0.0f, y: 0.0f);
111 if constexpr (has_fusion) {
112 if (use_gate) {
113 tmpx_gate = gate_x2[col2];
114 }
115 }
116
117#pragma unroll
118 for (int j = 0; j < ncols_dst; ++j) {
119 const float2 tmpy = y2[j*stride_col_y2 + col2];
120 ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
121 ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
122
123 if constexpr (has_fusion) {
124 if (use_gate) {
125 ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
126 ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
127 }
128 }
129 }
130 }
131 } else if constexpr (std::is_same_v<T, half>) {
132 const half2 * x2 = (const half2 *) x;
133 const half2 * gate_x2 = nullptr;
134 if constexpr (has_fusion) {
135 if (use_gate) {
136 gate_x2 = (const half2 *) gate_x;
137 }
138 }
139
140 if (std::is_same_v<type_acc, float>) {
141 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
142 const float2 tmpx = __half22float2(a: x2[col2]);
143 float2 tmpx_gate = make_float2(x: 0.0f, y: 0.0f);
144 if constexpr (has_fusion) {
145 if (use_gate) {
146 tmpx_gate = __half22float2(a: gate_x2[col2]);
147 }
148 }
149#pragma unroll
150 for (int j = 0; j < ncols_dst; ++j) {
151 const float2 tmpy = y2[j*stride_col_y2 + col2];
152 ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
153 ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
154
155 if constexpr (has_fusion) {
156 if (use_gate) {
157 ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
158 ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
159 }
160 }
161 }
162 }
163 } else {
164#ifdef FP16_AVAILABLE
165 half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
166 half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
167
168 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
169 const half2 tmpx = x2[col2];
170 half2 tmpx_gate = make_half2(0.0f, 0.0f);
171 if constexpr (has_fusion) {
172 if (use_gate) {
173 tmpx_gate = gate_x2[col2];
174 }
175 }
176#pragma unroll
177 for (int j = 0; j < ncols_dst; ++j) {
178 const float2 tmpy = y2[j*stride_col_y2 + col2];
179 sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
180
181 if constexpr (has_fusion) {
182 if (use_gate) {
183 sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
184 }
185 }
186 }
187 }
188
189#pragma unroll
190 for (int j = 0; j < ncols_dst; ++j) {
191 sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
192 }
193
194 if constexpr (has_fusion) {
195 if (use_gate) {
196#pragma unroll
197 for (int j = 0; j < ncols_dst; ++j) {
198 sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
199 }
200 }
201 }
202#else
203 NO_DEVICE_CODE;
204#endif // FP16_AVAILABLE
205 }
206 } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
207//TODO: add support for ggml_cuda_mad for hip_bfloat162
208#if defined(GGML_USE_HIP)
209 const int * x2 = (const int *) x;
210 const int * gate_x2 = nullptr;
211 if constexpr (has_fusion) {
212 if (use_gate) {
213 gate_x2 = (const int *) gate_x;
214 }
215 }
216 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
217 const int tmpx = x2[col2];
218 int tmpx_gate = 0;
219 if constexpr (has_fusion) {
220 if (use_gate) {
221 tmpx_gate = gate_x2[col2];
222 }
223 }
224#pragma unroll
225 for (int j = 0; j < ncols_dst; ++j) {
226 const float2 tmpy = y2[j*stride_col_y2 + col2];
227 const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
228 const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
229 ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
230 ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
231
232 if constexpr (has_fusion) {
233 if (use_gate) {
234 const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
235 const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
236 ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
237 ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
238 }
239 }
240 }
241 }
242#else
243 const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
244 const nv_bfloat162 * gate_x2 = nullptr;
245 if constexpr (has_fusion) {
246 if (use_gate) {
247 gate_x2 = (const nv_bfloat162 *) gate_x;
248 }
249 }
250 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
251 const nv_bfloat162 tmpx = x2[col2];
252 nv_bfloat162 tmpx_gate;
253 if constexpr (has_fusion) {
254 if (use_gate) {
255 tmpx_gate = gate_x2[col2];
256 }
257 }
258#pragma unroll
259 for (int j = 0; j < ncols_dst; ++j) {
260 const float2 tmpy = y2[j*stride_col_y2 + col2];
261 ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
262 ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
263
264 if constexpr (has_fusion) {
265 if (use_gate) {
266 ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
267 ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
268 }
269 }
270 }
271 }
272#endif
273 } else {
274 static_assert(std::is_same_v<T, void>, "unsupported type");
275 }
276
277#pragma unroll
278 for (int j = 0; j < ncols_dst; ++j) {
279 sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
280
281 if constexpr (has_fusion) {
282 if (use_gate) {
283 sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
284 }
285 }
286
287 if (block_size > warp_size) {
288 buf_iw[tid/warp_size] = sumf[j];
289 if constexpr (has_fusion) {
290 if (use_gate) {
291 buf_iw_gate[tid/warp_size] = sumf_gate[j];
292 }
293 }
294 __syncthreads();
295 if (tid < warp_size) {
296 sumf[j] = buf_iw[tid];
297 sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
298 if constexpr (has_fusion) {
299 if (use_gate) {
300 sumf_gate[j] = buf_iw_gate[tid];
301 sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
302 }
303 }
304 }
305
306 if (j < ncols_dst) {
307 __syncthreads();
308 }
309 }
310 }
311
312 if (tid >= ncols_dst) {
313 return;
314 }
315
316 float value = sumf[tid];
317
318 if constexpr (has_fusion) {
319 if (use_bias) {
320 value += x_bias[tid*stride_col_dst + row];
321 }
322
323 if (use_gate) {
324 float gate_value = sumf_gate[tid];
325 if (use_gate_bias) {
326 gate_value += gate_bias[tid*stride_col_dst + row];
327 }
328 switch (glu_op) {
329 case GGML_GLU_OP_SWIGLU:
330 value *= ggml_cuda_op_silu_single(x: gate_value);
331 break;
332 case GGML_GLU_OP_GEGLU:
333 value *= ggml_cuda_op_gelu_single(x: gate_value);
334 break;
335 case GGML_GLU_OP_SWIGLU_OAI: {
336 value = ggml_cuda_op_swiglu_oai_single(x: gate_value, g: value);
337 break;
338 }
339 default:
340 break;
341 }
342 }
343 }
344
345 dst[tid*stride_col_dst + row] = value;
346
347 if constexpr (!has_fusion) {
348 GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
349 }
350}
351
352template<typename T, typename type_acc, int ncols_dst, int block_size>
353static void mul_mat_vec_f_switch_fusion(
354 const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
355 const int64_t ncols, const int64_t nrows,
356 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
357 const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
358 const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
359 const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
360
361 const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
362 if constexpr (ncols_dst == 1) {
363 if (has_fusion) {
364 mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared, stream>>>
365 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
366 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
367 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
368 return;
369 }
370 }
371
372 GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
373
374 mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared, stream>>>
375 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
376 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
377 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
378
379}
380
381template <typename T, typename type_acc, int ncols_dst>
382void launch_mul_mat_vec_f_cuda(
383 const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
384 const int64_t ncols, const int64_t nrows,
385 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
386 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
387 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
388 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
389 cudaStream_t stream) {
390 GGML_ASSERT(ncols % 2 == 0);
391 GGML_ASSERT(stride_row % 2 == 0);
392 GGML_ASSERT(stride_col_y % 2 == 0);
393 GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
394 GGML_ASSERT( nsamples_dst % nsamples_x == 0);
395 const uint3 channel_ratio_fd = ids ? make_uint3(x: 0, y: 0, z: 0) : init_fastdiv_values(d_64: nchannels_dst / nchannels_x);
396 const uint3 sample_ratio_fd = init_fastdiv_values(d_64: nsamples_dst / nsamples_x);
397
398 const int device = ggml_cuda_get_device();
399 const int warp_size = ggml_cuda_info().devices[device].warp_size;
400
401 int64_t block_size_best = warp_size;
402 int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
403 int64_t max_block_size = 256;
404 if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
405 max_block_size = 128;
406 }
407 for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
408 const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
409 if (niter < niter_best) {
410 niter_best = niter;
411 block_size_best = block_size;
412 }
413 }
414
415 const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
416
417 const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
418 const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
419 const dim3 block_dims(block_size_best, 1, 1);
420 switch (block_size_best) {
421 case 32: {
422 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
423 (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
424 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
425 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
426 } break;
427 case 64: {
428 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
429 (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
430 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
431 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
432 } break;
433 case 96: {
434 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
435 (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
436 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
437 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
438 } break;
439 case 128: {
440 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
441 (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
442 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
443 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
444 } break;
445 case 160: {
446 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
447 (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
448 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
449 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
450 } break;
451 case 192: {
452 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
453 (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
454 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
455 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
456 } break;
457 case 224: {
458 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
459 (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
460 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
461 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
462 } break;
463 case 256: {
464 mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
465 (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
466 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
467 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
468 } break;
469 default: {
470 GGML_ABORT("fatal error");
471 } break;
472 }
473}
474
475template <typename T, typename type_acc>
476static void mul_mat_vec_f_cuda_switch_ncols_dst(
477 const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
478 const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
479 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
480 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
481 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
482 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
483 cudaStream_t stream) {
484 switch (ncols_dst) {
485 case 1:
486 launch_mul_mat_vec_f_cuda<T, type_acc, 1>
487 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
488 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
489 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
490 break;
491 case 2:
492 launch_mul_mat_vec_f_cuda<T, type_acc, 2>
493 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
494 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
495 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
496 break;
497 case 3:
498 launch_mul_mat_vec_f_cuda<T, type_acc, 3>
499 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
500 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
501 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
502 break;
503 case 4:
504 launch_mul_mat_vec_f_cuda<T, type_acc, 4>
505 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
506 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
507 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
508 break;
509 case 5:
510 launch_mul_mat_vec_f_cuda<T, type_acc, 5>
511 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
512 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
513 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
514 break;
515 case 6:
516 launch_mul_mat_vec_f_cuda<T, type_acc, 6>
517 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
518 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
519 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
520 break;
521 case 7:
522 launch_mul_mat_vec_f_cuda<T, type_acc, 7>
523 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
524 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
525 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
526 break;
527 case 8:
528 launch_mul_mat_vec_f_cuda<T, type_acc, 8>
529 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
530 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
531 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
532 break;
533 default:
534 GGML_ABORT("fatal error");
535 break;
536 }
537}
538
539template<typename T>
540static void mul_mat_vec_f_cuda(
541 const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
542 const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
543 const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
544 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
545 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
546 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
547 enum ggml_prec prec, cudaStream_t stream) {
548
549 if constexpr(std::is_same_v<T, half>) {
550 if (prec == GGML_PREC_DEFAULT) {
551 mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
552 (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
553 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
554 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
555 return;
556 }
557 }
558 mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
559 (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
560 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
561 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
562}
563
564void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
565 const ggml_cuda_mm_fusion_args_host * fusion) {
566 GGML_ASSERT( src1->type == GGML_TYPE_F32);
567 GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
568 GGML_ASSERT( dst->type == GGML_TYPE_F32);
569
570 GGML_TENSOR_BINARY_OP_LOCALS;
571
572 const size_t ts_src0 = ggml_type_size(src0->type);
573 const size_t ts_src1 = ggml_type_size(src1->type);
574 const size_t ts_dst = ggml_type_size(dst->type);
575
576 GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
577 GGML_ASSERT(ne13 == ne3);
578
579 GGML_ASSERT( nb00 == ts_src0);
580 GGML_ASSERT( nb10 == ts_src1);
581 GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
582 GGML_ASSERT( nb0 == ts_dst);
583
584 const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
585 const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
586
587 const float * src1_d = (const float *) src1->data;
588 const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
589 float * dst_d = (float *) dst->data;
590
591 ggml_cuda_mm_fusion_args_device fusion_local{};
592
593 if (fusion) {
594 GGML_ASSERT( !ids || dst->ne[2] == 1);
595 GGML_ASSERT( ids || dst->ne[1] == 1);
596 if (fusion->x_bias) {
597 GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
598 GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
599 GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
600 fusion_local.x_bias = fusion->x_bias->data;
601 }
602 if (fusion->gate) {
603 GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
604 fusion_local.gate = fusion->gate->data;
605 }
606 if (fusion->gate_bias) {
607 GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
608 GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
609 GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
610 fusion_local.gate_bias = fusion->gate_bias->data;
611 }
612 fusion_local.glu_op = fusion->glu_op;
613 }
614
615 const int64_t s01 = src0->nb[1] / ts_src0;
616 const int64_t s11 = src1->nb[1] / ts_src1;
617 const int64_t s1 = dst->nb[1] / ts_dst;
618 const int64_t s02 = src0->nb[2] / ts_src0;
619 const int64_t s12 = src1->nb[2] / ts_src1;
620 const int64_t s2 = dst->nb[2] / ts_dst;
621 const int64_t s03 = src0->nb[3] / ts_src0;
622 const int64_t s13 = src1->nb[3] / ts_src1;
623 const int64_t s3 = dst->nb[3] / ts_dst;
624
625 // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
626 const int64_t ncols_dst = ids ? ne2 : ne1;
627 const int64_t nchannels_y = ids ? ne11 : ne12;
628 const int64_t nchannels_dst = ids ? ne1 : ne2;
629 const int64_t stride_channel_dst = ids ? s1 : s2;
630 const int64_t stride_channel_y = ids ? s11 : s12;
631
632 GGML_ASSERT(!ids || ncols_dst == 1);
633
634 switch (src0->type) {
635 case GGML_TYPE_F32: {
636 const float * src0_d = (const float *) src0->data;
637 mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
638 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
639 ne03, ne3, s03, s13, s3, prec, ctx.stream());
640 } break;
641 case GGML_TYPE_F16: {
642 const half * src0_d = (const half *) src0->data;
643 mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
644 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
645 ne03, ne3, s03, s13, s3, prec, ctx.stream());
646 } break;
647 case GGML_TYPE_BF16: {
648 const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
649 mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
650 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
651 ne03, ne3, s03, s13, s3, prec, ctx.stream());
652 } break;
653 default:
654 GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
655 }
656}
657
658void ggml_cuda_op_mul_mat_vec_f(
659 ggml_backend_cuda_context & ctx,
660 const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
661 const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
662 const int64_t src1_padded_row_size, cudaStream_t stream) {
663
664 GGML_ASSERT(src1->type == GGML_TYPE_F32);
665 GGML_ASSERT(dst->type == GGML_TYPE_F32);
666
667 const int64_t ne00 = src0->ne[0];
668 const int64_t ne10 = src1->ne[0];
669 const int64_t ne0 = dst->ne[0];
670 const int64_t row_diff = row_high - row_low;
671
672 const int id = ggml_cuda_get_device();
673 const int cc = ggml_cuda_info().devices[id].cc;
674 const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
675
676 // ggml_cuda_op provides single, contiguous matrices
677 const int64_t stride_row = ne00;
678 const int64_t stride_col_y = ne10;
679 const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
680 const int64_t nchannels_x = 1;
681 const int64_t nchannels_y = 1;
682 const int64_t nchannels_dst = 1;
683 const int64_t stride_channel_x = 0;
684 const int64_t stride_channel_y = 0;
685 const int64_t stride_channel_dst = 0;
686 const int64_t nsamples_x = 1;
687 const int64_t nsamples_dst = 1;
688 const int64_t stride_sample_x = 0;
689 const int64_t stride_sample_y = 0;
690 const int64_t stride_sample_dst = 0;
691
692 ggml_cuda_mm_fusion_args_device empty{};
693 switch (src0->type) {
694 case GGML_TYPE_F32: {
695 const float * src0_d = (const float *) src0_dd_i;
696 mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
697 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
698 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
699 } break;
700 case GGML_TYPE_F16: {
701 const half * src0_d = (const half *) src0_dd_i;
702 mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
703 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
704 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
705 } break;
706 case GGML_TYPE_BF16: {
707 const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
708 mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
709 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
710 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
711 } break;
712 default:
713 GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
714 }
715
716 GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);
717}
718
719bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) {
720 if (src0_ne[0] % 2 != 0) {
721 return false;
722 }
723
724 const size_t ts = ggml_type_size(type);
725 if (src0_nb[0] != ts) {
726 return false;
727 }
728
729 // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
730 for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
731 if (src0_nb[i] % (2*ts) != 0) {
732 return false;
733 }
734 }
735
736 switch (type) {
737 case GGML_TYPE_F32:
738 if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
739 if (ampere_mma_available(cc)) {
740 return ne11 <= 3;
741 }
742 if (cc >= GGML_CUDA_CC_TURING) {
743 return ne11 <= 4;
744 }
745 return ne11 <= 3;
746 } else if (GGML_CUDA_CC_IS_AMD(cc)) {
747 if (fp32_mma_hardware_available(cc)) {
748 return ne11 <= 3;
749 }
750 return ne11 <= 8;
751 }
752 return ne11 <= 8;
753 case GGML_TYPE_F16:
754 if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
755 const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
756 if (ampere_mma_available(cc)) {
757 return src0_small && ne11 == 1;
758 }
759 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
760 return src0_small && ne11 <= 4;
761 }
762 if (fp16_mma_hardware_available(cc)) {
763 return src0_small && ne11 <= 3;
764 }
765 return ne11 <= 8;
766 } else if (GGML_CUDA_CC_IS_AMD(cc)) {
767 if (fp16_mma_hardware_available(cc)) {
768 if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
769 return ne11 <= 5;
770 }
771 return ne11 <= 2;
772 }
773 return ne11 <= 8;
774 }
775 return ne11 <= 8;
776 case GGML_TYPE_BF16:
777 if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
778 const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
779 if (ampere_mma_available(cc)) {
780 return src0_small && ne11 == 1;
781 }
782 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
783 return src0_small && ne11 <= 4;
784 }
785 if (bf16_mma_hardware_available(cc)) {
786 return src0_small && ne11 <= 3;
787 }
788 return ne11 <= 8;
789 } else if (GGML_CUDA_CC_IS_AMD(cc)) {
790 if (bf16_mma_hardware_available(cc)) {
791 return ne11 <= 3;
792 }
793 return ne11 <= 8;
794 }
795 return ne11 <= 8;
796 default:
797 return false;
798 }
799}
800