1#include "common.cuh"
2#include "cross-entropy-loss.cuh"
3#include "sum.cuh"
4
5#include <cmath>
6#include <cstdint>
7
8template <bool use_shared>
9static __global__ void cross_entropy_loss_f32(
10 const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
11 extern __shared__ float tmp[];
12
13 logits += int64_t(blockIdx.x)*nclasses;
14 labels += int64_t(blockIdx.x)*nclasses;
15
16 // Find maximum for softmax:
17 float max_logit = -INFINITY;
18 for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
19 const float val = logits[i];
20 max_logit = fmaxf(a: max_logit, b: val);
21
22 if (use_shared) {
23 tmp[i] = val;
24 }
25 }
26 max_logit = warp_reduce_max(x: max_logit);
27
28 // Calculate log(softmax(logits)) which is just logits - max:
29 float sum = 0.0f;
30 for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
31 const float logit_i = use_shared ? tmp[i] : logits[i];
32 sum += expf(a: logit_i - max_logit);
33 }
34 sum = warp_reduce_sum(x: sum);
35 sum = logf(a: sum);
36
37 // log(exp(logits - max) / sum) = (logits - max) - log(sum)
38 float loss = 0.0f;
39 for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
40 const float logit_i = use_shared ? tmp[i] : logits[i];
41 loss += (logit_i - max_logit - sum) * labels[i];
42 }
43 loss = -warp_reduce_sum(x: loss) / (float)k;
44
45 if (threadIdx.x != 0) {
46 return;
47 }
48
49 dst[blockIdx.x] = loss;
50}
51
52template <bool use_shared>
53static __global__ void cross_entropy_loss_back_f32(
54 const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
55 float * __restrict__ dst, const int nclasses) {
56 extern __shared__ float tmp[];
57
58 logits += int64_t(blockIdx.x)*nclasses;
59 labels += int64_t(blockIdx.x)*nclasses;
60 dst += int64_t(blockIdx.x)*nclasses;
61
62 float maxval = -INFINITY;
63 for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
64 const float val = logits[i];
65 maxval = fmaxf(a: maxval, b: val);
66
67 if (use_shared) {
68 tmp[i] = val;
69 }
70 }
71 maxval = warp_reduce_max(x: maxval);
72
73 float sum = 0.0f;
74 for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
75 const float val = expf(a: (use_shared ? tmp[i] : logits[i]) - maxval);
76 sum += val;
77
78 if (use_shared) {
79 tmp[i] = val;
80 } else {
81 dst[i] = val;
82 }
83 }
84 sum = warp_reduce_sum(x: sum);
85 const float sm_scale = 1.0f/sum;
86
87 const float d_by_nrows = *grad/gridDim.x;
88 for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
89 const float val = use_shared ? tmp[i] : dst[i];
90 dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
91 }
92}
93
94void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
95 const ggml_tensor * src0 = dst->src[0];
96 const ggml_tensor * src1 = dst->src[1];
97
98 GGML_ASSERT(src0->type == GGML_TYPE_F32);
99 GGML_ASSERT(src1->type == GGML_TYPE_F32);
100 GGML_ASSERT( dst->type == GGML_TYPE_F32);
101
102 GGML_ASSERT(ggml_is_contiguous(src0));
103 GGML_ASSERT(ggml_is_contiguous(src1));
104 GGML_ASSERT(ggml_is_contiguous(dst));
105
106 const int64_t ne00 = src0->ne[0];
107 const int64_t nrows = ggml_nrows(src0);
108
109 const float * src0_d = (const float *) src0->data;
110 const float * src1_d = (const float *) src1->data;
111 float * dst_d = (float *) dst->data;
112
113 ggml_cuda_pool & pool = ctx.pool();
114 cudaStream_t stream = ctx.stream();
115
116 const dim3 blocks_dim(WARP_SIZE, 1, 1);
117 const dim3 blocks_num(nrows, 1, 1);
118 const size_t nbytes_shared = ne00*sizeof(float);
119
120 const int id = ggml_cuda_get_device();
121 const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
122
123 ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
124
125 if (nbytes_shared <= smpbo) {
126 CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
127 cross_entropy_loss_f32<true><<<gridDim: blocks_num, blockDim: blocks_dim, sharedMem: nbytes_shared, stream>>>(logits: src0_d, labels: src1_d, dst: dst_tmp.ptr, nclasses: ne00, k: nrows);
128 } else {
129 cross_entropy_loss_f32<false><<<gridDim: blocks_num, blockDim: blocks_dim, sharedMem: 0, stream>>>(logits: src0_d, labels: src1_d, dst: dst_tmp.ptr, nclasses: ne00, k: nrows);
130 }
131 CUDA_CHECK(cudaGetLastError());
132
133 // Combine results from individual blocks:
134 sum_f32_cuda(pool, x: dst_tmp.ptr, dst: dst_d, ne: blocks_num.x, stream);
135}
136
137void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
138 const ggml_tensor * grad = dst->src[0];
139 const ggml_tensor * src0f = dst->src[1];
140 const ggml_tensor * src1f = dst->src[2];
141
142 GGML_ASSERT(src0f->type == GGML_TYPE_F32);
143 GGML_ASSERT(src1f->type == GGML_TYPE_F32);
144 GGML_ASSERT( grad->type == GGML_TYPE_F32);
145 GGML_ASSERT( dst->type == GGML_TYPE_F32);
146
147 GGML_ASSERT(ggml_is_scalar(grad));
148 GGML_ASSERT(ggml_is_contiguous(src0f));
149 GGML_ASSERT(ggml_is_contiguous(src1f));
150 GGML_ASSERT(ggml_is_contiguous(dst));
151 GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
152 GGML_ASSERT(ggml_are_same_shape(src0f, dst));
153
154 const int64_t ne00 = src0f->ne[0];
155 const int64_t nrows = ggml_nrows(src0f);
156
157 const float * grad_d = (const float *) grad->data;
158 const float * src0f_d = (const float *) src0f->data;
159 const float * src1f_d = (const float *) src1f->data;
160 float * dst_d = (float *) dst->data;
161
162 cudaStream_t stream = ctx.stream();
163
164 const dim3 blocks_dim(WARP_SIZE, 1, 1);
165 const dim3 blocks_num(nrows, 1, 1);
166 const size_t nbytes_shared = ne00*sizeof(float);
167
168 const int id = ggml_cuda_get_device();
169 const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
170
171 if (nbytes_shared <= smpbo) {
172 CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
173 cross_entropy_loss_back_f32<true><<<gridDim: blocks_num, blockDim: blocks_dim, sharedMem: nbytes_shared, stream>>>(grad: grad_d, logits: src0f_d, labels: src1f_d, dst: dst_d, nclasses: ne00);
174 } else {
175 cross_entropy_loss_back_f32<false><<<gridDim: blocks_num, blockDim: blocks_dim, sharedMem: 0, stream>>>(grad: grad_d, logits: src0f_d, labels: src1f_d, dst: dst_d, nclasses: ne00);
176 }
177}
178