1#include "out-prod.cuh"
2
3#include <cstdint>
4
5void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
6 const ggml_tensor * src0 = dst->src[0];
7 const ggml_tensor * src1 = dst->src[1];
8
9 GGML_TENSOR_BINARY_OP_LOCALS
10
11 GGML_ASSERT(src0->type == GGML_TYPE_F32);
12 GGML_ASSERT(src1->type == GGML_TYPE_F32);
13 GGML_ASSERT(dst->type == GGML_TYPE_F32);
14
15 GGML_ASSERT(ne01 == ne11);
16 GGML_ASSERT(ne0 == ne00);
17 GGML_ASSERT(ne1 == ne10);
18
19 GGML_ASSERT(ne2 % src0->ne[2] == 0);
20 GGML_ASSERT(ne3 % src0->ne[3] == 0);
21
22 GGML_ASSERT(ne2 == src1->ne[2]);
23 GGML_ASSERT(ne3 == src1->ne[3]);
24
25 const float * src0_d = (const float *) src0->data;
26 const float * src1_d = (const float *) src1->data;
27 float * dst_d = (float *) dst->data;
28
29 cudaStream_t stream = ctx.stream();
30 cublasHandle_t handle = ctx.cublas_handle();
31
32 const float alpha = 1.0f;
33 const float beta = 0.0f;
34
35 CUBLAS_CHECK(cublasSetStream(handle, stream));
36
37 const int64_t lda = nb01 / sizeof(float);
38 const int64_t ldc = nb1 / sizeof(float);
39
40 const bool src1_T = ggml_is_transposed(src1);
41 const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
42 const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
43 GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
44
45 // data strides in dimensions 2/3
46 const size_t s02 = nb02 / sizeof(float);
47 const size_t s03 = nb03 / sizeof(float);
48 const size_t s12 = nb12 / sizeof(float);
49 const size_t s13 = nb13 / sizeof(float);
50 const size_t s2 = nb2 / sizeof(float);
51 const size_t s3 = nb3 / sizeof(float);
52
53 // dps == dst per src0, used for group query attention
54 const int64_t dps2 = ne2 / ne02;
55 const int64_t dps3 = ne3 / ne03;
56
57 // TODO batched matrix multiplication
58 for (int64_t i3 = 0; i3 < ne3; ++i3) {
59 for (int64_t i2 = 0; i2 < ne2; ++i2) {
60 CUBLAS_CHECK(
61 cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
62 ne0, ne1, ne01,
63 &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
64 src1_d + i3 *s13 + i2 *s12, ldb,
65 &beta, dst_d + i3 *s3 + i2 *s2, ldc));
66 }
67 }
68}
69