| 1 | #include "rope.cuh" |
| 2 | |
| 3 | struct rope_corr_dims { |
| 4 | float v[2]; |
| 5 | }; |
| 6 | |
| 7 | |
| 8 | struct mrope_sections { |
| 9 | int v[4]; |
| 10 | }; |
| 11 | |
| 12 | static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { |
| 13 | const float y = (i0 / 2 - low) / max(a: 0.001f, b: high - low); |
| 14 | return 1.0f - min(a: 1.0f, b: max(a: 0.0f, b: y)); |
| 15 | } |
| 16 | |
| 17 | // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn |
| 18 | // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. |
| 19 | template<bool forward> |
| 20 | static __device__ void rope_yarn( |
| 21 | const float , const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor, |
| 22 | float mscale, float & cos_theta, float & sin_theta) { |
| 23 | // Get n-d rotational scaling corrected for extrapolation |
| 24 | float theta_interp = freq_scale * theta_extrap; |
| 25 | float theta = theta_interp; |
| 26 | if (ext_factor != 0.0f) { |
| 27 | float ramp_mix = rope_yarn_ramp(low: corr_dims.v[0], high: corr_dims.v[1], i0) * ext_factor; |
| 28 | theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; |
| 29 | |
| 30 | // Get n-d magnitude scaling corrected for interpolation |
| 31 | mscale *= 1.0f + 0.1f * logf(a: 1.0f / freq_scale); |
| 32 | } |
| 33 | cos_theta = cosf(a: theta) * mscale; |
| 34 | sin_theta = sinf(a: theta) * mscale; |
| 35 | if (!forward) { |
| 36 | sin_theta *= -1.0f; |
| 37 | } |
| 38 | } |
| 39 | |
| 40 | template<bool forward, bool has_ff, typename T> |
| 41 | static __global__ void rope_norm( |
| 42 | const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, |
| 43 | const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, |
| 44 | const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { |
| 45 | const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); |
| 46 | |
| 47 | if (i0 >= ne0) { |
| 48 | return; |
| 49 | } |
| 50 | |
| 51 | const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; |
| 52 | |
| 53 | const int row_x = row_dst % ne1; |
| 54 | const int channel_x = row_dst / ne1; |
| 55 | |
| 56 | const int idst = row_dst*ne0 + i0; |
| 57 | const int ix = channel_x*s2 + row_x*s1 + i0; |
| 58 | |
| 59 | if (i0 >= n_dims) { |
| 60 | dst[idst + 0] = x[ix + 0]; |
| 61 | dst[idst + 1] = x[ix + 1]; |
| 62 | |
| 63 | return; |
| 64 | } |
| 65 | |
| 66 | const float theta_base = pos[channel_x]*powf(a: theta_scale, b: i0/2.0f); |
| 67 | |
| 68 | const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; |
| 69 | |
| 70 | float cos_theta; |
| 71 | float sin_theta; |
| 72 | |
| 73 | rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); |
| 74 | |
| 75 | const float x0 = x[ix + 0]; |
| 76 | const float x1 = x[ix + 1]; |
| 77 | |
| 78 | dst[idst + 0] = x0*cos_theta - x1*sin_theta; |
| 79 | dst[idst + 1] = x0*sin_theta + x1*cos_theta; |
| 80 | } |
| 81 | |
| 82 | template<bool forward, bool has_ff, typename T> |
| 83 | static __global__ void rope_neox( |
| 84 | const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, |
| 85 | const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, |
| 86 | const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { |
| 87 | const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); |
| 88 | |
| 89 | if (i0 >= ne0) { |
| 90 | return; |
| 91 | } |
| 92 | |
| 93 | const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; |
| 94 | |
| 95 | const int row_x = row_dst % ne1; |
| 96 | const int channel_x = row_dst / ne1; |
| 97 | |
| 98 | const int idst = row_dst*ne0 + i0/2; |
| 99 | const int ix = channel_x*s2 + row_x*s1 + i0/2; |
| 100 | |
| 101 | if (i0 >= n_dims) { |
| 102 | dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; |
| 103 | dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; |
| 104 | |
| 105 | return; |
| 106 | } |
| 107 | |
| 108 | const float theta_base = pos[channel_x]*powf(a: theta_scale, b: i0/2.0f); |
| 109 | |
| 110 | const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; |
| 111 | |
| 112 | float cos_theta; |
| 113 | float sin_theta; |
| 114 | |
| 115 | rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); |
| 116 | |
| 117 | const float x0 = x[ix + 0]; |
| 118 | const float x1 = x[ix + n_dims/2]; |
| 119 | |
| 120 | dst[idst + 0] = x0*cos_theta - x1*sin_theta; |
| 121 | dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; |
| 122 | } |
| 123 | |
| 124 | template<bool forward, bool has_ff, typename T> |
| 125 | static __global__ void rope_multi( |
| 126 | const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, |
| 127 | const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, |
| 128 | const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { |
| 129 | const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); |
| 130 | |
| 131 | if (i0 >= ne0) { |
| 132 | return; |
| 133 | } |
| 134 | |
| 135 | const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; |
| 136 | |
| 137 | const int row_x = row_dst % ne1; |
| 138 | const int channel_x = row_dst / ne1; |
| 139 | |
| 140 | const int idst = row_dst*ne0 + i0/2; |
| 141 | const int ix = channel_x*s2 + row_x*s1 + i0/2; |
| 142 | |
| 143 | if (i0 >= n_dims) { |
| 144 | dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; |
| 145 | dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; |
| 146 | |
| 147 | return; |
| 148 | } |
| 149 | |
| 150 | const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; |
| 151 | const int sec_w = sections.v[1] + sections.v[0]; |
| 152 | const int sector = (i0 / 2) % sect_dims; |
| 153 | |
| 154 | float theta_base = 0.0; |
| 155 | if (is_imrope) { |
| 156 | if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h |
| 157 | theta_base = pos[channel_x + ne2 * 1]*powf(a: theta_scale, b: i0/2.0f); |
| 158 | } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w |
| 159 | theta_base = pos[channel_x + ne2 * 2]*powf(a: theta_scale, b: i0/2.0f); |
| 160 | } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t |
| 161 | theta_base = pos[channel_x]*powf(a: theta_scale, b: i0/2.0f); |
| 162 | } else { |
| 163 | theta_base = pos[channel_x + ne2 * 3]*powf(a: theta_scale, b: i0/2.0f); |
| 164 | } |
| 165 | } else { |
| 166 | if (sector < sections.v[0]) { |
| 167 | theta_base = pos[channel_x]*powf(a: theta_scale, b: i0/2.0f); |
| 168 | } |
| 169 | else if (sector >= sections.v[0] && sector < sec_w) { |
| 170 | theta_base = pos[channel_x + ne2 * 1]*powf(a: theta_scale, b: i0/2.0f); |
| 171 | } |
| 172 | else if (sector >= sec_w && sector < sec_w + sections.v[2]) { |
| 173 | theta_base = pos[channel_x + ne2 * 2]*powf(a: theta_scale, b: i0/2.0f); |
| 174 | } |
| 175 | else if (sector >= sec_w + sections.v[2]) { |
| 176 | theta_base = pos[channel_x + ne2 * 3]*powf(a: theta_scale, b: i0/2.0f); |
| 177 | } |
| 178 | } |
| 179 | |
| 180 | const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; |
| 181 | |
| 182 | float cos_theta; |
| 183 | float sin_theta; |
| 184 | |
| 185 | rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); |
| 186 | |
| 187 | const float x0 = x[ix + 0]; |
| 188 | const float x1 = x[ix + n_dims/2]; |
| 189 | |
| 190 | dst[idst + 0] = x0*cos_theta - x1*sin_theta; |
| 191 | dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; |
| 192 | } |
| 193 | |
| 194 | template<bool forward, bool has_ff, typename T> |
| 195 | static __global__ void rope_vision( |
| 196 | const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, |
| 197 | const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, |
| 198 | const float theta_scale, const float * freq_factors, const mrope_sections sections) { |
| 199 | const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); |
| 200 | |
| 201 | if (i0 >= ne0) { |
| 202 | return; |
| 203 | } |
| 204 | |
| 205 | const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; |
| 206 | |
| 207 | const int row_x = row_dst % ne1; |
| 208 | const int channel_x = row_dst / ne1; |
| 209 | |
| 210 | const int idst = row_dst*ne0 + i0/2; |
| 211 | const int ix = channel_x*s2 + row_x*s1 + i0/2; |
| 212 | |
| 213 | const int sect_dims = sections.v[0] + sections.v[1]; |
| 214 | const int sec_w = sections.v[1] + sections.v[0]; |
| 215 | const int sector = (i0 / 2) % sect_dims; |
| 216 | |
| 217 | float theta_base = 0.0; |
| 218 | if (sector < sections.v[0]) { |
| 219 | const int p = sector; |
| 220 | theta_base = pos[channel_x]*powf(a: theta_scale, b: p); |
| 221 | } |
| 222 | else if (sector >= sections.v[0] && sector < sec_w) { |
| 223 | const int p = sector - sections.v[0]; |
| 224 | theta_base = pos[channel_x + ne2]*powf(a: theta_scale, b: p); |
| 225 | } |
| 226 | |
| 227 | const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; |
| 228 | |
| 229 | float cos_theta; |
| 230 | float sin_theta; |
| 231 | |
| 232 | rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); |
| 233 | |
| 234 | const float x0 = x[ix + 0]; |
| 235 | const float x1 = x[ix + n_dims]; |
| 236 | |
| 237 | dst[idst + 0] = x0*cos_theta - x1*sin_theta; |
| 238 | dst[idst + n_dims] = x0*sin_theta + x1*cos_theta; |
| 239 | } |
| 240 | |
| 241 | template<bool forward, typename T> |
| 242 | static void rope_norm_cuda( |
| 243 | const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, |
| 244 | const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, |
| 245 | const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { |
| 246 | GGML_ASSERT(ne0 % 2 == 0); |
| 247 | const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); |
| 248 | const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); |
| 249 | const dim3 block_nums(nr, n_blocks_x, 1); |
| 250 | |
| 251 | const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims); |
| 252 | |
| 253 | if (freq_factors == nullptr) { |
| 254 | rope_norm<forward, false><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 255 | x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, |
| 256 | attn_factor, corr_dims, theta_scale, freq_factors); |
| 257 | } else { |
| 258 | rope_norm<forward, true><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 259 | x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, |
| 260 | attn_factor, corr_dims, theta_scale, freq_factors); |
| 261 | } |
| 262 | } |
| 263 | |
| 264 | template<bool forward, typename T> |
| 265 | static void rope_neox_cuda( |
| 266 | const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, |
| 267 | const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, |
| 268 | const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { |
| 269 | GGML_ASSERT(ne0 % 2 == 0); |
| 270 | const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); |
| 271 | const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); |
| 272 | const dim3 block_nums(nr, n_blocks_x, 1); |
| 273 | |
| 274 | const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims); |
| 275 | |
| 276 | if (freq_factors == nullptr) { |
| 277 | rope_neox<forward, false, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 278 | x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, |
| 279 | attn_factor, corr_dims, theta_scale, freq_factors); |
| 280 | } else { |
| 281 | rope_neox<forward, true, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 282 | x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, |
| 283 | attn_factor, corr_dims, theta_scale, freq_factors); |
| 284 | } |
| 285 | } |
| 286 | |
| 287 | template<bool forward, typename T> |
| 288 | static void rope_multi_cuda( |
| 289 | const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, |
| 290 | const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, |
| 291 | const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { |
| 292 | GGML_ASSERT(ne0 % 2 == 0); |
| 293 | const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); |
| 294 | const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); |
| 295 | const dim3 block_nums(nr, n_blocks_x, 1); |
| 296 | |
| 297 | const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims); |
| 298 | |
| 299 | if (freq_factors == nullptr) { |
| 300 | rope_multi<forward, false, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 301 | x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, |
| 302 | attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); |
| 303 | } else { |
| 304 | rope_multi<forward, true, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 305 | x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, |
| 306 | attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); |
| 307 | } |
| 308 | } |
| 309 | |
| 310 | template<bool forward, typename T> |
| 311 | static void rope_vision_cuda( |
| 312 | const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, |
| 313 | const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, |
| 314 | const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { |
| 315 | GGML_ASSERT(ne0 % 2 == 0); |
| 316 | const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); |
| 317 | const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); |
| 318 | const dim3 block_nums(nr, n_blocks_x, 1); |
| 319 | // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq) |
| 320 | // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE); |
| 321 | |
| 322 | const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims); |
| 323 | |
| 324 | if (freq_factors == nullptr) { |
| 325 | rope_vision<forward, false, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 326 | x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, |
| 327 | attn_factor, corr_dims, theta_scale, freq_factors, sections); |
| 328 | } else { |
| 329 | rope_vision<forward, true, T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 330 | x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, |
| 331 | attn_factor, corr_dims, theta_scale, freq_factors, sections); |
| 332 | } |
| 333 | } |
| 334 | |
| 335 | template <bool forward> |
| 336 | void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 337 | const ggml_tensor * src0 = dst->src[0]; |
| 338 | const ggml_tensor * src1 = dst->src[1]; |
| 339 | const ggml_tensor * src2 = dst->src[2]; |
| 340 | |
| 341 | const float * src0_d = (const float *)src0->data; |
| 342 | const float * src1_d = (const float *)src1->data; |
| 343 | |
| 344 | float * dst_d = (float *)dst->data; |
| 345 | cudaStream_t stream = ctx.stream(); |
| 346 | |
| 347 | GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); |
| 348 | GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 349 | GGML_ASSERT(src0->type == dst->type); |
| 350 | |
| 351 | const int64_t ne00 = src0->ne[0]; // head dims |
| 352 | const int64_t ne01 = src0->ne[1]; // num heads |
| 353 | const int64_t ne02 = src0->ne[2]; // num heads |
| 354 | const int64_t nr = ggml_nrows(src0); |
| 355 | |
| 356 | const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); |
| 357 | const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); |
| 358 | |
| 359 | //const int n_past = ((int32_t *) dst->op_params)[0]; |
| 360 | const int n_dims = ((int32_t *) dst->op_params)[1]; |
| 361 | const int mode = ((int32_t *) dst->op_params)[2]; |
| 362 | //const int n_ctx = ((int32_t *) dst->op_params)[3]; |
| 363 | const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; |
| 364 | mrope_sections sections; |
| 365 | |
| 366 | // RoPE alteration for extended context |
| 367 | float freq_base; |
| 368 | float freq_scale; |
| 369 | float ext_factor; |
| 370 | float attn_factor; |
| 371 | float beta_fast; |
| 372 | float beta_slow; |
| 373 | |
| 374 | memcpy(dest: &freq_base, src: (int32_t *) dst->op_params + 5, n: sizeof(float)); |
| 375 | memcpy(dest: &freq_scale, src: (int32_t *) dst->op_params + 6, n: sizeof(float)); |
| 376 | memcpy(dest: &ext_factor, src: (int32_t *) dst->op_params + 7, n: sizeof(float)); |
| 377 | memcpy(dest: &attn_factor, src: (int32_t *) dst->op_params + 8, n: sizeof(float)); |
| 378 | memcpy(dest: &beta_fast, src: (int32_t *) dst->op_params + 9, n: sizeof(float)); |
| 379 | memcpy(dest: &beta_slow, src: (int32_t *) dst->op_params + 10, n: sizeof(float)); |
| 380 | memcpy(dest: §ions.v, src: (int32_t *) dst->op_params + 11, n: sizeof(int)*4); |
| 381 | |
| 382 | const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; |
| 383 | const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; |
| 384 | const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; |
| 385 | const bool is_vision = mode == GGML_ROPE_TYPE_VISION; |
| 386 | |
| 387 | if (is_mrope) { |
| 388 | GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0); |
| 389 | } |
| 390 | |
| 391 | if (is_vision) { |
| 392 | GGML_ASSERT(n_dims == ne00/2); |
| 393 | } |
| 394 | |
| 395 | const int32_t * pos = (const int32_t *) src1_d; |
| 396 | |
| 397 | const float * freq_factors = nullptr; |
| 398 | if (src2 != nullptr) { |
| 399 | freq_factors = (const float *) src2->data; |
| 400 | } |
| 401 | |
| 402 | rope_corr_dims corr_dims; |
| 403 | ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); |
| 404 | |
| 405 | // compute |
| 406 | if (is_neox) { |
| 407 | if (src0->type == GGML_TYPE_F32) { |
| 408 | rope_neox_cuda<forward>( |
| 409 | (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, |
| 410 | freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); |
| 411 | } else if (src0->type == GGML_TYPE_F16) { |
| 412 | rope_neox_cuda<forward>( |
| 413 | (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, |
| 414 | freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); |
| 415 | } else { |
| 416 | GGML_ABORT("fatal error" ); |
| 417 | } |
| 418 | } else if (is_mrope && !is_vision) { |
| 419 | if (src0->type == GGML_TYPE_F32) { |
| 420 | rope_multi_cuda<forward>( |
| 421 | (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, |
| 422 | freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); |
| 423 | } else if (src0->type == GGML_TYPE_F16) { |
| 424 | rope_multi_cuda<forward>( |
| 425 | (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, |
| 426 | freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); |
| 427 | } else { |
| 428 | GGML_ABORT("fatal error" ); |
| 429 | } |
| 430 | } else if (is_vision) { |
| 431 | if (src0->type == GGML_TYPE_F32) { |
| 432 | rope_vision_cuda<forward>( |
| 433 | (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, |
| 434 | freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); |
| 435 | } else if (src0->type == GGML_TYPE_F16) { |
| 436 | rope_vision_cuda<forward>( |
| 437 | (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, |
| 438 | freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); |
| 439 | } else { |
| 440 | GGML_ABORT("fatal error" ); |
| 441 | } |
| 442 | } else { |
| 443 | if (src0->type == GGML_TYPE_F32) { |
| 444 | rope_norm_cuda<forward>( |
| 445 | (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, |
| 446 | freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); |
| 447 | } else if (src0->type == GGML_TYPE_F16) { |
| 448 | rope_norm_cuda<forward>( |
| 449 | (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, |
| 450 | freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); |
| 451 | } else { |
| 452 | GGML_ABORT("fatal error" ); |
| 453 | } |
| 454 | } |
| 455 | } |
| 456 | |
| 457 | void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 458 | ggml_cuda_op_rope_impl<true>(ctx, dst); |
| 459 | } |
| 460 | |
| 461 | void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 462 | ggml_cuda_op_rope_impl<false>(ctx, dst); |
| 463 | } |
| 464 | |