1#include "rope.cuh"
2
3struct rope_corr_dims {
4 float v[2];
5};
6
7
8struct mrope_sections {
9 int v[4];
10};
11
12static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
13 const float y = (i0 / 2 - low) / max(a: 0.001f, b: high - low);
14 return 1.0f - min(a: 1.0f, b: max(a: 0.0f, b: y));
15}
16
17// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
18// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
19template<bool forward>
20static __device__ void rope_yarn(
21 const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
22 float mscale, float & cos_theta, float & sin_theta) {
23 // Get n-d rotational scaling corrected for extrapolation
24 float theta_interp = freq_scale * theta_extrap;
25 float theta = theta_interp;
26 if (ext_factor != 0.0f) {
27 float ramp_mix = rope_yarn_ramp(low: corr_dims.v[0], high: corr_dims.v[1], i0) * ext_factor;
28 theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
29
30 // Get n-d magnitude scaling corrected for interpolation
31 mscale *= 1.0f + 0.1f * logf(a: 1.0f / freq_scale);
32 }
33 cos_theta = cosf(a: theta) * mscale;
34 sin_theta = sinf(a: theta) * mscale;
35 if (!forward) {
36 sin_theta *= -1.0f;
37 }
38}
39
40template<bool forward, bool has_ff, typename T>
41static __global__ void rope_norm(
42 const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
43 const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
44 const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
45 const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
46
47 if (i0 >= ne0) {
48 return;
49 }
50
51 const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
52
53 const int row_x = row_dst % ne1;
54 const int channel_x = row_dst / ne1;
55
56 const int idst = row_dst*ne0 + i0;
57 const int ix = channel_x*s2 + row_x*s1 + i0;
58
59 if (i0 >= n_dims) {
60 dst[idst + 0] = x[ix + 0];
61 dst[idst + 1] = x[ix + 1];
62
63 return;
64 }
65
66 const float theta_base = pos[channel_x]*powf(a: theta_scale, b: i0/2.0f);
67
68 const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
69
70 float cos_theta;
71 float sin_theta;
72
73 rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
74
75 const float x0 = x[ix + 0];
76 const float x1 = x[ix + 1];
77
78 dst[idst + 0] = x0*cos_theta - x1*sin_theta;
79 dst[idst + 1] = x0*sin_theta + x1*cos_theta;
80}
81
82template<bool forward, bool has_ff, typename T>
83static __global__ void rope_neox(
84 const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
85 const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
86 const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
87 const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
88
89 if (i0 >= ne0) {
90 return;
91 }
92
93 const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
94
95 const int row_x = row_dst % ne1;
96 const int channel_x = row_dst / ne1;
97
98 const int idst = row_dst*ne0 + i0/2;
99 const int ix = channel_x*s2 + row_x*s1 + i0/2;
100
101 if (i0 >= n_dims) {
102 dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
103 dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
104
105 return;
106 }
107
108 const float theta_base = pos[channel_x]*powf(a: theta_scale, b: i0/2.0f);
109
110 const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
111
112 float cos_theta;
113 float sin_theta;
114
115 rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
116
117 const float x0 = x[ix + 0];
118 const float x1 = x[ix + n_dims/2];
119
120 dst[idst + 0] = x0*cos_theta - x1*sin_theta;
121 dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
122}
123
124template<bool forward, bool has_ff, typename T>
125static __global__ void rope_multi(
126 const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
127 const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
128 const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
129 const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
130
131 if (i0 >= ne0) {
132 return;
133 }
134
135 const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
136
137 const int row_x = row_dst % ne1;
138 const int channel_x = row_dst / ne1;
139
140 const int idst = row_dst*ne0 + i0/2;
141 const int ix = channel_x*s2 + row_x*s1 + i0/2;
142
143 if (i0 >= n_dims) {
144 dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
145 dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
146
147 return;
148 }
149
150 const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
151 const int sec_w = sections.v[1] + sections.v[0];
152 const int sector = (i0 / 2) % sect_dims;
153
154 float theta_base = 0.0;
155 if (is_imrope) {
156 if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
157 theta_base = pos[channel_x + ne2 * 1]*powf(a: theta_scale, b: i0/2.0f);
158 } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
159 theta_base = pos[channel_x + ne2 * 2]*powf(a: theta_scale, b: i0/2.0f);
160 } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
161 theta_base = pos[channel_x]*powf(a: theta_scale, b: i0/2.0f);
162 } else {
163 theta_base = pos[channel_x + ne2 * 3]*powf(a: theta_scale, b: i0/2.0f);
164 }
165 } else {
166 if (sector < sections.v[0]) {
167 theta_base = pos[channel_x]*powf(a: theta_scale, b: i0/2.0f);
168 }
169 else if (sector >= sections.v[0] && sector < sec_w) {
170 theta_base = pos[channel_x + ne2 * 1]*powf(a: theta_scale, b: i0/2.0f);
171 }
172 else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
173 theta_base = pos[channel_x + ne2 * 2]*powf(a: theta_scale, b: i0/2.0f);
174 }
175 else if (sector >= sec_w + sections.v[2]) {
176 theta_base = pos[channel_x + ne2 * 3]*powf(a: theta_scale, b: i0/2.0f);
177 }
178 }
179
180 const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
181
182 float cos_theta;
183 float sin_theta;
184
185 rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
186
187 const float x0 = x[ix + 0];
188 const float x1 = x[ix + n_dims/2];
189
190 dst[idst + 0] = x0*cos_theta - x1*sin_theta;
191 dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
192}
193
194template<bool forward, bool has_ff, typename T>
195static __global__ void rope_vision(
196 const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
197 const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
198 const float theta_scale, const float * freq_factors, const mrope_sections sections) {
199 const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
200
201 if (i0 >= ne0) {
202 return;
203 }
204
205 const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
206
207 const int row_x = row_dst % ne1;
208 const int channel_x = row_dst / ne1;
209
210 const int idst = row_dst*ne0 + i0/2;
211 const int ix = channel_x*s2 + row_x*s1 + i0/2;
212
213 const int sect_dims = sections.v[0] + sections.v[1];
214 const int sec_w = sections.v[1] + sections.v[0];
215 const int sector = (i0 / 2) % sect_dims;
216
217 float theta_base = 0.0;
218 if (sector < sections.v[0]) {
219 const int p = sector;
220 theta_base = pos[channel_x]*powf(a: theta_scale, b: p);
221 }
222 else if (sector >= sections.v[0] && sector < sec_w) {
223 const int p = sector - sections.v[0];
224 theta_base = pos[channel_x + ne2]*powf(a: theta_scale, b: p);
225 }
226
227 const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
228
229 float cos_theta;
230 float sin_theta;
231
232 rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
233
234 const float x0 = x[ix + 0];
235 const float x1 = x[ix + n_dims];
236
237 dst[idst + 0] = x0*cos_theta - x1*sin_theta;
238 dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
239}
240
241template<bool forward, typename T>
242static void rope_norm_cuda(
243 const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
244 const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
245 const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
246 GGML_ASSERT(ne0 % 2 == 0);
247 const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
248 const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
249 const dim3 block_nums(nr, n_blocks_x, 1);
250
251 const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims);
252
253 if (freq_factors == nullptr) {
254 rope_norm<forward, false><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(
255 x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
256 attn_factor, corr_dims, theta_scale, freq_factors);
257 } else {
258 rope_norm<forward, true><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(
259 x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
260 attn_factor, corr_dims, theta_scale, freq_factors);
261 }
262}
263
264template<bool forward, typename T>
265static void rope_neox_cuda(
266 const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
267 const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
268 const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
269 GGML_ASSERT(ne0 % 2 == 0);
270 const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
271 const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
272 const dim3 block_nums(nr, n_blocks_x, 1);
273
274 const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims);
275
276 if (freq_factors == nullptr) {
277 rope_neox<forward, false, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(
278 x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
279 attn_factor, corr_dims, theta_scale, freq_factors);
280 } else {
281 rope_neox<forward, true, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(
282 x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
283 attn_factor, corr_dims, theta_scale, freq_factors);
284 }
285}
286
287template<bool forward, typename T>
288static void rope_multi_cuda(
289 const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
290 const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
291 const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
292 GGML_ASSERT(ne0 % 2 == 0);
293 const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
294 const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
295 const dim3 block_nums(nr, n_blocks_x, 1);
296
297 const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims);
298
299 if (freq_factors == nullptr) {
300 rope_multi<forward, false, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(
301 x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
302 attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
303 } else {
304 rope_multi<forward, true, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(
305 x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
306 attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
307 }
308}
309
310template<bool forward, typename T>
311static void rope_vision_cuda(
312 const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
313 const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
314 const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
315 GGML_ASSERT(ne0 % 2 == 0);
316 const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
317 const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
318 const dim3 block_nums(nr, n_blocks_x, 1);
319 // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
320 // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
321
322 const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims);
323
324 if (freq_factors == nullptr) {
325 rope_vision<forward, false, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(
326 x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
327 attn_factor, corr_dims, theta_scale, freq_factors, sections);
328 } else {
329 rope_vision<forward, true, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>(
330 x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
331 attn_factor, corr_dims, theta_scale, freq_factors, sections);
332 }
333}
334
335template <bool forward>
336void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
337 const ggml_tensor * src0 = dst->src[0];
338 const ggml_tensor * src1 = dst->src[1];
339 const ggml_tensor * src2 = dst->src[2];
340
341 const float * src0_d = (const float *)src0->data;
342 const float * src1_d = (const float *)src1->data;
343
344 float * dst_d = (float *)dst->data;
345 cudaStream_t stream = ctx.stream();
346
347 GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
348 GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
349 GGML_ASSERT(src0->type == dst->type);
350
351 const int64_t ne00 = src0->ne[0]; // head dims
352 const int64_t ne01 = src0->ne[1]; // num heads
353 const int64_t ne02 = src0->ne[2]; // num heads
354 const int64_t nr = ggml_nrows(src0);
355
356 const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
357 const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
358
359 //const int n_past = ((int32_t *) dst->op_params)[0];
360 const int n_dims = ((int32_t *) dst->op_params)[1];
361 const int mode = ((int32_t *) dst->op_params)[2];
362 //const int n_ctx = ((int32_t *) dst->op_params)[3];
363 const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
364 mrope_sections sections;
365
366 // RoPE alteration for extended context
367 float freq_base;
368 float freq_scale;
369 float ext_factor;
370 float attn_factor;
371 float beta_fast;
372 float beta_slow;
373
374 memcpy(dest: &freq_base, src: (int32_t *) dst->op_params + 5, n: sizeof(float));
375 memcpy(dest: &freq_scale, src: (int32_t *) dst->op_params + 6, n: sizeof(float));
376 memcpy(dest: &ext_factor, src: (int32_t *) dst->op_params + 7, n: sizeof(float));
377 memcpy(dest: &attn_factor, src: (int32_t *) dst->op_params + 8, n: sizeof(float));
378 memcpy(dest: &beta_fast, src: (int32_t *) dst->op_params + 9, n: sizeof(float));
379 memcpy(dest: &beta_slow, src: (int32_t *) dst->op_params + 10, n: sizeof(float));
380 memcpy(dest: &sections.v, src: (int32_t *) dst->op_params + 11, n: sizeof(int)*4);
381
382 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
383 const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
384 const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
385 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
386
387 if (is_mrope) {
388 GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
389 }
390
391 if (is_vision) {
392 GGML_ASSERT(n_dims == ne00/2);
393 }
394
395 const int32_t * pos = (const int32_t *) src1_d;
396
397 const float * freq_factors = nullptr;
398 if (src2 != nullptr) {
399 freq_factors = (const float *) src2->data;
400 }
401
402 rope_corr_dims corr_dims;
403 ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
404
405 // compute
406 if (is_neox) {
407 if (src0->type == GGML_TYPE_F32) {
408 rope_neox_cuda<forward>(
409 (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
410 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
411 } else if (src0->type == GGML_TYPE_F16) {
412 rope_neox_cuda<forward>(
413 (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
414 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
415 } else {
416 GGML_ABORT("fatal error");
417 }
418 } else if (is_mrope && !is_vision) {
419 if (src0->type == GGML_TYPE_F32) {
420 rope_multi_cuda<forward>(
421 (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
422 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
423 } else if (src0->type == GGML_TYPE_F16) {
424 rope_multi_cuda<forward>(
425 (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
426 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
427 } else {
428 GGML_ABORT("fatal error");
429 }
430 } else if (is_vision) {
431 if (src0->type == GGML_TYPE_F32) {
432 rope_vision_cuda<forward>(
433 (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
434 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
435 } else if (src0->type == GGML_TYPE_F16) {
436 rope_vision_cuda<forward>(
437 (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
438 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
439 } else {
440 GGML_ABORT("fatal error");
441 }
442 } else {
443 if (src0->type == GGML_TYPE_F32) {
444 rope_norm_cuda<forward>(
445 (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
446 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
447 } else if (src0->type == GGML_TYPE_F16) {
448 rope_norm_cuda<forward>(
449 (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
450 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
451 } else {
452 GGML_ABORT("fatal error");
453 }
454 }
455}
456
457void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
458 ggml_cuda_op_rope_impl<true>(ctx, dst);
459}
460
461void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
462 ggml_cuda_op_rope_impl<false>(ctx, dst);
463}
464