1#include "common.cuh"
2#include "fattn-tile.cuh"
3#include "fattn-wmma-f16.cuh"
4
5void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
6 const ggml_tensor * K = dst->src[1];
7 const ggml_tensor * V = dst->src[2];
8 switch (K->ne[0]) {
9 case 40: {
10 GGML_ASSERT(V->ne[0] == K->ne[0]);
11 ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
12 } break;
13 case 64: {
14 GGML_ASSERT(V->ne[0] == K->ne[0]);
15 ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
16 } break;
17 case 72: {
18 GGML_ASSERT(V->ne[0] == K->ne[0]);
19 ggml_cuda_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
20 } break;
21 case 80: {
22 GGML_ASSERT(V->ne[0] == K->ne[0]);
23 ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
24 } break;
25 case 96: {
26 GGML_ASSERT(V->ne[0] == K->ne[0]);
27 ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
28 } break;
29 case 112: {
30 GGML_ASSERT(V->ne[0] == K->ne[0]);
31 ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
32 } break;
33 case 128: {
34 GGML_ASSERT(V->ne[0] == K->ne[0]);
35 ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
36 } break;
37 case 256: {
38 GGML_ASSERT(V->ne[0] == K->ne[0]);
39 ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
40 } break;
41 case 576: {
42 GGML_ASSERT(V->ne[0] == 512);
43 ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
44 } break;
45 default: {
46 GGML_ABORT("Unsupported head size");
47 } break;
48 }
49}
50