| 1 | #include "ggml.h" |
| 2 | #include "mmf.cuh" |
| 3 | #include "mmid.cuh" |
| 4 | |
| 5 | |
| 6 | void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { |
| 7 | GGML_ASSERT( src1->type == GGML_TYPE_F32); |
| 8 | GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); |
| 9 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 10 | |
| 11 | |
| 12 | GGML_TENSOR_BINARY_OP_LOCALS; |
| 13 | |
| 14 | const size_t ts_src0 = ggml_type_size(src0->type); |
| 15 | const size_t ts_src1 = ggml_type_size(src1->type); |
| 16 | const size_t ts_dst = ggml_type_size(dst->type); |
| 17 | |
| 18 | GGML_ASSERT(ne13 == ne3); |
| 19 | |
| 20 | GGML_ASSERT( nb00 == ts_src0); |
| 21 | GGML_ASSERT( nb10 == ts_src1); |
| 22 | GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); |
| 23 | GGML_ASSERT( nb0 == ts_dst); |
| 24 | |
| 25 | const float * src1_d = (const float *) src1->data; |
| 26 | const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; |
| 27 | float * dst_d = (float *) dst->data; |
| 28 | |
| 29 | const int64_t s01 = src0->nb[1] / ts_src0; |
| 30 | const int64_t s11 = src1->nb[1] / ts_src1; |
| 31 | const int64_t s1 = dst->nb[1] / ts_dst; |
| 32 | const int64_t s02 = src0->nb[2] / ts_src0; |
| 33 | const int64_t s12 = src1->nb[2] / ts_src1; |
| 34 | const int64_t s2 = dst->nb[2] / ts_dst; |
| 35 | const int64_t s03 = src0->nb[3] / ts_src0; |
| 36 | const int64_t s13 = src1->nb[3] / ts_src1; |
| 37 | const int64_t s3 = dst->nb[3] / ts_dst; |
| 38 | |
| 39 | const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; |
| 40 | const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; |
| 41 | |
| 42 | mmf_ids_data ids_info{}; |
| 43 | mmf_ids_data * ids_info_ptr = nullptr; |
| 44 | ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev; |
| 45 | ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev; |
| 46 | ggml_cuda_pool_alloc<int32_t> expert_bounds_dev; |
| 47 | |
| 48 | // For MUL_MAT_ID the memory layout is different than for MUL_MAT: |
| 49 | const int64_t ncols_dst = ids ? ne2 : ne1; |
| 50 | const int64_t nchannels_dst = ids ? ne1 : ne2; |
| 51 | |
| 52 | const int64_t stride_col_dst = ids ? s2 : s1; |
| 53 | const int64_t stride_col_y = ids ? s12 : s11; |
| 54 | const int64_t stride_channel_dst = ids ? s1 : s2; |
| 55 | |
| 56 | int64_t stride_channel_y = ids ? s11 : s12; |
| 57 | int64_t nchannels_y = ids ? ne11 : ne12; |
| 58 | |
| 59 | //mul_mat_id: handle broadcast |
| 60 | if (ids && nchannels_y == 1) { |
| 61 | stride_channel_y = 0; |
| 62 | nchannels_y = ids->ne[0]; |
| 63 | } |
| 64 | |
| 65 | if (ids && ncols_dst > 16) { |
| 66 | const int64_t n_expert_used = ids->ne[0]; |
| 67 | const int64_t n_experts = ne02; |
| 68 | const int64_t n_tokens = ne12; |
| 69 | const int64_t ne_get_rows = n_tokens * n_expert_used; |
| 70 | |
| 71 | ids_src_compact_dev.alloc(pool&: ctx.pool(), size: ne_get_rows); |
| 72 | ids_dst_compact_dev.alloc(pool&: ctx.pool(), size: ne_get_rows); |
| 73 | expert_bounds_dev.alloc(pool&: ctx.pool(), size: n_experts + 1); |
| 74 | |
| 75 | const int si1 = static_cast<int>(ids_s1); |
| 76 | const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]); |
| 77 | |
| 78 | GGML_ASSERT(sis1 > 0); |
| 79 | |
| 80 | ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(), |
| 81 | static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream()); |
| 82 | CUDA_CHECK(cudaGetLastError()); |
| 83 | |
| 84 | ids_info.ids_src_compact = ids_src_compact_dev.get(); |
| 85 | ids_info.ids_dst_compact = ids_dst_compact_dev.get(); |
| 86 | ids_info.expert_bounds_dev = expert_bounds_dev.get(); |
| 87 | ids_info.n_experts = static_cast<int>(n_experts); |
| 88 | ids_info.sis1 = sis1; |
| 89 | ids_info_ptr = &ids_info; |
| 90 | } |
| 91 | |
| 92 | switch (src0->type) { |
| 93 | case GGML_TYPE_F32: { |
| 94 | const float * src0_d = (const float *) src0->data; |
| 95 | constexpr int vals_per_T = 1; |
| 96 | mul_mat_f_switch_cols_per_block( |
| 97 | src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, |
| 98 | ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, |
| 99 | ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); |
| 100 | } break; |
| 101 | case GGML_TYPE_F16: { |
| 102 | const half2 * src0_d = (const half2 *) src0->data; |
| 103 | constexpr int vals_per_T = 2; |
| 104 | mul_mat_f_switch_cols_per_block( |
| 105 | src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, |
| 106 | ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, |
| 107 | ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); |
| 108 | } break; |
| 109 | case GGML_TYPE_BF16: { |
| 110 | const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; |
| 111 | constexpr int vals_per_T = 2; |
| 112 | mul_mat_f_switch_cols_per_block( |
| 113 | src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, |
| 114 | ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, |
| 115 | ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); |
| 116 | } break; |
| 117 | default: |
| 118 | GGML_ABORT("unsupported type: %s" , ggml_type_name(src0->type)); |
| 119 | } |
| 120 | } |
| 121 | |
| 122 | bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, |
| 123 | const size_t * src0_nb, const int src1_ncols, bool mul_mat_id) { |
| 124 | if (ggml_is_quantized(type)) { |
| 125 | return false; |
| 126 | } |
| 127 | |
| 128 | const size_t ts = ggml_type_size(type); |
| 129 | if (src0_ne[0] % (warp_size * (4/ts)) != 0) { |
| 130 | return false; |
| 131 | } |
| 132 | |
| 133 | if (src0_nb[0] != ts) { |
| 134 | return false; |
| 135 | } |
| 136 | |
| 137 | // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash: |
| 138 | for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { |
| 139 | if (src0_nb[i] % (2*ts) != 0) { |
| 140 | return false; |
| 141 | } |
| 142 | } |
| 143 | if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { |
| 144 | return false; |
| 145 | } |
| 146 | |
| 147 | if (mul_mat_id) { |
| 148 | if (src0_ne[1] <= 1024 && src1_ncols > 512) { |
| 149 | return false; |
| 150 | } else if(src0_ne[1] > 1024 && src1_ncols > 128) { |
| 151 | return false; |
| 152 | } |
| 153 | } else { |
| 154 | if (src1_ncols > 16) { |
| 155 | return false; |
| 156 | } |
| 157 | } |
| 158 | |
| 159 | switch (type) { |
| 160 | case GGML_TYPE_F32: |
| 161 | return ampere_mma_available(cc); |
| 162 | case GGML_TYPE_F16: |
| 163 | return volta_mma_available(cc) || turing_mma_available(cc); |
| 164 | case GGML_TYPE_BF16: |
| 165 | return ampere_mma_available(cc); |
| 166 | default: |
| 167 | return false; |
| 168 | } |
| 169 | } |
| 170 | |