| 1 | #pragma once |
| 2 | // This file contains primitives that expose the tensor core PTX instructions for CUDA code. |
| 3 | // The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. |
| 4 | // The documentation for the PTX instructions can be found under: |
| 5 | // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction |
| 6 | // |
| 7 | // Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. |
| 8 | // A is a row-major matrix with shape M x K. |
| 9 | // B is a column-major matrix with shape K x N. |
| 10 | // C is a column-major matrix with shape M x N. |
| 11 | // A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns. |
| 12 | // Note that J is measured in physical 32 bit elements instead of logical elements. |
| 13 | // The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile. |
| 14 | // All matrix tiles have ne physical 32 bit elements per warp. |
| 15 | // |
| 16 | // As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. |
| 17 | // The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior. |
| 18 | |
| 19 | #include "common.cuh" |
| 20 | |
| 21 | // On Volta each warp is doing 4 8x8 mma operations in parallel. |
| 22 | // The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile. |
| 23 | // However, the i indices in this file are by default permuted to simplify the index calculations. |
| 24 | // #define GGML_CUDA_MMA_NO_VOLTA_PERM |
| 25 | |
| 26 | #if CUDART_VERSION >= 11080 |
| 27 | |
| 28 | static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { |
| 29 | int ret = 0; |
| 30 | |
| 31 | #ifdef TURING_MMA_AVAILABLE |
| 32 | asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" |
| 33 | : "=r" (ret) : "r" (x)); |
| 34 | #else |
| 35 | GGML_UNUSED(x); |
| 36 | NO_DEVICE_CODE; |
| 37 | #endif // defined(TURING_MMA_AVAILABLE) |
| 38 | return ret; |
| 39 | } |
| 40 | |
| 41 | #else |
| 42 | |
| 43 | static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { |
| 44 | // Imagine transposing row-major matrix to column-major matrix. |
| 45 | const int src_i_low = 2 * (threadIdx.x % 4); |
| 46 | const int src_i_high = src_i_low + 1; |
| 47 | const int src_j = threadIdx.x / 4; |
| 48 | |
| 49 | const int src_laneid_low = src_i_low * 4 + src_j / 2; |
| 50 | const int src_laneid_high = src_i_high * 4 + src_j / 2; |
| 51 | |
| 52 | const int shift_low = ((src_j + 0) % 2) * 16; |
| 53 | const int shift_high = ((src_j + 1) % 2) * 16; |
| 54 | |
| 55 | const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; |
| 56 | const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; |
| 57 | |
| 58 | return ret_low | ret_high; |
| 59 | } |
| 60 | |
| 61 | #endif // CUDART_VERSION >= 11080 |
| 62 | |
| 63 | static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { |
| 64 | half2 ret; |
| 65 | *((int *) &ret) = ggml_cuda_movmatrix(x: *((const int *) &x)); |
| 66 | return ret; |
| 67 | } |
| 68 | |
| 69 | namespace ggml_cuda_mma { |
| 70 | |
| 71 | template <int I_, int J_, typename T> |
| 72 | struct tile { |
| 73 | static constexpr int I = I_; |
| 74 | static constexpr int J = J_; |
| 75 | |
| 76 | #if defined(GGML_USE_HIP) |
| 77 | static constexpr int ne = I * J / 64; |
| 78 | T x[ne] = {0}; |
| 79 | |
| 80 | static constexpr __device__ bool supported() { |
| 81 | if (I == 64 && J == 2) return true; |
| 82 | if (I == 16 && J == 8) return true; |
| 83 | if (I == 32 && J == 4) return true; |
| 84 | if (I == 16 && J == 16) return true; |
| 85 | if (I == 32 && J == 32) return true; |
| 86 | return false; |
| 87 | } |
| 88 | |
| 89 | static __device__ __forceinline__ int get_i(const int l) { |
| 90 | if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> |
| 91 | return threadIdx.x % 16; |
| 92 | } else if constexpr (I == 16 && J == 8) { |
| 93 | return threadIdx.x % 16; |
| 94 | } else if constexpr (I == 32 && J == 4) { |
| 95 | return threadIdx.x % 32; |
| 96 | } else if constexpr (I == 16 && J == 16) { |
| 97 | return 4 * (threadIdx.x / 16) + l; |
| 98 | } else if constexpr (I == 32 && J == 32) { |
| 99 | return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); |
| 100 | } else { |
| 101 | NO_DEVICE_CODE; |
| 102 | return -1; |
| 103 | } |
| 104 | } |
| 105 | |
| 106 | static __device__ __forceinline__ int get_j(const int l) { |
| 107 | if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> |
| 108 | return (2 * ((threadIdx.x / 16) % 2) + l); |
| 109 | } else if constexpr (I == 16 && J == 8) { |
| 110 | return 2 * (threadIdx.x / 16) + l; |
| 111 | } else if constexpr (I == 32 && J == 4) { |
| 112 | return 2 * (threadIdx.x / 32) + l; |
| 113 | } else if constexpr (I == 16 && J == 16) { |
| 114 | return threadIdx.x % 16; |
| 115 | } else if constexpr (I == 32 && J == 32) { |
| 116 | return threadIdx.x % 32; |
| 117 | } else { |
| 118 | NO_DEVICE_CODE; |
| 119 | return -1; |
| 120 | } |
| 121 | } |
| 122 | #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA |
| 123 | static constexpr int ne = I * J / 32; |
| 124 | T x[ne] = {0}; |
| 125 | |
| 126 | static constexpr __device__ bool supported() { |
| 127 | if (I == 32 && J == 8) return true; |
| 128 | return false; |
| 129 | } |
| 130 | |
| 131 | static __device__ __forceinline__ int get_i(const int l) { |
| 132 | if constexpr (I == 32 && J == 8) { |
| 133 | #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM |
| 134 | return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2); |
| 135 | #else |
| 136 | return (l & 2) | (threadIdx.x & ~2); |
| 137 | #endif // GGML_CUDA_MMA_NO_VOLTA_PERM |
| 138 | } else { |
| 139 | NO_DEVICE_CODE; |
| 140 | return -1; |
| 141 | } |
| 142 | } |
| 143 | |
| 144 | static __device__ __forceinline__ int get_j(const int l) { |
| 145 | if constexpr (I == 32 && J == 8) { |
| 146 | return (threadIdx.x & 2) | (l & (4 + 1)); |
| 147 | } else { |
| 148 | NO_DEVICE_CODE; |
| 149 | return -1; |
| 150 | } |
| 151 | } |
| 152 | #else |
| 153 | static constexpr int ne = I * J / 32; |
| 154 | T x[ne] = {0}; |
| 155 | |
| 156 | static constexpr __device__ bool supported() { |
| 157 | if (I == 8 && J == 4) return true; |
| 158 | if (I == 8 && J == 8) return true; |
| 159 | if (I == 16 && J == 8) return true; |
| 160 | if (I == 16 && J == 16) return true; |
| 161 | if (I == 32 && J == 8) return true; |
| 162 | return false; |
| 163 | } |
| 164 | |
| 165 | static __device__ __forceinline__ int get_i(const int l) { |
| 166 | if constexpr (I == 8 && J == 4) { |
| 167 | return threadIdx.x / 4; |
| 168 | } else if constexpr (I == 8 && J == 8) { |
| 169 | return threadIdx.x / 4; |
| 170 | } else if constexpr (I == 16 && J == 8) { |
| 171 | return ((l / 2) * 8) | (threadIdx.x / 4); |
| 172 | } else if constexpr (I == 16 && J == 16) { |
| 173 | return (((l / 2) % 2) * 8) | (threadIdx.x / 4); |
| 174 | } else if constexpr (I == 32 && J == 8) { |
| 175 | return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction. |
| 176 | } else { |
| 177 | NO_DEVICE_CODE; |
| 178 | return -1; |
| 179 | } |
| 180 | } |
| 181 | |
| 182 | static __device__ __forceinline__ int get_j(const int l) { |
| 183 | if constexpr (I == 8 && J == 4) { |
| 184 | return threadIdx.x % 4; |
| 185 | } else if constexpr (I == 8 && J == 8) { |
| 186 | return (l * 4) | (threadIdx.x % 4); |
| 187 | } else if constexpr (I == 16 && J == 8) { |
| 188 | return ((threadIdx.x % 4) * 2) | (l % 2); |
| 189 | } else if constexpr (I == 16 && J == 16) { |
| 190 | return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2); |
| 191 | } else if constexpr (I == 32 && J == 8) { |
| 192 | return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction. |
| 193 | } else { |
| 194 | NO_DEVICE_CODE; |
| 195 | return -1; |
| 196 | } |
| 197 | } |
| 198 | #endif // defined(GGML_USE_HIP) |
| 199 | }; |
| 200 | |
| 201 | template <int I_, int J_> |
| 202 | struct tile<I_, J_, half2> { |
| 203 | static constexpr int I = I_; |
| 204 | static constexpr int J = J_; |
| 205 | |
| 206 | #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA |
| 207 | static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE; |
| 208 | half2 x[ne] = {{0.0f, 0.0f}}; |
| 209 | |
| 210 | static constexpr __device__ bool supported() { |
| 211 | if (I == 8 && J == 8) return true; |
| 212 | if (I == 32 && J == 8) return true; |
| 213 | return false; |
| 214 | } |
| 215 | |
| 216 | static __device__ __forceinline__ int get_i(const int l) { |
| 217 | if constexpr (I == 8 && J == 8) { |
| 218 | return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); |
| 219 | } else if constexpr (I == 32 && J == 8) { |
| 220 | #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM |
| 221 | return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4); |
| 222 | #else |
| 223 | return threadIdx.x; |
| 224 | #endif // GGML_CUDA_MMA_NO_VOLTA_PERM |
| 225 | } else { |
| 226 | NO_DEVICE_CODE; |
| 227 | return -1; |
| 228 | } |
| 229 | } |
| 230 | |
| 231 | static __device__ __forceinline__ int get_j(const int l) { |
| 232 | if constexpr ((I == 8 || I == 32) && J == 8) { |
| 233 | return l; |
| 234 | } else { |
| 235 | NO_DEVICE_CODE; |
| 236 | return -1; |
| 237 | } |
| 238 | } |
| 239 | #else |
| 240 | static constexpr int ne = I * J / WARP_SIZE; |
| 241 | half2 x[ne] = {{0.0f, 0.0f}}; |
| 242 | |
| 243 | static constexpr __device__ bool supported() { |
| 244 | if (I == 8 && J == 4) return true; |
| 245 | if (I == 8 && J == 8) return true; |
| 246 | if (I == 16 && J == 8) return true; |
| 247 | if (I == 16 && J == 16) return true; |
| 248 | if (I == 32 && J == 8) return true; |
| 249 | return false; |
| 250 | } |
| 251 | |
| 252 | static __device__ __forceinline__ int get_i(const int l) { |
| 253 | if constexpr (I == 8 && J == 8) { |
| 254 | return threadIdx.x / 4; |
| 255 | } else if constexpr (I == 16 && J == 4) { |
| 256 | return (l * 8) | (threadIdx.x / 4); |
| 257 | } else if constexpr (I == 16 && J == 8) { |
| 258 | return ((l % 2) * 8) | (threadIdx.x / 4); |
| 259 | } else if constexpr (I == 32 && J == 8) { |
| 260 | return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4); |
| 261 | } else { |
| 262 | NO_DEVICE_CODE; |
| 263 | return -1; |
| 264 | } |
| 265 | } |
| 266 | |
| 267 | static __device__ __forceinline__ int get_j(const int l) { |
| 268 | if constexpr (I == 8 && J == 8) { |
| 269 | return (l * 4) | (threadIdx.x % 4); |
| 270 | } else if constexpr (I == 16 && J == 4) { |
| 271 | return threadIdx.x % 4; |
| 272 | } else if constexpr (I == 16 && J == 8) { |
| 273 | return ((l / 2) * 4) | (threadIdx.x % 4); |
| 274 | } else if constexpr (I == 32 && J == 8) { |
| 275 | return ((l & 2) * 2) | (threadIdx.x % 4); |
| 276 | } else { |
| 277 | NO_DEVICE_CODE; |
| 278 | return -1; |
| 279 | } |
| 280 | } |
| 281 | #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA |
| 282 | }; |
| 283 | |
| 284 | template <int I_, int J_> |
| 285 | struct tile<I_, J_, nv_bfloat162> { |
| 286 | static constexpr int I = I_; |
| 287 | static constexpr int J = J_; |
| 288 | static constexpr int ne = I * J / WARP_SIZE; |
| 289 | nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; |
| 290 | |
| 291 | static constexpr __device__ bool supported() { |
| 292 | if (I == 8 && J == 8) return true; |
| 293 | if (I == 16 && J == 4) return true; |
| 294 | if (I == 16 && J == 8) return true; |
| 295 | return false; |
| 296 | } |
| 297 | |
| 298 | static __device__ __forceinline__ int get_i(const int l) { |
| 299 | if constexpr (I == 8 && J == 8) { |
| 300 | return threadIdx.x / 4; |
| 301 | } else if constexpr (I == 16 && J == 4) { |
| 302 | return (l * 8) | (threadIdx.x / 4); |
| 303 | } else if constexpr (I == 16 && J == 8) { |
| 304 | return ((l % 2) * 8) | (threadIdx.x / 4); |
| 305 | } else { |
| 306 | NO_DEVICE_CODE; |
| 307 | return -1; |
| 308 | } |
| 309 | } |
| 310 | |
| 311 | static __device__ __forceinline__ int get_j(const int l) { |
| 312 | if constexpr (I == 8 && J == 8) { |
| 313 | return (l * 4) | (threadIdx.x % 4); |
| 314 | } else if constexpr (I == 16 && J == 4) { |
| 315 | return threadIdx.x % 4; |
| 316 | } else if constexpr (I == 16 && J == 8) { |
| 317 | return ((l / 2) * 4) | (threadIdx.x % 4); |
| 318 | } else { |
| 319 | NO_DEVICE_CODE; |
| 320 | return -1; |
| 321 | } |
| 322 | } |
| 323 | }; |
| 324 | |
| 325 | template <int I, int J> |
| 326 | static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { |
| 327 | tile<I, J/2, half2> ret; |
| 328 | #pragma unroll |
| 329 | for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { |
| 330 | ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); |
| 331 | } |
| 332 | return ret; |
| 333 | } |
| 334 | |
| 335 | static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { |
| 336 | tile<8, 8, half2> ret; |
| 337 | ret.x[0] = ggml_cuda_movmatrix(x: t.x[0]); |
| 338 | ret.x[1] = ggml_cuda_movmatrix(x: t.x[1]); |
| 339 | |
| 340 | return ret; |
| 341 | } |
| 342 | |
| 343 | template <int I, int J, typename T> |
| 344 | static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) { |
| 345 | #if defined(AMD_MFMA_AVAILABLE) |
| 346 | if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> |
| 347 | #pragma unroll |
| 348 | for (int l = 0; l < t.ne; ++l) { |
| 349 | t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; |
| 350 | } |
| 351 | } else { |
| 352 | int64_t * xi = (int64_t *) t.x; |
| 353 | const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); |
| 354 | xi[0] = xs[0]; |
| 355 | } |
| 356 | #else |
| 357 | #pragma unroll |
| 358 | for (int l = 0; l < t.ne; ++l) { |
| 359 | t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; |
| 360 | } |
| 361 | #endif // defined(AMD_MFMA_AVAILABLE) |
| 362 | } |
| 363 | |
| 364 | template <typename T> |
| 365 | static __device__ __forceinline__ void load_ldmatrix( |
| 366 | tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) { |
| 367 | #ifdef TURING_MMA_AVAILABLE |
| 368 | int * xi = (int *) t.x; |
| 369 | const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J; |
| 370 | asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" |
| 371 | : "=r" (xi[0]), "=r" (xi[1]) |
| 372 | : "l" (xs)); |
| 373 | #else |
| 374 | load_generic(t, xs0, stride); |
| 375 | #endif // TURING_MMA_AVAILABLE |
| 376 | } |
| 377 | |
| 378 | template <typename T> |
| 379 | static __device__ __forceinline__ void load_ldmatrix( |
| 380 | tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { |
| 381 | #ifdef TURING_MMA_AVAILABLE |
| 382 | int * xi = (int *) t.x; |
| 383 | const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; |
| 384 | asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" |
| 385 | : "=r" (xi[0]), "=r" (xi[1]) |
| 386 | : "l" (xs)); |
| 387 | #else |
| 388 | #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA |
| 389 | GGML_UNUSED_VARS(t, xs0, stride); |
| 390 | NO_DEVICE_CODE; |
| 391 | #else |
| 392 | load_generic(t, xs0, stride); |
| 393 | #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA |
| 394 | #endif // TURING_MMA_AVAILABLE |
| 395 | } |
| 396 | |
| 397 | template <typename T> |
| 398 | static __device__ __forceinline__ void load_ldmatrix( |
| 399 | tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { |
| 400 | #if defined(TURING_MMA_AVAILABLE) |
| 401 | int * xi = (int * ) t.x; |
| 402 | const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); |
| 403 | asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" |
| 404 | : "=r" (xi[0]), "=r" (xi[1]), "=r" (xi[2]), "=r" (xi[3]) |
| 405 | : "l" (xs)); |
| 406 | #else |
| 407 | #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA |
| 408 | GGML_UNUSED_VARS(t, xs0, stride); |
| 409 | NO_DEVICE_CODE; |
| 410 | #else |
| 411 | load_generic(t, xs0, stride); |
| 412 | #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA |
| 413 | #endif // TURING_MMA_AVAILABLE |
| 414 | } |
| 415 | |
| 416 | template <typename T> |
| 417 | static __device__ __forceinline__ void load_ldmatrix( |
| 418 | tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) { |
| 419 | #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA |
| 420 | #if 1 |
| 421 | // TODO: more generic handling |
| 422 | static_assert(sizeof(T) == 4, "bad type size" ); |
| 423 | ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); |
| 424 | ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); |
| 425 | #else |
| 426 | load_generic(t, xs0, stride); |
| 427 | #endif // 1 |
| 428 | #else |
| 429 | tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t; |
| 430 | load_ldmatrix(t16[0], xs0 + 0*stride, stride); |
| 431 | load_ldmatrix(t16[1], xs0 + 16*stride, stride); |
| 432 | #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA |
| 433 | } |
| 434 | |
| 435 | template <typename T> |
| 436 | static __device__ __forceinline__ void load_ldmatrix_trans( |
| 437 | tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { |
| 438 | #ifdef TURING_MMA_AVAILABLE |
| 439 | int * xi = (int * ) t.x; |
| 440 | const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); |
| 441 | asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" |
| 442 | : "=r" (xi[0]), "=r" (xi[2]), "=r" (xi[1]), "=r" (xi[3]) |
| 443 | : "l" (xs)); |
| 444 | #else |
| 445 | GGML_UNUSED_VARS(t, xs0, stride); |
| 446 | NO_DEVICE_CODE; |
| 447 | #endif // TURING_MMA_AVAILABLE |
| 448 | } |
| 449 | |
| 450 | static __device__ __forceinline__ void mma( |
| 451 | tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) { |
| 452 | #ifdef TURING_MMA_AVAILABLE |
| 453 | #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 454 | asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" |
| 455 | : "+r" (D.x[0]), "+r" (D.x[1]), "+r" (D.x[2]), "+r" (D.x[3]) |
| 456 | : "r" (A.x[0]), "r" (A.x[1]), "r" (B.x[0])); |
| 457 | #else |
| 458 | // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead: |
| 459 | asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" |
| 460 | : "+r" (D.x[0]), "+r" (D.x[1]) |
| 461 | : "r" (A.x[0]), "r" (B.x[0])); |
| 462 | asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" |
| 463 | : "+r" (D.x[2]), "+r" (D.x[3]) |
| 464 | : "r" (A.x[1]), "r" (B.x[0])); |
| 465 | #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 466 | #else |
| 467 | GGML_UNUSED_VARS(D, A, B); |
| 468 | NO_DEVICE_CODE; |
| 469 | #endif // TURING_MMA_AVAILABLE |
| 470 | } |
| 471 | |
| 472 | static __device__ __forceinline__ void mma( |
| 473 | tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) { |
| 474 | #ifdef TURING_MMA_AVAILABLE |
| 475 | #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 476 | asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" |
| 477 | : "+r" (D.x[0]), "+r" (D.x[1]), "+r" (D.x[2]), "+r" (D.x[3]) |
| 478 | : "r" (A.x[0]), "r" (A.x[1]), "r" (A.x[2]), "r" (A.x[3]), "r" (B.x[0]), "r" (B.x[1])); |
| 479 | #else |
| 480 | // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: |
| 481 | asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" |
| 482 | : "+r" (D.x[0]), "+r" (D.x[1]) |
| 483 | : "r" (A.x[0]), "r" (B.x[0])); |
| 484 | asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" |
| 485 | : "+r" (D.x[2]), "+r" (D.x[3]) |
| 486 | : "r" (A.x[1]), "r" (B.x[0])); |
| 487 | asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" |
| 488 | : "+r" (D.x[0]), "+r" (D.x[1]) |
| 489 | : "r" (A.x[2]), "r" (B.x[1])); |
| 490 | asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" |
| 491 | : "+r" (D.x[2]), "+r" (D.x[3]) |
| 492 | : "r" (A.x[3]), "r" (B.x[1])); |
| 493 | #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 494 | #else |
| 495 | GGML_UNUSED_VARS(D, A, B); |
| 496 | NO_DEVICE_CODE; |
| 497 | #endif // TURING_MMA_AVAILABLE |
| 498 | } |
| 499 | |
| 500 | static __device__ __forceinline__ void mma( |
| 501 | tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { |
| 502 | #ifdef TURING_MMA_AVAILABLE |
| 503 | const int * Axi = (const int *) A.x; |
| 504 | const int * Bxi = (const int *) B.x; |
| 505 | int * Dxi = (int *) D.x; |
| 506 | #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 507 | asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" |
| 508 | : "+r" (Dxi[0]), "+r" (Dxi[1]) |
| 509 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[0]), "r" (Bxi[1])); |
| 510 | #else |
| 511 | // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: |
| 512 | asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" |
| 513 | : "+r" (Dxi[0]), "+r" (Dxi[1]) |
| 514 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Bxi[0])); |
| 515 | asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" |
| 516 | : "+r" (Dxi[0]), "+r" (Dxi[1]) |
| 517 | : "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[1])); |
| 518 | #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 519 | #else |
| 520 | GGML_UNUSED_VARS(D, A, B); |
| 521 | NO_DEVICE_CODE; |
| 522 | #endif // TURING_MMA_AVAILABLE |
| 523 | } |
| 524 | |
| 525 | static __device__ __forceinline__ void mma( |
| 526 | tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { |
| 527 | #ifdef TURING_MMA_AVAILABLE |
| 528 | const int * Axi = (const int *) A.x; |
| 529 | const int * Bxi = (const int *) B.x; |
| 530 | int * Dxi = (int *) D.x; |
| 531 | #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 532 | asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" |
| 533 | : "+r" (Dxi[0]), "+r" (Dxi[1]) |
| 534 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[0]), "r" (Bxi[2])); |
| 535 | asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" |
| 536 | : "+r" (Dxi[2]), "+r" (Dxi[3]) |
| 537 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[1]), "r" (Bxi[3])); |
| 538 | #else |
| 539 | // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: |
| 540 | asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" |
| 541 | : "+r" (Dxi[0]), "+r" (Dxi[1]) |
| 542 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Bxi[0])); |
| 543 | asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" |
| 544 | : "+r" (Dxi[0]), "+r" (Dxi[1]) |
| 545 | : "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[2])); |
| 546 | asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" |
| 547 | : "+r" (Dxi[2]), "+r" (Dxi[3]) |
| 548 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Bxi[1])); |
| 549 | asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" |
| 550 | : "+r" (Dxi[2]), "+r" (Dxi[3]) |
| 551 | : "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[3])); |
| 552 | #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 553 | #else |
| 554 | GGML_UNUSED_VARS(D, A, B); |
| 555 | NO_DEVICE_CODE; |
| 556 | #endif // TURING_MMA_AVAILABLE |
| 557 | } |
| 558 | |
| 559 | static __device__ __forceinline__ void mma( |
| 560 | tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) { |
| 561 | #ifdef AMPERE_MMA_AVAILABLE |
| 562 | const int * Axi = (const int *) A.x; |
| 563 | const int * Bxi = (const int *) B.x; |
| 564 | int * Dxi = (int *) D.x; |
| 565 | asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" |
| 566 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]) |
| 567 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[0]), "r" (Bxi[1])); |
| 568 | #else |
| 569 | GGML_UNUSED_VARS(D, A, B); |
| 570 | NO_DEVICE_CODE; |
| 571 | #endif // AMPERE_MMA_AVAILABLE |
| 572 | } |
| 573 | |
| 574 | static __device__ __forceinline__ void mma( |
| 575 | tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { |
| 576 | #ifdef TURING_MMA_AVAILABLE |
| 577 | const int * Axi = (const int *) A.x; |
| 578 | const int * Bxi = (const int *) B.x; |
| 579 | int * Dxi = (int *) D.x; |
| 580 | #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 581 | asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" |
| 582 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]) |
| 583 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[0]), "r" (Bxi[1])); |
| 584 | #else |
| 585 | // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: |
| 586 | asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" |
| 587 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]) |
| 588 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Bxi[0])); |
| 589 | asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" |
| 590 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]) |
| 591 | : "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[1])); |
| 592 | #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 593 | #else |
| 594 | GGML_UNUSED_VARS(D, A, B); |
| 595 | NO_DEVICE_CODE; |
| 596 | #endif // TURING_MMA_AVAILABLE |
| 597 | } |
| 598 | |
| 599 | static __device__ __forceinline__ void mma( |
| 600 | tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) { |
| 601 | #ifdef AMPERE_MMA_AVAILABLE |
| 602 | const int * Axi = (const int *) A.x; |
| 603 | const int * Bxi = (const int *) B.x; |
| 604 | int * Dxi = (int *) D.x; |
| 605 | asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" |
| 606 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]) |
| 607 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[0]), "r" (Bxi[1])); |
| 608 | #else |
| 609 | GGML_UNUSED_VARS(D, A, B); |
| 610 | NO_DEVICE_CODE; |
| 611 | #endif // AMPERE_MMA_AVAILABLE |
| 612 | } |
| 613 | |
| 614 | static __device__ __forceinline__ void mma( |
| 615 | tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { |
| 616 | #ifdef TURING_MMA_AVAILABLE |
| 617 | const int * Axi = (const int *) A.x; |
| 618 | const int * Bxi = (const int *) B.x; |
| 619 | int * Dxi = (int *) D.x; |
| 620 | #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 621 | asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" |
| 622 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]) |
| 623 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[0]), "r" (Bxi[2])); |
| 624 | asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" |
| 625 | : "+r" (Dxi[4]), "+r" (Dxi[5]), "+r" (Dxi[6]), "+r" (Dxi[7]) |
| 626 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[1]), "r" (Bxi[3])); |
| 627 | #else |
| 628 | // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: |
| 629 | asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" |
| 630 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]) |
| 631 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Bxi[0])); |
| 632 | asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" |
| 633 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]) |
| 634 | : "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[2])); |
| 635 | asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" |
| 636 | : "+r" (Dxi[4]), "+r" (Dxi[5]), "+r" (Dxi[6]), "+r" (Dxi[7]) |
| 637 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Bxi[1])); |
| 638 | asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" |
| 639 | : "+r" (Dxi[4]), "+r" (Dxi[5]), "+r" (Dxi[6]), "+r" (Dxi[7]) |
| 640 | : "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[3])); |
| 641 | #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 642 | #else |
| 643 | GGML_UNUSED_VARS(D, A, B); |
| 644 | NO_DEVICE_CODE; |
| 645 | #endif // TURING_MMA_AVAILABLE |
| 646 | } |
| 647 | |
| 648 | static __device__ __forceinline__ void mma( |
| 649 | tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { |
| 650 | #if defined(AMD_MFMA_AVAILABLE) |
| 651 | using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; |
| 652 | int32x4_t * acc = (int32x4_t *) D.x; |
| 653 | #if defined(CDNA3) |
| 654 | acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], |
| 655 | ((int64_t *) B.x)[0], |
| 656 | acc[0], |
| 657 | 0, 0, 0); |
| 658 | #elif defined(CDNA2) || defined(CDNA) |
| 659 | acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], |
| 660 | B.x[0], |
| 661 | acc[0], |
| 662 | 0, 0, 0); |
| 663 | acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], |
| 664 | B.x[1], |
| 665 | acc[0], |
| 666 | 0, 0, 0); |
| 667 | #endif // defined(CDNA3) |
| 668 | #else |
| 669 | GGML_UNUSED_VARS(D, A, B); |
| 670 | NO_DEVICE_CODE; |
| 671 | #endif // AMD_MFMA_AVAILABLE |
| 672 | } |
| 673 | |
| 674 | static __device__ __forceinline__ void mma( |
| 675 | tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) { |
| 676 | #if defined(AMD_MFMA_AVAILABLE) |
| 677 | using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; |
| 678 | int32x16_t * acc = (int32x16_t *) D.x; |
| 679 | #if defined(CDNA3) |
| 680 | acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], |
| 681 | ((int64_t *) B.x)[0], |
| 682 | acc[0], |
| 683 | 0, 0, 0); |
| 684 | #elif defined(CDNA2) || defined(CDNA) |
| 685 | acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], |
| 686 | B.x[0], |
| 687 | acc[0], |
| 688 | 0, 0, 0); |
| 689 | acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], |
| 690 | B.x[1], |
| 691 | acc[0], |
| 692 | 0, 0, 0); |
| 693 | #endif // defined(CDNA3) |
| 694 | #else |
| 695 | GGML_UNUSED_VARS(D, A, B); |
| 696 | NO_DEVICE_CODE; |
| 697 | #endif // AMD_MFMA_AVAILABLE |
| 698 | } |
| 699 | |
| 700 | template <typename T1, typename T2, int J, int K> |
| 701 | static __device__ __forceinline__ void mma( |
| 702 | tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) { |
| 703 | tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D; |
| 704 | tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A; |
| 705 | mma(D16[0], A16[0], B); |
| 706 | mma(D16[1], A16[1], B); |
| 707 | } |
| 708 | |
| 709 | static __device__ __forceinline__ void mma( |
| 710 | tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) { |
| 711 | #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA |
| 712 | const int * Axi = (const int *) A.x; |
| 713 | const int * Bxi = (const int *) B.x; |
| 714 | int * Dxi = (int *) D.x; |
| 715 | asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " |
| 716 | "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" |
| 717 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]), "+r" (Dxi[4]), "+r" (Dxi[5]), "+r" (Dxi[6]), "+r" (Dxi[7]) |
| 718 | : "r" (Axi[0]), "r" (Axi[1]), "r" (Bxi[0]), "r" (Bxi[1])); |
| 719 | asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " |
| 720 | "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" |
| 721 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]), "+r" (Dxi[4]), "+r" (Dxi[5]), "+r" (Dxi[6]), "+r" (Dxi[7]) |
| 722 | : "r" (Axi[2]), "r" (Axi[3]), "r" (Bxi[2]), "r" (Bxi[3])); |
| 723 | asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " |
| 724 | "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" |
| 725 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]), "+r" (Dxi[4]), "+r" (Dxi[5]), "+r" (Dxi[6]), "+r" (Dxi[7]) |
| 726 | : "r" (Axi[4]), "r" (Axi[5]), "r" (Bxi[4]), "r" (Bxi[5])); |
| 727 | asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " |
| 728 | "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" |
| 729 | : "+r" (Dxi[0]), "+r" (Dxi[1]), "+r" (Dxi[2]), "+r" (Dxi[3]), "+r" (Dxi[4]), "+r" (Dxi[5]), "+r" (Dxi[6]), "+r" (Dxi[7]) |
| 730 | : "r" (Axi[6]), "r" (Axi[7]), "r" (Bxi[6]), "r" (Bxi[7])); |
| 731 | #else |
| 732 | tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D; |
| 733 | tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A; |
| 734 | mma(D&: D16[0], A: A16[0], B); |
| 735 | mma(D&: D16[1], A: A16[1], B); |
| 736 | #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE |
| 737 | } |
| 738 | } |
| 739 | |