| 1 | #include "out-prod.cuh" |
| 2 | |
| 3 | #include <cstdint> |
| 4 | |
| 5 | void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 6 | const ggml_tensor * src0 = dst->src[0]; |
| 7 | const ggml_tensor * src1 = dst->src[1]; |
| 8 | |
| 9 | GGML_TENSOR_BINARY_OP_LOCALS |
| 10 | |
| 11 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 12 | GGML_ASSERT(src1->type == GGML_TYPE_F32); |
| 13 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 14 | |
| 15 | GGML_ASSERT(ne01 == ne11); |
| 16 | GGML_ASSERT(ne0 == ne00); |
| 17 | GGML_ASSERT(ne1 == ne10); |
| 18 | |
| 19 | GGML_ASSERT(ne2 % src0->ne[2] == 0); |
| 20 | GGML_ASSERT(ne3 % src0->ne[3] == 0); |
| 21 | |
| 22 | GGML_ASSERT(ne2 == src1->ne[2]); |
| 23 | GGML_ASSERT(ne3 == src1->ne[3]); |
| 24 | |
| 25 | const float * src0_d = (const float *) src0->data; |
| 26 | const float * src1_d = (const float *) src1->data; |
| 27 | float * dst_d = (float *) dst->data; |
| 28 | |
| 29 | cudaStream_t stream = ctx.stream(); |
| 30 | cublasHandle_t handle = ctx.cublas_handle(); |
| 31 | |
| 32 | const float alpha = 1.0f; |
| 33 | const float beta = 0.0f; |
| 34 | |
| 35 | CUBLAS_CHECK(cublasSetStream(handle, stream)); |
| 36 | |
| 37 | const int64_t lda = nb01 / sizeof(float); |
| 38 | const int64_t ldc = nb1 / sizeof(float); |
| 39 | |
| 40 | const bool src1_T = ggml_is_transposed(src1); |
| 41 | const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; |
| 42 | const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); |
| 43 | GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float)); |
| 44 | |
| 45 | // data strides in dimensions 2/3 |
| 46 | const size_t s02 = nb02 / sizeof(float); |
| 47 | const size_t s03 = nb03 / sizeof(float); |
| 48 | const size_t s12 = nb12 / sizeof(float); |
| 49 | const size_t s13 = nb13 / sizeof(float); |
| 50 | const size_t s2 = nb2 / sizeof(float); |
| 51 | const size_t s3 = nb3 / sizeof(float); |
| 52 | |
| 53 | // dps == dst per src0, used for group query attention |
| 54 | const int64_t dps2 = ne2 / ne02; |
| 55 | const int64_t dps3 = ne3 / ne03; |
| 56 | |
| 57 | // TODO batched matrix multiplication |
| 58 | for (int64_t i3 = 0; i3 < ne3; ++i3) { |
| 59 | for (int64_t i2 = 0; i2 < ne2; ++i2) { |
| 60 | CUBLAS_CHECK( |
| 61 | cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, |
| 62 | ne0, ne1, ne01, |
| 63 | &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, |
| 64 | src1_d + i3 *s13 + i2 *s12, ldb, |
| 65 | &beta, dst_d + i3 *s3 + i2 *s2, ldc)); |
| 66 | } |
| 67 | } |
| 68 | } |
| 69 | |