1#pragma once
2// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
3// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
4// The documentation for the PTX instructions can be found under:
5// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
6//
7// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
8// A is a row-major matrix with shape M x K.
9// B is a column-major matrix with shape K x N.
10// C is a column-major matrix with shape M x N.
11// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
12// Note that J is measured in physical 32 bit elements instead of logical elements.
13// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
14// All matrix tiles have ne physical 32 bit elements per warp.
15//
16// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
17// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
18
19#include "common.cuh"
20
21// On Volta each warp is doing 4 8x8 mma operations in parallel.
22// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
23// However, the i indices in this file are by default permuted to simplify the index calculations.
24// #define GGML_CUDA_MMA_NO_VOLTA_PERM
25
26#if CUDART_VERSION >= 11080
27
28static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
29 int ret = 0;
30
31#ifdef TURING_MMA_AVAILABLE
32 asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
33 : "=r"(ret) : "r"(x));
34#else
35 GGML_UNUSED(x);
36 NO_DEVICE_CODE;
37#endif // defined(TURING_MMA_AVAILABLE)
38 return ret;
39}
40
41#else
42
43static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
44 // Imagine transposing row-major matrix to column-major matrix.
45 const int src_i_low = 2 * (threadIdx.x % 4);
46 const int src_i_high = src_i_low + 1;
47 const int src_j = threadIdx.x / 4;
48
49 const int src_laneid_low = src_i_low * 4 + src_j / 2;
50 const int src_laneid_high = src_i_high * 4 + src_j / 2;
51
52 const int shift_low = ((src_j + 0) % 2) * 16;
53 const int shift_high = ((src_j + 1) % 2) * 16;
54
55 const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
56 const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
57
58 return ret_low | ret_high;
59}
60
61#endif // CUDART_VERSION >= 11080
62
63static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
64 half2 ret;
65 *((int *) &ret) = ggml_cuda_movmatrix(x: *((const int *) &x));
66 return ret;
67}
68
69namespace ggml_cuda_mma {
70
71 template <int I_, int J_, typename T>
72 struct tile {
73 static constexpr int I = I_;
74 static constexpr int J = J_;
75
76#if defined(GGML_USE_HIP)
77 static constexpr int ne = I * J / 64;
78 T x[ne] = {0};
79
80 static constexpr __device__ bool supported() {
81 if (I == 64 && J == 2) return true;
82 if (I == 16 && J == 8) return true;
83 if (I == 32 && J == 4) return true;
84 if (I == 16 && J == 16) return true;
85 if (I == 32 && J == 32) return true;
86 return false;
87 }
88
89 static __device__ __forceinline__ int get_i(const int l) {
90 if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
91 return threadIdx.x % 16;
92 } else if constexpr (I == 16 && J == 8) {
93 return threadIdx.x % 16;
94 } else if constexpr (I == 32 && J == 4) {
95 return threadIdx.x % 32;
96 } else if constexpr (I == 16 && J == 16) {
97 return 4 * (threadIdx.x / 16) + l;
98 } else if constexpr (I == 32 && J == 32) {
99 return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
100 } else {
101 NO_DEVICE_CODE;
102 return -1;
103 }
104 }
105
106 static __device__ __forceinline__ int get_j(const int l) {
107 if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
108 return (2 * ((threadIdx.x / 16) % 2) + l);
109 } else if constexpr (I == 16 && J == 8) {
110 return 2 * (threadIdx.x / 16) + l;
111 } else if constexpr (I == 32 && J == 4) {
112 return 2 * (threadIdx.x / 32) + l;
113 } else if constexpr (I == 16 && J == 16) {
114 return threadIdx.x % 16;
115 } else if constexpr (I == 32 && J == 32) {
116 return threadIdx.x % 32;
117 } else {
118 NO_DEVICE_CODE;
119 return -1;
120 }
121 }
122#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
123 static constexpr int ne = I * J / 32;
124 T x[ne] = {0};
125
126 static constexpr __device__ bool supported() {
127 if (I == 32 && J == 8) return true;
128 return false;
129 }
130
131 static __device__ __forceinline__ int get_i(const int l) {
132 if constexpr (I == 32 && J == 8) {
133#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
134 return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
135#else
136 return (l & 2) | (threadIdx.x & ~2);
137#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
138 } else {
139 NO_DEVICE_CODE;
140 return -1;
141 }
142 }
143
144 static __device__ __forceinline__ int get_j(const int l) {
145 if constexpr (I == 32 && J == 8) {
146 return (threadIdx.x & 2) | (l & (4 + 1));
147 } else {
148 NO_DEVICE_CODE;
149 return -1;
150 }
151 }
152#else
153 static constexpr int ne = I * J / 32;
154 T x[ne] = {0};
155
156 static constexpr __device__ bool supported() {
157 if (I == 8 && J == 4) return true;
158 if (I == 8 && J == 8) return true;
159 if (I == 16 && J == 8) return true;
160 if (I == 16 && J == 16) return true;
161 if (I == 32 && J == 8) return true;
162 return false;
163 }
164
165 static __device__ __forceinline__ int get_i(const int l) {
166 if constexpr (I == 8 && J == 4) {
167 return threadIdx.x / 4;
168 } else if constexpr (I == 8 && J == 8) {
169 return threadIdx.x / 4;
170 } else if constexpr (I == 16 && J == 8) {
171 return ((l / 2) * 8) | (threadIdx.x / 4);
172 } else if constexpr (I == 16 && J == 16) {
173 return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
174 } else if constexpr (I == 32 && J == 8) {
175 return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
176 } else {
177 NO_DEVICE_CODE;
178 return -1;
179 }
180 }
181
182 static __device__ __forceinline__ int get_j(const int l) {
183 if constexpr (I == 8 && J == 4) {
184 return threadIdx.x % 4;
185 } else if constexpr (I == 8 && J == 8) {
186 return (l * 4) | (threadIdx.x % 4);
187 } else if constexpr (I == 16 && J == 8) {
188 return ((threadIdx.x % 4) * 2) | (l % 2);
189 } else if constexpr (I == 16 && J == 16) {
190 return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
191 } else if constexpr (I == 32 && J == 8) {
192 return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
193 } else {
194 NO_DEVICE_CODE;
195 return -1;
196 }
197 }
198#endif // defined(GGML_USE_HIP)
199 };
200
201 template <int I_, int J_>
202 struct tile<I_, J_, half2> {
203 static constexpr int I = I_;
204 static constexpr int J = J_;
205
206#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
207 static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
208 half2 x[ne] = {{0.0f, 0.0f}};
209
210 static constexpr __device__ bool supported() {
211 if (I == 8 && J == 8) return true;
212 if (I == 32 && J == 8) return true;
213 return false;
214 }
215
216 static __device__ __forceinline__ int get_i(const int l) {
217 if constexpr (I == 8 && J == 8) {
218 return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
219 } else if constexpr (I == 32 && J == 8) {
220#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
221 return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
222#else
223 return threadIdx.x;
224#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
225 } else {
226 NO_DEVICE_CODE;
227 return -1;
228 }
229 }
230
231 static __device__ __forceinline__ int get_j(const int l) {
232 if constexpr ((I == 8 || I == 32) && J == 8) {
233 return l;
234 } else {
235 NO_DEVICE_CODE;
236 return -1;
237 }
238 }
239#else
240 static constexpr int ne = I * J / WARP_SIZE;
241 half2 x[ne] = {{0.0f, 0.0f}};
242
243 static constexpr __device__ bool supported() {
244 if (I == 8 && J == 4) return true;
245 if (I == 8 && J == 8) return true;
246 if (I == 16 && J == 8) return true;
247 if (I == 16 && J == 16) return true;
248 if (I == 32 && J == 8) return true;
249 return false;
250 }
251
252 static __device__ __forceinline__ int get_i(const int l) {
253 if constexpr (I == 8 && J == 8) {
254 return threadIdx.x / 4;
255 } else if constexpr (I == 16 && J == 4) {
256 return (l * 8) | (threadIdx.x / 4);
257 } else if constexpr (I == 16 && J == 8) {
258 return ((l % 2) * 8) | (threadIdx.x / 4);
259 } else if constexpr (I == 32 && J == 8) {
260 return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
261 } else {
262 NO_DEVICE_CODE;
263 return -1;
264 }
265 }
266
267 static __device__ __forceinline__ int get_j(const int l) {
268 if constexpr (I == 8 && J == 8) {
269 return (l * 4) | (threadIdx.x % 4);
270 } else if constexpr (I == 16 && J == 4) {
271 return threadIdx.x % 4;
272 } else if constexpr (I == 16 && J == 8) {
273 return ((l / 2) * 4) | (threadIdx.x % 4);
274 } else if constexpr (I == 32 && J == 8) {
275 return ((l & 2) * 2) | (threadIdx.x % 4);
276 } else {
277 NO_DEVICE_CODE;
278 return -1;
279 }
280 }
281#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
282 };
283
284 template <int I_, int J_>
285 struct tile<I_, J_, nv_bfloat162> {
286 static constexpr int I = I_;
287 static constexpr int J = J_;
288 static constexpr int ne = I * J / WARP_SIZE;
289 nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
290
291 static constexpr __device__ bool supported() {
292 if (I == 8 && J == 8) return true;
293 if (I == 16 && J == 4) return true;
294 if (I == 16 && J == 8) return true;
295 return false;
296 }
297
298 static __device__ __forceinline__ int get_i(const int l) {
299 if constexpr (I == 8 && J == 8) {
300 return threadIdx.x / 4;
301 } else if constexpr (I == 16 && J == 4) {
302 return (l * 8) | (threadIdx.x / 4);
303 } else if constexpr (I == 16 && J == 8) {
304 return ((l % 2) * 8) | (threadIdx.x / 4);
305 } else {
306 NO_DEVICE_CODE;
307 return -1;
308 }
309 }
310
311 static __device__ __forceinline__ int get_j(const int l) {
312 if constexpr (I == 8 && J == 8) {
313 return (l * 4) | (threadIdx.x % 4);
314 } else if constexpr (I == 16 && J == 4) {
315 return threadIdx.x % 4;
316 } else if constexpr (I == 16 && J == 8) {
317 return ((l / 2) * 4) | (threadIdx.x % 4);
318 } else {
319 NO_DEVICE_CODE;
320 return -1;
321 }
322 }
323 };
324
325 template <int I, int J>
326 static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
327 tile<I, J/2, half2> ret;
328#pragma unroll
329 for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
330 ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
331 }
332 return ret;
333 }
334
335 static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
336 tile<8, 8, half2> ret;
337 ret.x[0] = ggml_cuda_movmatrix(x: t.x[0]);
338 ret.x[1] = ggml_cuda_movmatrix(x: t.x[1]);
339
340 return ret;
341 }
342
343 template <int I, int J, typename T>
344 static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
345#if defined(AMD_MFMA_AVAILABLE)
346 if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
347#pragma unroll
348 for (int l = 0; l < t.ne; ++l) {
349 t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
350 }
351 } else {
352 int64_t * xi = (int64_t *) t.x;
353 const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
354 xi[0] = xs[0];
355 }
356#else
357#pragma unroll
358 for (int l = 0; l < t.ne; ++l) {
359 t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
360 }
361#endif // defined(AMD_MFMA_AVAILABLE)
362 }
363
364 template <typename T>
365 static __device__ __forceinline__ void load_ldmatrix(
366 tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
367#ifdef TURING_MMA_AVAILABLE
368 int * xi = (int *) t.x;
369 const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
370 asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
371 : "=r"(xi[0]), "=r"(xi[1])
372 : "l"(xs));
373#else
374 load_generic(t, xs0, stride);
375#endif // TURING_MMA_AVAILABLE
376 }
377
378 template <typename T>
379 static __device__ __forceinline__ void load_ldmatrix(
380 tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
381#ifdef TURING_MMA_AVAILABLE
382 int * xi = (int *) t.x;
383 const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
384 asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
385 : "=r"(xi[0]), "=r"(xi[1])
386 : "l"(xs));
387#else
388#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
389 GGML_UNUSED_VARS(t, xs0, stride);
390 NO_DEVICE_CODE;
391#else
392 load_generic(t, xs0, stride);
393#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
394#endif // TURING_MMA_AVAILABLE
395 }
396
397 template <typename T>
398 static __device__ __forceinline__ void load_ldmatrix(
399 tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
400#if defined(TURING_MMA_AVAILABLE)
401 int * xi = (int * ) t.x;
402 const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
403 asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
404 : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
405 : "l"(xs));
406#else
407#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
408 GGML_UNUSED_VARS(t, xs0, stride);
409 NO_DEVICE_CODE;
410#else
411 load_generic(t, xs0, stride);
412#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
413#endif // TURING_MMA_AVAILABLE
414 }
415
416 template <typename T>
417 static __device__ __forceinline__ void load_ldmatrix(
418 tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
419#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
420#if 1
421 // TODO: more generic handling
422 static_assert(sizeof(T) == 4, "bad type size");
423 ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
424 ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
425#else
426 load_generic(t, xs0, stride);
427#endif // 1
428#else
429 tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
430 load_ldmatrix(t16[0], xs0 + 0*stride, stride);
431 load_ldmatrix(t16[1], xs0 + 16*stride, stride);
432#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
433 }
434
435 template <typename T>
436 static __device__ __forceinline__ void load_ldmatrix_trans(
437 tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
438#ifdef TURING_MMA_AVAILABLE
439 int * xi = (int * ) t.x;
440 const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
441 asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
442 : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
443 : "l"(xs));
444#else
445 GGML_UNUSED_VARS(t, xs0, stride);
446 NO_DEVICE_CODE;
447#endif // TURING_MMA_AVAILABLE
448 }
449
450 static __device__ __forceinline__ void mma(
451 tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
452#ifdef TURING_MMA_AVAILABLE
453#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
454 asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
455 : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
456 : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
457#else
458 // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
459 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
460 : "+r"(D.x[0]), "+r"(D.x[1])
461 : "r"(A.x[0]), "r"(B.x[0]));
462 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
463 : "+r"(D.x[2]), "+r"(D.x[3])
464 : "r"(A.x[1]), "r"(B.x[0]));
465#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
466#else
467 GGML_UNUSED_VARS(D, A, B);
468 NO_DEVICE_CODE;
469#endif // TURING_MMA_AVAILABLE
470 }
471
472 static __device__ __forceinline__ void mma(
473 tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
474#ifdef TURING_MMA_AVAILABLE
475#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
476 asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
477 : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
478 : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
479#else
480 // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
481 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
482 : "+r"(D.x[0]), "+r"(D.x[1])
483 : "r"(A.x[0]), "r"(B.x[0]));
484 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
485 : "+r"(D.x[2]), "+r"(D.x[3])
486 : "r"(A.x[1]), "r"(B.x[0]));
487 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
488 : "+r"(D.x[0]), "+r"(D.x[1])
489 : "r"(A.x[2]), "r"(B.x[1]));
490 asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
491 : "+r"(D.x[2]), "+r"(D.x[3])
492 : "r"(A.x[3]), "r"(B.x[1]));
493#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
494#else
495 GGML_UNUSED_VARS(D, A, B);
496 NO_DEVICE_CODE;
497#endif // TURING_MMA_AVAILABLE
498 }
499
500 static __device__ __forceinline__ void mma(
501 tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
502#ifdef TURING_MMA_AVAILABLE
503 const int * Axi = (const int *) A.x;
504 const int * Bxi = (const int *) B.x;
505 int * Dxi = (int *) D.x;
506#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
507 asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
508 : "+r"(Dxi[0]), "+r"(Dxi[1])
509 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
510#else
511 // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
512 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
513 : "+r"(Dxi[0]), "+r"(Dxi[1])
514 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
515 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
516 : "+r"(Dxi[0]), "+r"(Dxi[1])
517 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
518#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
519#else
520 GGML_UNUSED_VARS(D, A, B);
521 NO_DEVICE_CODE;
522#endif // TURING_MMA_AVAILABLE
523 }
524
525 static __device__ __forceinline__ void mma(
526 tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
527#ifdef TURING_MMA_AVAILABLE
528 const int * Axi = (const int *) A.x;
529 const int * Bxi = (const int *) B.x;
530 int * Dxi = (int *) D.x;
531#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
532 asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
533 : "+r"(Dxi[0]), "+r"(Dxi[1])
534 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
535 asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
536 : "+r"(Dxi[2]), "+r"(Dxi[3])
537 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
538#else
539 // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
540 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
541 : "+r"(Dxi[0]), "+r"(Dxi[1])
542 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
543 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
544 : "+r"(Dxi[0]), "+r"(Dxi[1])
545 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
546 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
547 : "+r"(Dxi[2]), "+r"(Dxi[3])
548 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
549 asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
550 : "+r"(Dxi[2]), "+r"(Dxi[3])
551 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
552#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
553#else
554 GGML_UNUSED_VARS(D, A, B);
555 NO_DEVICE_CODE;
556#endif // TURING_MMA_AVAILABLE
557 }
558
559 static __device__ __forceinline__ void mma(
560 tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
561#ifdef AMPERE_MMA_AVAILABLE
562 const int * Axi = (const int *) A.x;
563 const int * Bxi = (const int *) B.x;
564 int * Dxi = (int *) D.x;
565 asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
566 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
567 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
568#else
569 GGML_UNUSED_VARS(D, A, B);
570 NO_DEVICE_CODE;
571#endif // AMPERE_MMA_AVAILABLE
572 }
573
574 static __device__ __forceinline__ void mma(
575 tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
576#ifdef TURING_MMA_AVAILABLE
577 const int * Axi = (const int *) A.x;
578 const int * Bxi = (const int *) B.x;
579 int * Dxi = (int *) D.x;
580#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
581 asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
582 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
583 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
584#else
585 // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
586 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
587 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
588 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
589 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
590 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
591 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
592#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
593#else
594 GGML_UNUSED_VARS(D, A, B);
595 NO_DEVICE_CODE;
596#endif // TURING_MMA_AVAILABLE
597 }
598
599 static __device__ __forceinline__ void mma(
600 tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
601#ifdef AMPERE_MMA_AVAILABLE
602 const int * Axi = (const int *) A.x;
603 const int * Bxi = (const int *) B.x;
604 int * Dxi = (int *) D.x;
605 asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
606 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
607 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
608#else
609 GGML_UNUSED_VARS(D, A, B);
610 NO_DEVICE_CODE;
611#endif // AMPERE_MMA_AVAILABLE
612 }
613
614 static __device__ __forceinline__ void mma(
615 tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
616#ifdef TURING_MMA_AVAILABLE
617 const int * Axi = (const int *) A.x;
618 const int * Bxi = (const int *) B.x;
619 int * Dxi = (int *) D.x;
620#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
621 asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
622 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
623 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
624 asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
625 : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
626 : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
627#else
628 // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
629 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
630 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
631 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
632 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
633 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
634 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
635 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
636 : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
637 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
638 asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
639 : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
640 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
641#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
642#else
643 GGML_UNUSED_VARS(D, A, B);
644 NO_DEVICE_CODE;
645#endif // TURING_MMA_AVAILABLE
646 }
647
648 static __device__ __forceinline__ void mma(
649 tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
650#if defined(AMD_MFMA_AVAILABLE)
651 using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
652 int32x4_t * acc = (int32x4_t *) D.x;
653#if defined(CDNA3)
654 acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
655 ((int64_t *) B.x)[0],
656 acc[0],
657 0, 0, 0);
658#elif defined(CDNA2) || defined(CDNA)
659 acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
660 B.x[0],
661 acc[0],
662 0, 0, 0);
663 acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
664 B.x[1],
665 acc[0],
666 0, 0, 0);
667#endif // defined(CDNA3)
668#else
669 GGML_UNUSED_VARS(D, A, B);
670 NO_DEVICE_CODE;
671#endif // AMD_MFMA_AVAILABLE
672 }
673
674 static __device__ __forceinline__ void mma(
675 tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
676#if defined(AMD_MFMA_AVAILABLE)
677 using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
678 int32x16_t * acc = (int32x16_t *) D.x;
679#if defined(CDNA3)
680 acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
681 ((int64_t *) B.x)[0],
682 acc[0],
683 0, 0, 0);
684#elif defined(CDNA2) || defined(CDNA)
685 acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
686 B.x[0],
687 acc[0],
688 0, 0, 0);
689 acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
690 B.x[1],
691 acc[0],
692 0, 0, 0);
693#endif // defined(CDNA3)
694#else
695 GGML_UNUSED_VARS(D, A, B);
696 NO_DEVICE_CODE;
697#endif // AMD_MFMA_AVAILABLE
698 }
699
700 template <typename T1, typename T2, int J, int K>
701 static __device__ __forceinline__ void mma(
702 tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
703 tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
704 tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
705 mma(D16[0], A16[0], B);
706 mma(D16[1], A16[1], B);
707 }
708
709 static __device__ __forceinline__ void mma(
710 tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
711#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
712 const int * Axi = (const int *) A.x;
713 const int * Bxi = (const int *) B.x;
714 int * Dxi = (int *) D.x;
715 asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
716 "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
717 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
718 : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
719 asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
720 "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
721 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
722 : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
723 asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
724 "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
725 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
726 : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
727 asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
728 "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
729 : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
730 : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
731#else
732 tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
733 tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
734 mma(D&: D16[0], A: A16[0], B);
735 mma(D&: D16[1], A: A16[1], B);
736#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
737 }
738}
739