| 1 | #include "pad.cuh" |
| 2 | |
| 3 | static __global__ void pad_f32(const float * src, float * dst, |
| 4 | const int lp0, const int rp0, const int lp1, const int rp1, |
| 5 | const int lp2, const int rp2, const int lp3, const int rp3, |
| 6 | const int ne0, const int ne1, const int ne2, const int ne3) { |
| 7 | // blockIdx.z: i3*ne2+i2 |
| 8 | // blockIdx.y: i1 |
| 9 | // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE |
| 10 | // gridDim.y: ne1 |
| 11 | int i0 = threadIdx.x + blockIdx.x * blockDim.x; |
| 12 | int i1 = blockIdx.y; |
| 13 | int i2 = blockIdx.z % ne2; |
| 14 | int i3 = blockIdx.z / ne2; |
| 15 | if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { |
| 16 | return; |
| 17 | } |
| 18 | |
| 19 | // operation |
| 20 | const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; |
| 21 | if ((i0 >= lp0 && i0 < ne0 - rp0) && |
| 22 | (i1 >= lp1 && i1 < ne1 - rp1) && |
| 23 | (i2 >= lp2 && i2 < ne2 - rp2) && |
| 24 | (i3 >= lp3 && i3 < ne3 - rp3)) { |
| 25 | const int64_t i00 = i0 - lp0; |
| 26 | const int64_t i01 = i1 - lp1; |
| 27 | const int64_t i02 = i2 - lp2; |
| 28 | const int64_t i03 = i3 - lp3; |
| 29 | const int64_t ne02 = ne2 - lp2 - rp2; |
| 30 | const int64_t ne01 = ne1 - lp1 - rp1; |
| 31 | const int64_t ne00 = ne0 - lp0 - rp0; |
| 32 | |
| 33 | const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00; |
| 34 | |
| 35 | dst[dst_idx] = src[src_idx]; |
| 36 | } else { |
| 37 | dst[dst_idx] = 0.0f; |
| 38 | } |
| 39 | } |
| 40 | |
| 41 | static void pad_f32_cuda(const float * src, float * dst, |
| 42 | const int lp0, const int rp0, const int lp1, const int rp1, |
| 43 | const int lp2, const int rp2, const int lp3, const int rp3, |
| 44 | const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { |
| 45 | int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; |
| 46 | dim3 gridDim(num_blocks, ne1, ne2*ne3); |
| 47 | pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, sharedMem: 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3); |
| 48 | } |
| 49 | |
| 50 | void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 51 | const ggml_tensor * src0 = dst->src[0]; |
| 52 | const float * src0_d = (const float *)src0->data; |
| 53 | float * dst_d = (float *)dst->data; |
| 54 | cudaStream_t stream = ctx.stream(); |
| 55 | |
| 56 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 57 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 58 | GGML_ASSERT(ggml_is_contiguous(src0)); |
| 59 | |
| 60 | const int32_t lp0 = ((const int32_t*)(dst->op_params))[0]; |
| 61 | const int32_t rp0 = ((const int32_t*)(dst->op_params))[1]; |
| 62 | const int32_t lp1 = ((const int32_t*)(dst->op_params))[2]; |
| 63 | const int32_t rp1 = ((const int32_t*)(dst->op_params))[3]; |
| 64 | const int32_t lp2 = ((const int32_t*)(dst->op_params))[4]; |
| 65 | const int32_t rp2 = ((const int32_t*)(dst->op_params))[5]; |
| 66 | const int32_t lp3 = ((const int32_t*)(dst->op_params))[6]; |
| 67 | const int32_t rp3 = ((const int32_t*)(dst->op_params))[7]; |
| 68 | |
| 69 | pad_f32_cuda(src0_d, dst_d, |
| 70 | lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, |
| 71 | dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); |
| 72 | } |
| 73 | |