1#pragma once
2
3#include "common.cuh"
4
5#include <cstdint>
6
7static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
8 const uint8_t * x8 = (const uint8_t *) x;
9
10 int x32 = x8[4*i32 + 0] << 0;
11 x32 |= x8[4*i32 + 1] << 8;
12 x32 |= x8[4*i32 + 2] << 16;
13 x32 |= x8[4*i32 + 3] << 24;
14
15 return x32;
16}
17
18static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
19 const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
20
21 int x32 = x16[2*i32 + 0] << 0;
22 x32 |= x16[2*i32 + 1] << 16;
23
24 return x32;
25}
26
27static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) {
28 return ((const int *) x)[i32]; // assume at least 4 byte alignment
29}
30
31// q4 contains 8 indices with 4 bit each.
32// This function selects those bytes from table that are at those indices and returns them as int2.
33// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
34static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
35#if defined(GGML_USE_HIP)
36 // Load the 16-byte table into four 32-bit unsigned integers.
37 const uint32_t *values = (const uint32_t *)table;
38
39 const uint32_t q_even = q4;
40 const uint32_t q_odd = (q4 >> 4);
41
42 // Perform lookups in the lower half of the table (indices 0-7).
43 uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
44 uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
45
46 // Perform lookups in the upper half of the table (indices 8-15).
47 uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
48 uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
49
50 // Select between the low and high results based on the MSB of each index nibble.
51 uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
52 uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
53 uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
54 uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
55
56 return make_int2(res_x, res_y);
57#elif !defined(GGML_USE_MUSA)
58 // CUDA does not have an instruction for selecting bytes with 4 bit indices.
59 // However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
60 const uint32_t * table32 = (const uint32_t *) table;
61
62 // __byte_perm selects bytes based on the lower 16 bits in its third argument.
63 // Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
64 // To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
65 // Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
66 uint32_t tmp[2];
67 const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
68#pragma unroll
69 for (uint32_t i = 0; i < 2; ++i) {
70 const uint32_t shift = 16 * i;
71
72 const uint32_t low = __byte_perm(a: table32[0], b: table32[1], c: q4 >> shift);
73 const uint32_t high = __byte_perm(a: table32[2], b: table32[3], c: q4 >> shift);
74 tmp[i] = __byte_perm(a: low, b: high, c: low_high_selection_indices >> shift);
75 }
76
77 // tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
78 // However, for the result we need ints with all even/odd 4 bit indices in q4.
79 // Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
80 return make_int2(x: __byte_perm(a: tmp[0], b: tmp[1], c: 0x6420), y: __byte_perm(a: tmp[0], b: tmp[1], c: 0x7531));
81#else
82 // Generic implementation.
83 const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
84 const int8_t * q0_8 = (const int8_t *) &q0_32;
85 const char4 val0_8 = make_char4(
86 table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
87
88 const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
89 const int8_t * q1_8 = (const int8_t *) &q1_32;
90 const char4 val1_8 = make_char4(
91 table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
92
93 return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
94#endif
95}
96
97// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
98// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
99
100#define VDR_Q4_0_Q8_1_MMVQ 2
101#define VDR_Q4_0_Q8_1_MMQ 4
102
103template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
104 const int * v, const int * u, const float & d4, const half2 & ds8) {
105
106 int sumi = 0;
107
108#pragma unroll
109 for (int i = 0; i < vdr; ++i) {
110 const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
111 const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
112
113 // SIMD dot product of quantized values
114 sumi = ggml_cuda_dp4a(a: vi0, b: u[2*i+0], c: sumi);
115 sumi = ggml_cuda_dp4a(a: vi1, b: u[2*i+1], c: sumi);
116 }
117
118 const float2 ds8f = __half22float2(a: ds8);
119
120 // second part effectively subtracts 8 from each quant value
121 return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
122}
123
124#define VDR_Q4_1_Q8_1_MMVQ 2
125#define VDR_Q4_1_Q8_1_MMQ 4
126
127template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
128 const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
129
130 int sumi = 0;
131
132#pragma unroll
133 for (int i = 0; i < vdr; ++i) {
134 const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
135 const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
136
137 // SIMD dot product of quantized values
138 sumi = ggml_cuda_dp4a(a: vi0, b: u[2*i+0], c: sumi);
139 sumi = ggml_cuda_dp4a(a: vi1, b: u[2*i+1], c: sumi);
140 }
141
142#ifdef FAST_FP16_AVAILABLE
143 const float2 tmp = __half22float2(__hmul2(dm4, ds8));
144 const float d4d8 = tmp.x;
145 const float m4s8 = tmp.y;
146#else
147 const float2 dm4f = __half22float2(a: dm4);
148 const float2 ds8f = __half22float2(a: ds8);
149 const float d4d8 = dm4f.x * ds8f.x;
150 const float m4s8 = dm4f.y * ds8f.y;
151#endif // FAST_FP16_AVAILABLE
152
153 // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
154 return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
155}
156
157#define VDR_Q5_0_Q8_1_MMVQ 2
158#define VDR_Q5_0_Q8_1_MMQ 4
159
160template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
161 const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
162
163 int sumi = 0;
164
165#pragma unroll
166 for (int i = 0; i < vdr; ++i) {
167 int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
168 vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
169 vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
170 vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
171 vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
172 sumi = ggml_cuda_dp4a(a: vi0, b: u[2*i+0], c: sumi); // SIMD dot product of quantized values
173
174 int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
175 vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
176 vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
177 vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
178 vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
179 sumi = ggml_cuda_dp4a(a: vi1, b: u[2*i+1], c: sumi); // SIMD dot product of quantized values
180 }
181
182 const float2 ds8f = __half22float2(a: ds8);
183
184 // second part effectively subtracts 16 from each quant value
185 return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
186}
187
188#define VDR_Q5_1_Q8_1_MMVQ 2
189#define VDR_Q5_1_Q8_1_MMQ 4
190
191template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
192 const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
193
194 int sumi = 0;
195
196#pragma unroll
197 for (int i = 0; i < vdr; ++i) {
198 int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
199 vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
200 vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
201 vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
202 vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
203 sumi = ggml_cuda_dp4a(a: vi0, b: u[2*i+0], c: sumi); // SIMD dot product of quantized values
204
205 int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
206 vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
207 vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
208 vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
209 vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
210 sumi = ggml_cuda_dp4a(a: vi1, b: u[2*i+1], c: sumi); // SIMD dot product of quantized values
211 }
212
213#ifdef FAST_FP16_AVAILABLE
214 const float2 tmp = __half22float2(__hmul2(dm5, ds8));
215 const float d5d8 = tmp.x;
216 const float m5s8 = tmp.y;
217#else
218 const float2 dm5f = __half22float2(a: dm5);
219 const float2 ds8f = __half22float2(a: ds8);
220 const float d5d8 = dm5f.x * ds8f.x;
221 const float m5s8 = dm5f.y * ds8f.y;
222#endif // FAST_FP16_AVAILABLE
223
224 // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
225 return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
226}
227
228#define VDR_Q8_0_Q8_1_MMVQ 2
229#define VDR_Q8_0_Q8_1_MMQ 8
230
231template <typename T, int vdr> static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl(
232 const int * v, const int * u, const T & d8_0, const T & d8_1) {
233
234 int sumi = 0;
235
236#pragma unroll
237 for (int i = 0; i < vdr; ++i) {
238 // SIMD dot product of quantized values
239 sumi = ggml_cuda_dp4a(a: v[i], b: u[i], c: sumi);
240 }
241
242 return d8_0*d8_1 * ((T) sumi);
243}
244
245template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
246 const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
247
248 int sumi = 0;
249
250#pragma unroll
251 for (int i = 0; i < vdr; ++i) {
252 // SIMD dot product of quantized values
253 sumi = ggml_cuda_dp4a(a: v[i], b: u[i], c: sumi);
254 }
255
256#ifdef FAST_FP16_AVAILABLE
257 const float2 tmp = __half22float2(__hmul2(dm8, ds8));
258 const float d8d8 = tmp.x;
259 const float m8s8 = tmp.y;
260#else
261 const float2 dm8f = __half22float2(a: dm8);
262 const float2 ds8f = __half22float2(a: ds8);
263 const float d8d8 = dm8f.x * ds8f.x;
264 const float m8s8 = dm8f.y * ds8f.y;
265#endif // FAST_FP16_AVAILABLE
266
267 // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
268 return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
269}
270
271template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl(
272 const int * v, const int * u, const float * d8_0, const float & d8_1) {
273
274 float sumf = 0.0f;
275
276#pragma unroll
277 for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) {
278 int sumi = 0;
279
280#pragma unroll
281 for (int i = i0; i < i0 + QI8_0/2; ++i) {
282 // SIMD dot product of quantized values
283 sumi = ggml_cuda_dp4a(a: v[i], b: u[i], c: sumi);
284 }
285
286 sumf += d8_0[i0/(QI8_0/2)]*sumi;
287 }
288
289 return d8_1*sumf;
290}
291
292#define VDR_MXFP4_Q8_1_MMVQ 2
293#define VDR_MXFP4_Q8_1_MMQ 4
294
295static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
296 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
297
298 const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
299
300 const int * q8 = (const int *) bq8_1->qs + iqs;
301
302 int sumi = 0;
303#pragma unroll
304 for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
305 const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
306 const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
307
308 sumi = ggml_cuda_dp4a(a: v.x, b: q8[l + 0], c: sumi);
309 sumi = ggml_cuda_dp4a(a: v.y, b: q8[l + 4], c: sumi);
310 }
311
312 const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
313 return d * sumi;
314}
315
316#define VDR_Q2_K_Q8_1_MMVQ 1
317#define VDR_Q2_K_Q8_1_MMQ 4
318
319// contiguous v/x values
320static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
321 const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
322 const half2 & dm2, const float * __restrict__ d8) {
323
324 float sumf_d = 0.0f;
325 float sumf_m = 0.0f;
326
327#pragma unroll
328 for (int i = 0; i < QR2_K; ++i) {
329 const int sc = scales[2*i];
330
331 const int vi = (v >> (2*i)) & 0x03030303;
332
333 sumf_d += d8[i] * (ggml_cuda_dp4a(a: vi, b: u[i], c: 0) * (sc & 0xF)); // SIMD dot product
334
335 // fill int with 4x m
336 int m = sc >> 4;
337 m |= m << 8;
338 m |= m << 16;
339 sumf_m += d8[i] * ggml_cuda_dp4a(a: m, b: u[i], c: 0); // multiply constant q2_K part with sum of q8_1 values
340 }
341
342 const float2 dm2f = __half22float2(a: dm2);
343
344 return dm2f.x*sumf_d - dm2f.y*sumf_m;
345}
346
347// contiguous v/x + u/y values
348template <int ns8>
349static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
350 const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
351
352 float sumf = 0.0f;
353 float sumf_d8 = 0.0f;
354
355#pragma unroll
356 for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
357 const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
358 int sumi_d0 = 0;
359
360 const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
361 int sumi_d1 = 0;
362
363#pragma unroll
364 for (int i = i0; i < i0 + QI8_1/2; ++i) {
365 sumi_d0 = ggml_cuda_dp4a(a: v[i], b: u[i], c: sumi_d0);
366 }
367 sumf_d8 += dm2f0.x * sumi_d0;
368
369#pragma unroll
370 for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
371 sumi_d1 = ggml_cuda_dp4a(a: v[i], b: u[i], c: sumi_d1);
372 }
373 sumf_d8 += dm2f1.x * sumi_d1;
374
375 if (i0/QI8_1 < ns8) {
376 const float2 s8f = __half22float2(s8[i0/QI8_1]);
377 sumf -= dm2f0.y*s8f.x;
378 sumf -= dm2f1.y*s8f.y;
379 } else {
380 int sumi_m0 = 0;
381#pragma unroll
382 for (int i = i0; i < i0 + QI8_1/2; ++i) {
383 sumi_m0 = ggml_cuda_dp4a(a: 0x01010101, b: u[i], c: sumi_m0);
384 }
385 sumf_d8 -= dm2f0.y * sumi_m0;
386
387 int sumi_m1 = 0;
388#pragma unroll
389 for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
390 sumi_m1 = ggml_cuda_dp4a(a: 0x01010101, b: u[i], c: sumi_m1);
391 }
392 sumf_d8 -= dm2f1.y * sumi_m1;
393 }
394 }
395
396 return sumf + d8*sumf_d8;
397}
398
399#define VDR_Q3_K_Q8_1_MMVQ 1
400#define VDR_Q3_K_Q8_1_MMQ 2
401
402// contiguous v/x values
403static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
404 const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
405 const int & scale_offset, const float & d3, const float * __restrict__ d8) {
406
407 float sumf = 0.0f;
408
409#pragma unroll
410 for (int i = 0; i < QR3_K; ++i) {
411 const int isc = scale_offset + 2*i;
412
413 const int isc_low = isc % (QK_K/32);
414 const int sc_shift_low = 4 * (isc / (QK_K/32));
415 const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF;
416
417 const int isc_high = isc % (QK_K/64);
418 const int sc_shift_high = 2 * (isc / (QK_K/64));
419 const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
420
421 const int sc = (sc_low | sc_high) - 32;
422
423 const int vil = (vl >> (2*i)) & 0x03030303;
424
425 const int vih = ((vh >> i) << 2) & 0x04040404;
426
427 const int vi = __vsubss4(a: vil, b: vih);
428
429 sumf += d8[i] * (ggml_cuda_dp4a(a: vi, b: u[i], c: 0) * sc); // SIMD dot product
430 }
431
432 return d3 * sumf;
433}
434
435// contiguous v/x + u/y values
436static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
437 const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
438 const float & d3, const float & d8) {
439
440 int sumi = 0;
441
442#pragma unroll
443 for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
444 int sumi_sc = 0;
445
446#pragma unroll
447 for (int i = i0; i < i0 + QI8_1/2; ++i) {
448 sumi_sc = ggml_cuda_dp4a(a: v[i], b: u[i], c: sumi_sc); // SIMD dot product
449 }
450
451 sumi += sumi_sc * scales[i0 / (QI8_1/2)];
452 }
453
454 return d3*d8 * sumi;
455}
456
457#define VDR_Q4_K_Q8_1_MMVQ 2
458#define VDR_Q4_K_Q8_1_MMQ 8
459
460// contiguous v/x values
461static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
462 const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
463 const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
464
465 float sumf_d = 0.0f;
466 float sumf_m = 0.0f;
467
468#pragma unroll
469 for (int i = 0; i < QR4_K; ++i) {
470 const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
471 const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
472
473 const int dot1 = ggml_cuda_dp4a(a: v1i, b: u[2*i+1], c: ggml_cuda_dp4a(a: v0i, b: u[2*i+0], c: 0)); // SIMD dot product
474 const int dot2 = ggml_cuda_dp4a(a: 0x01010101, b: u[2*i+1], c: ggml_cuda_dp4a(a: 0x01010101, b: u[2*i+0], c: 0)); // sum of u
475
476 sumf_d += d8[i] * (dot1 * sc[i]);
477 sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
478 }
479
480 const float2 dm4f = __half22float2(a: dm4);
481
482 return dm4f.x*sumf_d - dm4f.y*sumf_m;
483}
484
485// contiguous v/x + u/y values
486static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
487 const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
488 const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
489
490 float sumf_d = 0.0f;
491 float sumf_m = 0.0f;
492
493#pragma unroll
494 for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
495 int sumi_d = 0;
496
497#pragma unroll
498 for (int j = 0; j < QI8_1; ++j) {
499 sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
500 }
501
502 const float2 ds8f = __half22float2(a: ds8[i]);
503
504 sumf_d += ds8f.x * (sc[i] * sumi_d);
505 sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
506 }
507
508 const float2 dm4f = __half22float2(a: dm4);
509
510 return dm4f.x*sumf_d - dm4f.y*sumf_m;
511}
512
513#define VDR_Q5_K_Q8_1_MMVQ 2
514#define VDR_Q5_K_Q8_1_MMQ 8
515
516// contiguous v/x values
517static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
518 const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
519 const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
520
521 float sumf_d = 0.0f;
522 float sumf_m = 0.0f;
523
524#pragma unroll
525 for (int i = 0; i < QR5_K; ++i) {
526 const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
527 const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
528
529 const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
530 const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
531
532 const int v0i = vl0i | vh0i;
533 const int v1i = vl1i | vh1i;
534
535 const int dot1 = ggml_cuda_dp4a(a: v0i, b: u[2*i+0], c: ggml_cuda_dp4a(a: v1i, b: u[2*i+1], c: 0)); // SIMD dot product
536 const int dot2 = ggml_cuda_dp4a(a: 0x01010101, b: u[2*i+0], c: ggml_cuda_dp4a(a: 0x01010101, b: u[2*i+1], c: 0)); // sum of u
537
538 sumf_d += d8[i] * (dot1 * sc[i]);
539 sumf_m += d8[i] * (dot2 * m[i]);
540
541 }
542
543 const float2 dm5f = __half22float2(a: dm5);
544
545 return dm5f.x*sumf_d - dm5f.y*sumf_m;
546}
547
548// contiguous v/x + u/y values
549static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
550 const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
551 const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
552
553 float sumf_d = 0.0f;
554 float sumf_m = 0.0f;
555
556#pragma unroll
557 for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
558 int sumi_d = 0;
559
560#pragma unroll
561 for (int j = 0; j < QI8_1; ++j) {
562 sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
563 }
564
565 const float2 ds8f = __half22float2(a: ds8[i]);
566
567 sumf_d += ds8f.x * (sc[i] * sumi_d);
568 sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
569 }
570
571 const float2 dm4f = __half22float2(a: dm4);
572
573 return dm4f.x*sumf_d - dm4f.y*sumf_m;
574}
575
576#define VDR_Q6_K_Q8_1_MMVQ 1
577#define VDR_Q6_K_Q8_1_MMQ 8
578
579// contiguous v/x values
580static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
581 const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
582 const float & d, const float * __restrict__ d8) {
583
584 float sumf = 0.0f;
585
586#pragma unroll
587 for (int i = 0; i < QR6_K; ++i) {
588 const int sc = scales[4*i];
589
590 const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
591
592 const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
593
594 const int vi = __vsubss4(a: (vil | vih), b: 0x20202020); // vi = (vil | vih) - 32
595
596 sumf += d8[i] * (ggml_cuda_dp4a(a: vi, b: u[i], c: 0) * sc); // SIMD dot product
597 }
598
599 return d*sumf;
600}
601
602// contiguous v/x + u/y values
603static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
604 const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
605 const float & d6, const float * __restrict__ d8) {
606
607 float sumf_d = 0.0f;
608
609 const int sc_packed = get_int_b4(x: sc, i32: 0);
610 const int8_t * sc_reg = (const int8_t *) &sc_packed;
611
612#pragma unroll
613 for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
614 int2 sumi_d = {.x: 0, .y: 0}; // 2 q6_K scales per q8_1 scale
615
616#pragma unroll
617 for (int i = i0; i < i0 + 2; ++i) {
618 sumi_d.x = ggml_cuda_dp4a(a: v[2*i+0], b: u[2*i+0], c: sumi_d.x); // SIMD dot product
619 sumi_d.x = ggml_cuda_dp4a(a: v[2*i+1], b: u[2*i+1], c: sumi_d.x); // SIMD dot product
620
621 sumi_d.y = ggml_cuda_dp4a(a: v[2*i+4], b: u[2*i+4], c: sumi_d.y); // SIMD dot product
622 sumi_d.y = ggml_cuda_dp4a(a: v[2*i+5], b: u[2*i+5], c: sumi_d.y); // SIMD dot product
623 }
624
625 sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
626 }
627
628 return d6 * sumf_d;
629}
630
631static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
632 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
633
634 const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx;
635
636 int v[VDR_Q4_0_Q8_1_MMVQ];
637 int u[2*VDR_Q4_0_Q8_1_MMVQ];
638
639#pragma unroll
640 for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
641 v[i] = get_int_b2(bq4_0->qs, iqs + i);
642 u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
643 u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_0);
644 }
645
646 return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
647}
648
649
650static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
651 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
652
653 const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx;
654
655 int v[VDR_Q4_1_Q8_1_MMVQ];
656 int u[2*VDR_Q4_1_Q8_1_MMVQ];
657
658#pragma unroll
659 for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
660 v[i] = get_int_b4(bq4_1->qs, iqs + i);
661 u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
662 u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_1);
663 }
664
665 return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
666}
667
668static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
669 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
670
671 const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx;
672
673 int vl[VDR_Q5_0_Q8_1_MMVQ];
674 int vh[VDR_Q5_0_Q8_1_MMVQ];
675 int u[2*VDR_Q5_0_Q8_1_MMVQ];
676
677#pragma unroll
678 for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
679 vl[i] = get_int_b2(bq5_0->qs, iqs + i);
680 vh[i] = get_int_b2(bq5_0->qh, 0) >> (4 * (iqs + i));
681 u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
682 u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_0);
683 }
684
685 return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
686}
687
688static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
689 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
690
691 const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx;
692
693 int vl[VDR_Q5_1_Q8_1_MMVQ];
694 int vh[VDR_Q5_1_Q8_1_MMVQ];
695 int u[2*VDR_Q5_1_Q8_1_MMVQ];
696
697#pragma unroll
698 for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
699 vl[i] = get_int_b4(bq5_1->qs, iqs + i);
700 vh[i] = get_int_b4(bq5_1->qh, 0) >> (4 * (iqs + i));
701 u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
702 u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_1);
703 }
704
705 return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
706}
707
708static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
709 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
710
711 const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx;
712
713 int v[VDR_Q8_0_Q8_1_MMVQ];
714 int u[VDR_Q8_0_Q8_1_MMVQ];
715
716#pragma unroll
717 for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
718 v[i] = get_int_b2(bq8_0->qs, iqs + i);
719 u[i] = get_int_b4(bq8_1->qs, iqs + i);
720 }
721
722 return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
723}
724
725static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
726 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
727
728 const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx;
729
730 const int bq8_offset = QR2_K * (iqs / QI8_1);
731 const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
732
733 const uint8_t * scales = bq2_K->scales + scale_offset;
734
735 const int v = get_int_b4(bq2_K->qs, iqs);
736 int u[QR2_K];
737 float d8[QR2_K];
738
739#pragma unroll
740 for (int i = 0; i < QR2_K; ++ i) {
741 u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
742 d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
743 }
744
745 return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
746}
747
748static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
749 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
750
751 const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx;
752
753 const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
754 const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
755
756 const float d = bq3_K->d;
757
758 const int vl = get_int_b2(bq3_K->qs, iqs);
759
760 // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
761 const int vh = ~get_int_b2(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
762
763 int u[QR3_K];
764 float d8[QR3_K];
765
766#pragma unroll
767 for (int i = 0; i < QR3_K; ++i) {
768 u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
769 d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
770 }
771
772 return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
773}
774
775static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
776 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
777
778 const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx;
779
780 int v[2];
781 int u[2*QR4_K];
782 float d8[QR4_K];
783
784 // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
785 const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
786
787 // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
788 // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
789 // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
790 // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
791
792 const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
793 v[0] = q4[0];
794 v[1] = q4[4];
795
796 const uint16_t * scales = (const uint16_t *)bq4_K->scales;
797 uint16_t aux[2];
798 const int j = bq8_offset/2;
799 if (j < 2) {
800 aux[0] = scales[j+0] & 0x3f3f;
801 aux[1] = scales[j+2] & 0x3f3f;
802 } else {
803 aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
804 aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
805 }
806 const uint8_t * sc = (const uint8_t *)aux;
807 const uint8_t * m = sc + 2;
808
809 for (int i = 0; i < QR4_K; ++i) {
810 const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
811 d8[i] = __low2float(bq8i->ds);
812
813 const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
814 u[2*i+0] = q8[0];
815 u[2*i+1] = q8[4];
816 }
817
818 return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
819}
820
821static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
822 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
823
824 const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx;
825
826 int vl[2];
827 int vh[2];
828 int u[2*QR5_K];
829 float d8[QR5_K];
830
831 const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
832 const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
833 const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
834
835 vl[0] = ql[0];
836 vl[1] = ql[4];
837
838 vh[0] = qh[0] >> bq8_offset;
839 vh[1] = qh[4] >> bq8_offset;
840
841 const uint16_t * scales = (const uint16_t *)bq5_K->scales;
842 uint16_t aux[2];
843 const int j = bq8_offset/2;
844 if (j < 2) {
845 aux[0] = scales[j+0] & 0x3f3f;
846 aux[1] = scales[j+2] & 0x3f3f;
847 } else {
848 aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
849 aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
850 }
851 const uint8_t * sc = (const uint8_t *)aux;
852 const uint8_t * m = sc + 2;
853
854#pragma unroll
855 for (int i = 0; i < QR5_K; ++i) {
856 const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
857 d8[i] = __low2float(bq8i->ds);
858
859 const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
860 u[2*i+0] = q8[0];
861 u[2*i+1] = q8[4];
862 }
863
864 return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
865}
866
867static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
868 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
869
870 const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx;
871
872 const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
873 const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
874 const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
875
876 const int vl = get_int_b2(bq6_K->ql, iqs);
877 const int vh = get_int_b2(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
878
879 const int8_t * scales = bq6_K->scales + scale_offset;
880
881 int u[QR6_K];
882 float d8[QR6_K];
883
884#pragma unroll
885 for (int i = 0; i < QR6_K; ++i) {
886 u[i] = get_int_b4(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
887 d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds);
888 }
889
890 return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
891}
892
893#define VDR_IQ2_XXS_Q8_1_MMVQ 2
894#define VDR_IQ2_XXS_Q8_1_MMQ 2
895
896static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
897 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
898
899 const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx;
900
901 const int q2 = get_int_b2(bq2->qs, iqs);
902 const uint8_t * aux8 = (const uint8_t *) &q2;
903 const uint32_t aux32 = get_int_b2(bq2->qs, iqs + 1);
904
905 int sumi = 0;
906#pragma unroll
907 for (int k0 = 0; k0 < 8; k0 += 2) {
908 const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
909 const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
910
911 const int signs0 = __vcmpne4(a: ((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), b: 0x00000000);
912 const int grid0 = __vsub4(a: grid_pos[0] ^ signs0, b: signs0);
913 const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
914 sumi = ggml_cuda_dp4a(a: grid0, b: u0, c: sumi);
915
916 const int signs1 = __vcmpne4(a: ((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), b: 0x00000000);
917 const int grid1 = __vsub4(a: grid_pos[1] ^ signs1, b: signs1);
918 const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
919 sumi = ggml_cuda_dp4a(a: grid1, b: u1, c: sumi);
920 }
921
922 const int ls = aux32 >> 28;
923 sumi = (ls*sumi + sumi/2)/4;
924 const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
925 return d * sumi;
926}
927
928#define VDR_IQ2_XS_Q8_1_MMVQ 2
929#define VDR_IQ2_XS_Q8_1_MMQ 2
930
931static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
932 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
933
934 const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx;
935
936 const int2 q2_packed = make_int2(get_int_b2(bq2->qs, iqs + 0), get_int_b2(bq2->qs, iqs + 1));
937 const uint16_t * q2 = (const uint16_t *) &q2_packed;
938 const int ls0 = bq2->scales[iqs/2] & 0x0F;
939 const int ls1 = bq2->scales[iqs/2] >> 4;
940
941 int sumi0 = 0;
942 int sumi1 = 0;
943#pragma unroll
944 for (int l0 = 0; l0 < 8; l0 += 2) {
945 const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
946 const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9));
947
948 const int grid_l = __vsub4(a: grid_pos[0] ^ signs[0], b: signs[0]);
949 const int grid_h = __vsub4(a: grid_pos[1] ^ signs[1], b: signs[1]);
950
951 const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
952 const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
953
954 if (l0 < 4) {
955 sumi0 = ggml_cuda_dp4a(a: grid_l, b: u0, c: sumi0);
956 sumi0 = ggml_cuda_dp4a(a: grid_h, b: u1, c: sumi0);
957 } else {
958 sumi1 = ggml_cuda_dp4a(a: grid_l, b: u0, c: sumi1);
959 sumi1 = ggml_cuda_dp4a(a: grid_h, b: u1, c: sumi1);
960 }
961 }
962 const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
963 const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
964 return d * sumi;
965}
966
967#define VDR_IQ2_S_Q8_1_MMVQ 2
968#define VDR_IQ2_S_Q8_1_MMQ 2
969
970static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
971 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
972
973 const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx;
974
975 const int qs_packed = get_int_b2(bq2->qs, iqs/2);
976 const uint8_t * qs = (const uint8_t *) &qs_packed;
977
978 const int qh = bq2->qh[iqs/2];
979
980 const int signs_packed_32 = get_int_b2(bq2->qs, QK_K/32 + iqs/2);
981 const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
982
983 const int ls0 = bq2->scales[iqs/2] & 0x0F;
984 const int ls1 = bq2->scales[iqs/2] >> 4;
985
986 int sumi0 = 0;
987 int sumi1 = 0;
988#pragma unroll
989 for (int l0 = 0; l0 < 8; l0 += 2) {
990 const int * grid_pos = (const int *)(iq2s_grid + (qs[l0/2] | ((qh << (8-l0)) & 0x300)));
991
992 const int signs0 = __vcmpne4(a: ((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), b: 0x00000000);
993 const int signs1 = __vcmpne4(a: ((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), b: 0x00000000);
994
995 const int grid_l = __vsub4(a: grid_pos[0] ^ signs0, b: signs0);
996 const int grid_h = __vsub4(a: grid_pos[1] ^ signs1, b: signs1);
997
998 const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
999 const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
1000
1001 if (l0 < 4) {
1002 sumi0 = ggml_cuda_dp4a(a: grid_l, b: u0, c: sumi0);
1003 sumi0 = ggml_cuda_dp4a(a: grid_h, b: u1, c: sumi0);
1004 } else {
1005 sumi1 = ggml_cuda_dp4a(a: grid_l, b: u0, c: sumi1);
1006 sumi1 = ggml_cuda_dp4a(a: grid_h, b: u1, c: sumi1);
1007 }
1008 }
1009 const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
1010
1011 const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
1012 return d * sumi;
1013}
1014
1015#define VDR_IQ3_XXS_Q8_1_MMVQ 2
1016#define VDR_IQ3_XXS_Q8_1_MMQ 2
1017
1018static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
1019 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
1020
1021 const block_iq3_xxs * bq3 = (const block_iq3_xxs *) vbq + kbx;
1022
1023 const int2 q3_packed = make_int2(get_int_b2(bq3->qs, iqs), get_int_b2(bq3->qs, iqs+1));
1024 const uint8_t * q3 = (const uint8_t *) &q3_packed;
1025 const uint32_t aux32 = get_int_b2(bq3->qs, QK_K/16 + iqs/2);
1026
1027 int sumi = 0;
1028#pragma unroll
1029 for (int l0 = 0; l0 < 8; l0 += 2) {
1030 const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
1031
1032 const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
1033
1034 const int grid_l = __vsub4(a: grid_pos.x ^ signs[0], b: signs[0]);
1035 const int grid_h = __vsub4(a: grid_pos.y ^ signs[1], b: signs[1]);
1036
1037 const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
1038 const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
1039
1040 sumi = ggml_cuda_dp4a(a: grid_l, b: u0, c: sumi);
1041 sumi = ggml_cuda_dp4a(a: grid_h, b: u1, c: sumi);
1042 }
1043
1044 const int ls = aux32 >> 28;
1045 sumi = (ls*sumi + sumi/2)/2;
1046 const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);
1047 return d * sumi;
1048}
1049
1050#define VDR_IQ3_S_Q8_1_MMVQ 2
1051#define VDR_IQ3_S_Q8_1_MMQ 2
1052
1053// TODO: don't use lookup table for signs
1054static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
1055 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
1056
1057 const block_iq3_s * bq3 = (const block_iq3_s *) vbq + kbx;
1058
1059 const int2 qs_packed = make_int2(get_int_b2(bq3->qs, iqs + 0), get_int_b2(bq3->qs, iqs + 1));
1060 const uint8_t * qs = (const uint8_t *) &qs_packed;
1061
1062 const int qh = bq3->qh[iqs/2];
1063
1064 const int signs_packed_32 = get_int_b2(bq3->signs, iqs/2);
1065 const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
1066
1067 int sumi = 0;
1068#pragma unroll
1069 for (int l0 = 0; l0 < 8; l0 += 2) {
1070 const int2 grid_pos = make_int2(
1071 iq3s_grid[qs[l0 + 0] | ((qh << (8 - l0)) & 0x100)],
1072 iq3s_grid[qs[l0 + 1] | ((qh << (7 - l0)) & 0x100)]);
1073
1074 const int signs0 = __vcmpne4(a: ((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), b: 0x00000000);
1075 const int signs1 = __vcmpne4(a: ((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), b: 0x00000000);
1076
1077 const int grid_l = __vsub4(a: grid_pos.x ^ signs0, b: signs0);
1078 const int grid_h = __vsub4(a: grid_pos.y ^ signs1, b: signs1);
1079
1080 const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
1081 const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
1082
1083 sumi = ggml_cuda_dp4a(a: grid_l, b: u0, c: sumi);
1084 sumi = ggml_cuda_dp4a(a: grid_h, b: u1, c: sumi);
1085 }
1086
1087 sumi *= 1 + 2*((bq3->scales[iqs/4] >> ((iqs << 1) & 0x04)) & 0x0F);
1088
1089 const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);
1090 return d * sumi;
1091}
1092
1093#define VDR_IQ1_S_Q8_1_MMVQ 1
1094#define VDR_IQ1_S_Q8_1_MMQ 1
1095
1096static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
1097 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
1098 const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
1099
1100 const int qs_packed = get_int_b2(bq1->qs, iqs);
1101 const uint8_t * qs = (const uint8_t *) &qs_packed;
1102
1103 const int qh = bq1->qh[iqs];
1104
1105 int sumi = 0;
1106#pragma unroll
1107 for (int l0 = 0; l0 < 8; l0 += 2) {
1108 const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)];
1109
1110 const int grid0 = (grid >> 0) & 0x0F0F0F0F;
1111 const int grid1 = (grid >> 4) & 0x0F0F0F0F;
1112
1113 const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
1114 const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
1115
1116 sumi = ggml_cuda_dp4a(a: grid0, b: u0, c: sumi);
1117 sumi = ggml_cuda_dp4a(a: grid1, b: u1, c: sumi);
1118 }
1119
1120 const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1);
1121 const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
1122 const float2 ds = __half22float2(bq8_1[iqs].ds);
1123 return d1q * (ds.x*sumi + ds.y*delta);
1124}
1125
1126#define VDR_IQ1_M_Q8_1_MMVQ 1
1127#define VDR_IQ1_M_Q8_1_MMQ 1
1128
1129static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
1130 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
1131
1132 const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx;
1133
1134 const int qs_packed = get_int_b4(bq1->qs, iqs);
1135 const uint8_t * qs = (const uint8_t *) &qs_packed;
1136
1137 int sumi[2] = {0};
1138 float sumf[2] = {0.0f};
1139#pragma unroll
1140 for (int l0 = 0; l0 < 8; l0 += 2) {
1141 const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2));
1142
1143 const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)];
1144
1145 const int grid0 = (grid >> 0) & 0x0F0F0F0F;
1146 const int grid1 = (grid >> 4) & 0x0F0F0F0F;
1147
1148 const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
1149 const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
1150
1151 sumi[l0/4] = ggml_cuda_dp4a(a: grid0, b: u0, c: sumi[l0/4]);
1152 sumi[l0/4] = ggml_cuda_dp4a(a: grid1, b: u1, c: sumi[l0/4]);
1153
1154 const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08);
1155 int sumy = 0;
1156 sumy = ggml_cuda_dp4a(a: u0, b: 0x01010101, c: sumy);
1157 sumy = ggml_cuda_dp4a(a: u1, b: 0x01010101, c: sumy);
1158 sumf[l0/4] += delta*sumy;
1159 }
1160
1161 const uint16_t * sc = (const uint16_t *) bq1->scales;
1162
1163 iq1m_scale_t scale;
1164 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000);
1165 const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds);
1166
1167 const int tmp = sc[iqs/2] >> (6*(iqs%2));
1168 const int sc0 = 2*((tmp >> 0) & 0x07) + 1;
1169 const int sc1 = 2*((tmp >> 3) & 0x07) + 1;
1170 return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
1171}
1172
1173#define VDR_IQ4_NL_Q8_1_MMVQ 2
1174#define VDR_IQ4_NL_Q8_1_MMQ 4
1175
1176static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
1177 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
1178
1179 const block_iq4_nl * bq4 = (const block_iq4_nl *) vbq + kbx;
1180
1181 const int * q8 = (const int *) bq8_1->qs + iqs;
1182
1183 int sumi = 0;
1184#pragma unroll
1185 for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
1186 const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
1187 const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
1188
1189 sumi = ggml_cuda_dp4a(a: v.x, b: q8[l + 0], c: sumi);
1190 sumi = ggml_cuda_dp4a(a: v.y, b: q8[l + 4], c: sumi);
1191 }
1192
1193 const float d = __half2float(bq4->d) * __low2float(bq8_1->ds);
1194 return d * sumi;
1195}
1196
1197#define VDR_IQ4_XS_Q8_1_MMVQ 4
1198#define VDR_IQ4_XS_Q8_1_MMQ 4
1199
1200static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
1201 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
1202
1203 const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx;
1204
1205 int sumi = 0;
1206#pragma unroll
1207 for (int j = 0; j < 4; ++j) {
1208 const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
1209 const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
1210
1211 const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
1212 const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
1213
1214 sumi = ggml_cuda_dp4a(a: v.x, b: u0, c: sumi);
1215 sumi = ggml_cuda_dp4a(a: v.y, b: u1, c: sumi);
1216 }
1217
1218 const int ls = ((bq4->scales_l[iqs/8] >> (iqs & 0x04)) & 0x0F) | (((bq4->scales_h >> (iqs/2)) & 0x03) << 4);
1219 sumi *= ls - 32;
1220
1221 const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds);
1222 return d * sumi;
1223}
1224