1#include "pad.cuh"
2
3static __global__ void pad_f32(const float * src, float * dst,
4 const int lp0, const int rp0, const int lp1, const int rp1,
5 const int lp2, const int rp2, const int lp3, const int rp3,
6 const int ne0, const int ne1, const int ne2, const int ne3) {
7 // blockIdx.z: i3*ne2+i2
8 // blockIdx.y: i1
9 // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
10 // gridDim.y: ne1
11 int i0 = threadIdx.x + blockIdx.x * blockDim.x;
12 int i1 = blockIdx.y;
13 int i2 = blockIdx.z % ne2;
14 int i3 = blockIdx.z / ne2;
15 if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
16 return;
17 }
18
19 // operation
20 const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
21 if ((i0 >= lp0 && i0 < ne0 - rp0) &&
22 (i1 >= lp1 && i1 < ne1 - rp1) &&
23 (i2 >= lp2 && i2 < ne2 - rp2) &&
24 (i3 >= lp3 && i3 < ne3 - rp3)) {
25 const int64_t i00 = i0 - lp0;
26 const int64_t i01 = i1 - lp1;
27 const int64_t i02 = i2 - lp2;
28 const int64_t i03 = i3 - lp3;
29 const int64_t ne02 = ne2 - lp2 - rp2;
30 const int64_t ne01 = ne1 - lp1 - rp1;
31 const int64_t ne00 = ne0 - lp0 - rp0;
32
33 const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
34
35 dst[dst_idx] = src[src_idx];
36 } else {
37 dst[dst_idx] = 0.0f;
38 }
39}
40
41static void pad_f32_cuda(const float * src, float * dst,
42 const int lp0, const int rp0, const int lp1, const int rp1,
43 const int lp2, const int rp2, const int lp3, const int rp3,
44 const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
45 int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
46 dim3 gridDim(num_blocks, ne1, ne2*ne3);
47 pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, sharedMem: 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
48}
49
50void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
51 const ggml_tensor * src0 = dst->src[0];
52 const float * src0_d = (const float *)src0->data;
53 float * dst_d = (float *)dst->data;
54 cudaStream_t stream = ctx.stream();
55
56 GGML_ASSERT(src0->type == GGML_TYPE_F32);
57 GGML_ASSERT(dst->type == GGML_TYPE_F32);
58 GGML_ASSERT(ggml_is_contiguous(src0));
59
60 const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
61 const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
62 const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
63 const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
64 const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
65 const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
66 const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
67 const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
68
69 pad_f32_cuda(src0_d, dst_d,
70 lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
71 dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
72}
73