| 1 | #include "common.cuh" |
| 2 | #include "cross-entropy-loss.cuh" |
| 3 | #include "sum.cuh" |
| 4 | |
| 5 | #include <cmath> |
| 6 | #include <cstdint> |
| 7 | |
| 8 | template <bool use_shared> |
| 9 | static __global__ void cross_entropy_loss_f32( |
| 10 | const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) { |
| 11 | extern __shared__ float tmp[]; |
| 12 | |
| 13 | logits += int64_t(blockIdx.x)*nclasses; |
| 14 | labels += int64_t(blockIdx.x)*nclasses; |
| 15 | |
| 16 | // Find maximum for softmax: |
| 17 | float max_logit = -INFINITY; |
| 18 | for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { |
| 19 | const float val = logits[i]; |
| 20 | max_logit = fmaxf(a: max_logit, b: val); |
| 21 | |
| 22 | if (use_shared) { |
| 23 | tmp[i] = val; |
| 24 | } |
| 25 | } |
| 26 | max_logit = warp_reduce_max(x: max_logit); |
| 27 | |
| 28 | // Calculate log(softmax(logits)) which is just logits - max: |
| 29 | float sum = 0.0f; |
| 30 | for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { |
| 31 | const float logit_i = use_shared ? tmp[i] : logits[i]; |
| 32 | sum += expf(a: logit_i - max_logit); |
| 33 | } |
| 34 | sum = warp_reduce_sum(x: sum); |
| 35 | sum = logf(a: sum); |
| 36 | |
| 37 | // log(exp(logits - max) / sum) = (logits - max) - log(sum) |
| 38 | float loss = 0.0f; |
| 39 | for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { |
| 40 | const float logit_i = use_shared ? tmp[i] : logits[i]; |
| 41 | loss += (logit_i - max_logit - sum) * labels[i]; |
| 42 | } |
| 43 | loss = -warp_reduce_sum(x: loss) / (float)k; |
| 44 | |
| 45 | if (threadIdx.x != 0) { |
| 46 | return; |
| 47 | } |
| 48 | |
| 49 | dst[blockIdx.x] = loss; |
| 50 | } |
| 51 | |
| 52 | template <bool use_shared> |
| 53 | static __global__ void cross_entropy_loss_back_f32( |
| 54 | const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels, |
| 55 | float * __restrict__ dst, const int nclasses) { |
| 56 | extern __shared__ float tmp[]; |
| 57 | |
| 58 | logits += int64_t(blockIdx.x)*nclasses; |
| 59 | labels += int64_t(blockIdx.x)*nclasses; |
| 60 | dst += int64_t(blockIdx.x)*nclasses; |
| 61 | |
| 62 | float maxval = -INFINITY; |
| 63 | for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { |
| 64 | const float val = logits[i]; |
| 65 | maxval = fmaxf(a: maxval, b: val); |
| 66 | |
| 67 | if (use_shared) { |
| 68 | tmp[i] = val; |
| 69 | } |
| 70 | } |
| 71 | maxval = warp_reduce_max(x: maxval); |
| 72 | |
| 73 | float sum = 0.0f; |
| 74 | for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { |
| 75 | const float val = expf(a: (use_shared ? tmp[i] : logits[i]) - maxval); |
| 76 | sum += val; |
| 77 | |
| 78 | if (use_shared) { |
| 79 | tmp[i] = val; |
| 80 | } else { |
| 81 | dst[i] = val; |
| 82 | } |
| 83 | } |
| 84 | sum = warp_reduce_sum(x: sum); |
| 85 | const float sm_scale = 1.0f/sum; |
| 86 | |
| 87 | const float d_by_nrows = *grad/gridDim.x; |
| 88 | for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { |
| 89 | const float val = use_shared ? tmp[i] : dst[i]; |
| 90 | dst[i] = (val*sm_scale - labels[i])*d_by_nrows; |
| 91 | } |
| 92 | } |
| 93 | |
| 94 | void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 95 | const ggml_tensor * src0 = dst->src[0]; |
| 96 | const ggml_tensor * src1 = dst->src[1]; |
| 97 | |
| 98 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 99 | GGML_ASSERT(src1->type == GGML_TYPE_F32); |
| 100 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 101 | |
| 102 | GGML_ASSERT(ggml_is_contiguous(src0)); |
| 103 | GGML_ASSERT(ggml_is_contiguous(src1)); |
| 104 | GGML_ASSERT(ggml_is_contiguous(dst)); |
| 105 | |
| 106 | const int64_t ne00 = src0->ne[0]; |
| 107 | const int64_t nrows = ggml_nrows(src0); |
| 108 | |
| 109 | const float * src0_d = (const float *) src0->data; |
| 110 | const float * src1_d = (const float *) src1->data; |
| 111 | float * dst_d = (float *) dst->data; |
| 112 | |
| 113 | ggml_cuda_pool & pool = ctx.pool(); |
| 114 | cudaStream_t stream = ctx.stream(); |
| 115 | |
| 116 | const dim3 blocks_dim(WARP_SIZE, 1, 1); |
| 117 | const dim3 blocks_num(nrows, 1, 1); |
| 118 | const size_t nbytes_shared = ne00*sizeof(float); |
| 119 | |
| 120 | const int id = ggml_cuda_get_device(); |
| 121 | const size_t smpbo = ggml_cuda_info().devices[id].smpbo; |
| 122 | |
| 123 | ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x); |
| 124 | |
| 125 | if (nbytes_shared <= smpbo) { |
| 126 | CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo); |
| 127 | cross_entropy_loss_f32<true><<<gridDim: blocks_num, blockDim: blocks_dim, sharedMem: nbytes_shared, stream>>>(logits: src0_d, labels: src1_d, dst: dst_tmp.ptr, nclasses: ne00, k: nrows); |
| 128 | } else { |
| 129 | cross_entropy_loss_f32<false><<<gridDim: blocks_num, blockDim: blocks_dim, sharedMem: 0, stream>>>(logits: src0_d, labels: src1_d, dst: dst_tmp.ptr, nclasses: ne00, k: nrows); |
| 130 | } |
| 131 | CUDA_CHECK(cudaGetLastError()); |
| 132 | |
| 133 | // Combine results from individual blocks: |
| 134 | sum_f32_cuda(pool, x: dst_tmp.ptr, dst: dst_d, ne: blocks_num.x, stream); |
| 135 | } |
| 136 | |
| 137 | void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 138 | const ggml_tensor * grad = dst->src[0]; |
| 139 | const ggml_tensor * src0f = dst->src[1]; |
| 140 | const ggml_tensor * src1f = dst->src[2]; |
| 141 | |
| 142 | GGML_ASSERT(src0f->type == GGML_TYPE_F32); |
| 143 | GGML_ASSERT(src1f->type == GGML_TYPE_F32); |
| 144 | GGML_ASSERT( grad->type == GGML_TYPE_F32); |
| 145 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 146 | |
| 147 | GGML_ASSERT(ggml_is_scalar(grad)); |
| 148 | GGML_ASSERT(ggml_is_contiguous(src0f)); |
| 149 | GGML_ASSERT(ggml_is_contiguous(src1f)); |
| 150 | GGML_ASSERT(ggml_is_contiguous(dst)); |
| 151 | GGML_ASSERT(ggml_are_same_shape(src0f, src1f)); |
| 152 | GGML_ASSERT(ggml_are_same_shape(src0f, dst)); |
| 153 | |
| 154 | const int64_t ne00 = src0f->ne[0]; |
| 155 | const int64_t nrows = ggml_nrows(src0f); |
| 156 | |
| 157 | const float * grad_d = (const float *) grad->data; |
| 158 | const float * src0f_d = (const float *) src0f->data; |
| 159 | const float * src1f_d = (const float *) src1f->data; |
| 160 | float * dst_d = (float *) dst->data; |
| 161 | |
| 162 | cudaStream_t stream = ctx.stream(); |
| 163 | |
| 164 | const dim3 blocks_dim(WARP_SIZE, 1, 1); |
| 165 | const dim3 blocks_num(nrows, 1, 1); |
| 166 | const size_t nbytes_shared = ne00*sizeof(float); |
| 167 | |
| 168 | const int id = ggml_cuda_get_device(); |
| 169 | const size_t smpbo = ggml_cuda_info().devices[id].smpbo; |
| 170 | |
| 171 | if (nbytes_shared <= smpbo) { |
| 172 | CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo); |
| 173 | cross_entropy_loss_back_f32<true><<<gridDim: blocks_num, blockDim: blocks_dim, sharedMem: nbytes_shared, stream>>>(grad: grad_d, logits: src0f_d, labels: src1f_d, dst: dst_d, nclasses: ne00); |
| 174 | } else { |
| 175 | cross_entropy_loss_back_f32<false><<<gridDim: blocks_num, blockDim: blocks_dim, sharedMem: 0, stream>>>(grad: grad_d, logits: src0f_d, labels: src1f_d, dst: dst_d, nclasses: ne00); |
| 176 | } |
| 177 | } |
| 178 | |