From acf3c890fa22d2dd5aaa20c6a2ec5724795c26ef Mon Sep 17 00:00:00 2001 From: rmatif Date: Thu, 4 Sep 2025 16:06:48 +0000 Subject: [PATCH 1/4] add mul_mat variant for embed gpu --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 112 ++++++++++++ .../vulkan-shaders/mul_mm_embed.comp | 160 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 3 + 3 files changed, 275 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index cd1c66ba7b476..a3a0d8a5a15a2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -77,6 +77,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_APPLE 0x106b #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de +#define VK_VENDOR_ID_ARM 0x13B5 #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 @@ -448,6 +449,8 @@ struct vk_device_struct { vk_matmul_pipeline pipeline_matmul_bf16 {}; vk_matmul_pipeline2 pipeline_matmul_f16; vk_matmul_pipeline2 pipeline_matmul_f16_f32; + vk_pipeline pipeline_matmul_f16_f32_embed; + vk_pipeline pipeline_matmul_f32_f32_embed; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; @@ -702,6 +705,15 @@ struct vk_mat_vec_id_push_constants { uint32_t nei0; uint32_t ne11; }; +struct vk_mat_mat_embed_push_constants { + uint32_t M; + uint32_t N; + uint32_t K; + uint32_t stride_a; + uint32_t stride_b; + uint32_t stride_d; +}; + struct vk_flash_attn_push_constants { uint32_t N; uint32_t KV; @@ -2901,6 +2913,16 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); } } + + if (device->vendor_id == VK_VENDOR_ID_ARM) { + ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32_embed, "mul_mat_f16_f32_embed", + mul_mat_f16_f32_embed_len, mul_mat_f16_f32_embed_data, "main", 3, + sizeof(vk_mat_mat_embed_push_constants), {64, 64, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f32_embed, "mul_mat_f32_f32_embed", + mul_mat_f32_f32_embed_len, mul_mat_f32_f32_embed_data, "main", 3, + sizeof(vk_mat_mat_embed_push_constants), {64, 64, 1}, {}, 1); + } + // reusing CREATE_MM from the fp32 path if ((device->coopmat2 || device->coopmat_support) #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) @@ -5726,6 +5748,96 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t ne12 = src1->ne[2]; const uint64_t ne13 = src1->ne[3]; + if (ctx->device->vendor_id == VK_VENDOR_ID_ARM && + (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && + src1->type == GGML_TYPE_F32 && + ne02 == 1 && ne03 == 1 && + ne12 == 1 && ne13 == 1) { + + const uint32_t M = ne01; + const uint32_t N = ne11; + const uint32_t K = ne10; + + vk_pipeline pipeline = nullptr; + vk_buffer d_X; + uint64_t x_buf_offset; + uint32_t stride_a; + + if (ggml_is_quantized(src0->type)) { + vk_pipeline dequant_pipeline = ggml_vk_get_to_fp16(ctx, src0->type); + + if (dequant_pipeline) { + const uint64_t x_sz = sizeof(ggml_fp16_t) * M * K; + + if (dryrun) { + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + } + ggml_pipeline_request_descriptor_sets(ctx, dequant_pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_f16_f32_embed, 1); + return; + } + + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + + const std::vector pc = { (uint32_t)M, (uint32_t)K, (uint32_t)K, (uint32_t)K, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, dequant_pipeline, { + vk_subbuffer{ src0_buf_ctx->dev_buffer, vk_tensor_offset(src0) + src0->view_offs, VK_WHOLE_SIZE }, + vk_subbuffer{ ctx->prealloc_x, 0, VK_WHOLE_SIZE } + }, pc, { (uint32_t)(ggml_nelements(src0)), 1, 1}); + + d_X = ctx->prealloc_x; + x_buf_offset = 0; + stride_a = K; + pipeline = ctx->device->pipeline_matmul_f16_f32_embed; + } + } else { + if (src0->type == GGML_TYPE_F16) { + pipeline = ctx->device->pipeline_matmul_f16_f32_embed; + } else { + pipeline = ctx->device->pipeline_matmul_f32_f32_embed; + } + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + //GGML_LOG_INFO("%s: using Mali-optimized shader for %s x %s, %ux%u matrix (K=%u)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type), M, N, K); + + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + d_X = src0_buf_ctx->dev_buffer; + x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + stride_a = src0->nb[1] / ggml_type_size(src0->type); + } + + if (pipeline != nullptr) { + ggml_vk_sync_buffers(ctx, subctx); // Ensure dequant (if any) is finished + + const uint32_t stride_b = src1->nb[1] / ggml_type_size(src1->type); + const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type); + + const vk_mat_mat_embed_push_constants pc = { M, N, K, stride_a, stride_b, stride_d }; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + vk_buffer d_Y = src1_buf_ctx->dev_buffer; + const uint64_t y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_X, x_buf_offset, VK_WHOLE_SIZE }, + vk_subbuffer{ d_Y, y_buf_offset, VK_WHOLE_SIZE }, + vk_subbuffer{ d_D, d_buf_offset, VK_WHOLE_SIZE }, + }, pc, { M, N, 1 }); + + return; + } + } + const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp new file mode 100644 index 0000000000000..133e12b604a45 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp @@ -0,0 +1,160 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require + +#ifdef A_TYPE_FP16 + #define A_VEC4_TYPE f16vec4 + #define A_SCALAR_TYPE float16_t + #define A_VEC4_ZERO f16vec4(0.0hf) + #define A_VEC4_CAST(v) vec4(v) +#else + #define A_VEC4_TYPE vec4 + #define A_SCALAR_TYPE float + #define A_VEC4_ZERO vec4(0.0f) + #define A_VEC4_CAST(v) (v) +#endif + +layout(local_size_x = 16, local_size_y = 8, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A_BUFFER { A_SCALAR_TYPE data_a[]; }; +layout (binding = 1) readonly buffer B_BUFFER { float data_b[]; }; +layout (binding = 2) writeonly buffer D_BUFFER { float data_d[]; }; + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; +} p; + +const uint BM = 32; +const uint BN = 32; +const uint BK = 32; + +const uint TM = 4; +const uint TN = 2; + +const uint WG_X = 16; +const uint WG_Y = 8; +const uint WG_SIZE = WG_X * WG_Y; + +const uint VEC_K = BK / 4; + +shared A_VEC4_TYPE buf_a[BM][VEC_K]; +shared vec4 buf_b[BN][VEC_K]; + +void main() { + const uint lidx = gl_LocalInvocationID.x; + const uint lidy = gl_LocalInvocationID.y; + const uint lid = lidy * WG_X + lidx; + + const uint group_m = gl_WorkGroupID.x * BM; + const uint group_n = gl_WorkGroupID.y * BN; + + float sums[TM][TN]; + #pragma unroll + for (uint i = 0; i < TM; i++) { + #pragma unroll + for (uint j = 0; j < TN; j++) { + sums[i][j] = 0.0f; + } + } + + const uint num_k_tiles = (p.K + BK - 1) / BK; + + for (uint t = 0; t < num_k_tiles; t++) { + const uint k_tile_start = t * BK; + + #pragma unroll + for(uint i = 0; i < 2; ++i) { + uint load_idx = lid + i * WG_SIZE; + uint m = load_idx / VEC_K; + uint k = load_idx % VEC_K; + uint global_m = group_m + m; + uint k_scalar = k_tile_start + k * 4; + if (global_m < p.M && k_scalar + 3 < p.K) { + uint base_idx = global_m * p.stride_a + k_scalar; + buf_a[m][k] = A_VEC4_TYPE(data_a[base_idx], data_a[base_idx+1], data_a[base_idx+2], data_a[base_idx+3]); + } else { + buf_a[m][k] = A_VEC4_ZERO; + } + } + + #pragma unroll + for(uint i = 0; i < 2; ++i) { + uint load_idx = lid + i * WG_SIZE; + uint n = load_idx / VEC_K; + uint k = load_idx % VEC_K; + uint global_n = group_n + n; + uint k_scalar = k_tile_start + k * 4; + if (global_n < p.N && k_scalar + 3 < p.K) { + uint base_idx = global_n * p.stride_b + k_scalar; + buf_b[n][k] = vec4(data_b[base_idx], data_b[base_idx+1], data_b[base_idx+2], data_b[base_idx+3]); + } else { + buf_b[n][k] = vec4(0.0f); + } + } + + barrier(); + + uint m_base = lidy * TM; + uint n_base = lidx * TN; + + A_VEC4_TYPE a_reg[TM]; + vec4 b_reg[TN]; + A_VEC4_TYPE a_reg_next[TM]; + vec4 b_reg_next[TN]; + + #pragma unroll + for (uint i = 0; i < TM; i++) { a_reg[i] = buf_a[m_base + i][0]; } + #pragma unroll + for (uint j = 0; j < TN; j++) { b_reg[j] = buf_b[n_base + j][0]; } + + for (uint k = 1; k < VEC_K; k++) { + #pragma unroll + for (uint i = 0; i < TM; i++) { a_reg_next[i] = buf_a[m_base + i][k]; } + #pragma unroll + for (uint j = 0; j < TN; j++) { b_reg_next[j] = buf_b[n_base + j][k]; } + + #pragma unroll + for (uint i = 0; i < TM; i++) { + #pragma unroll + for (uint j = 0; j < TN; j++) { + sums[i][j] += dot(A_VEC4_CAST(a_reg[i]), b_reg[j]); + } + } + + a_reg = a_reg_next; + b_reg = b_reg_next; + } + + #pragma unroll + for (uint i = 0; i < TM; i++) { + #pragma unroll + for (uint j = 0; j < TN; j++) { + sums[i][j] += dot(A_VEC4_CAST(a_reg[i]), b_reg[j]); + } + } + + barrier(); + } + + uint m_base = lidy * TM; + uint n_base = lidx * TN; + + #pragma unroll + for (uint i = 0; i < TM; i++) { + uint global_m = group_m + m_base + i; + #pragma unroll + for (uint j = 0; j < TN; j++) { + uint global_n = group_n + n_base + j; + if (global_m < p.M && global_n < p.N) { + data_d[global_n * p.stride_d + global_m] = sums[i][j]; + } + } + } +} \ No newline at end of file diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 613498d0d50b7..997a904f499de 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -454,6 +454,9 @@ void process_shaders() { } } + string_to_spv("mul_mat_f16_f32_embed", "mul_mm_embed.comp", {{"A_TYPE_FP16", "1"}}); + string_to_spv("mul_mat_f32_f32_embed", "mul_mm_embed.comp", {}); + // flash attention for (const auto& f16acc : {false, true}) { std::map fa_base_dict = base_dict; From e103de2a09df2bee7bb9259dbfc95f4f1a878ce0 Mon Sep 17 00:00:00 2001 From: rmatif Date: Thu, 4 Sep 2025 19:04:29 +0000 Subject: [PATCH 2/4] add line ending --- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp index 133e12b604a45..17745c5de8c6a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp @@ -157,4 +157,4 @@ void main() { } } } -} \ No newline at end of file +} From d520e63a6744246471aa9a056c81515e9ac5ac60 Mon Sep 17 00:00:00 2001 From: rmatif Date: Fri, 5 Sep 2025 21:23:34 +0000 Subject: [PATCH 3/4] refactor and opt mulmat shaders and adress review comments --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 85 +++++++---- .../vulkan-shaders/mul_mm_embed.comp | 139 +++++++++--------- 2 files changed, 128 insertions(+), 96 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a3a0d8a5a15a2..0abefc67c54e6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -705,15 +705,6 @@ struct vk_mat_vec_id_push_constants { uint32_t nei0; uint32_t ne11; }; -struct vk_mat_mat_embed_push_constants { - uint32_t M; - uint32_t N; - uint32_t K; - uint32_t stride_a; - uint32_t stride_b; - uint32_t stride_d; -}; - struct vk_flash_attn_push_constants { uint32_t N; uint32_t KV; @@ -2915,12 +2906,28 @@ static void ggml_vk_load_shaders(vk_device& device) { } if (device->vendor_id == VK_VENDOR_ID_ARM) { + // Shader workgroup size is 16x8 = 128 + const uint32_t wg_x = 16; + const uint32_t wg_y = 8; + + // Tile sizes for the workgroup + const uint32_t bm = 64; + const uint32_t bn = 64; + const uint32_t bk = 16; + + // Threads per tile + const uint32_t tm = bm / wg_y; + const uint32_t tn = bn / wg_x; + + const std::vector embed_spec_constants = {bm, bn, bk, tm, tn}; + const std::array embed_wg_denoms = {bm, bn, 1}; + ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32_embed, "mul_mat_f16_f32_embed", mul_mat_f16_f32_embed_len, mul_mat_f16_f32_embed_data, "main", 3, - sizeof(vk_mat_mat_embed_push_constants), {64, 64, 1}, {}, 1); + sizeof(vk_mat_mat_push_constants), embed_wg_denoms, embed_spec_constants, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f32_embed, "mul_mat_f32_f32_embed", mul_mat_f32_f32_embed_len, mul_mat_f32_f32_embed_data, "main", 3, - sizeof(vk_mat_mat_embed_push_constants), {64, 64, 1}, {}, 1); + sizeof(vk_mat_mat_push_constants), embed_wg_denoms, embed_spec_constants, 1); } // reusing CREATE_MM from the fp32 path @@ -5750,10 +5757,30 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (ctx->device->vendor_id == VK_VENDOR_ID_ARM && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && - src1->type == GGML_TYPE_F32 && + src1->type == GGML_TYPE_F32 && ggml_vk_dim01_contiguous(src1) && ne02 == 1 && ne03 == 1 && ne12 == 1 && ne13 == 1) { + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + vk_buffer d_Qx = nullptr, d_Qy = nullptr, d_D = nullptr; + size_t qx_buf_offset = 0, qy_buf_offset = 0, d_buf_offset = 0; + bool src0_uma = false, src1_uma = false, dst_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + dst_uma = d_D != nullptr; + } + + if (!src0_uma) { d_Qx = src0_buf_ctx->dev_buffer; qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; } + if (!src1_uma) { d_Qy = src1_buf_ctx->dev_buffer; qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; } + if (!dst_uma) { d_D = dst_buf_ctx->dev_buffer; d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; } + const uint32_t M = ne01; const uint32_t N = ne11; const uint32_t K = ne10; @@ -5762,11 +5789,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub vk_buffer d_X; uint64_t x_buf_offset; uint32_t stride_a; + bool dequantized = false; if (ggml_is_quantized(src0->type)) { vk_pipeline dequant_pipeline = ggml_vk_get_to_fp16(ctx, src0->type); if (dequant_pipeline) { + dequantized = true; const uint64_t x_sz = sizeof(ggml_fp16_t) * M * K; if (dryrun) { @@ -5778,12 +5807,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub return; } - ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; - const std::vector pc = { (uint32_t)M, (uint32_t)K, (uint32_t)K, (uint32_t)K, (uint32_t)(ggml_nelements(src0)) }; - ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, dequant_pipeline, { - vk_subbuffer{ src0_buf_ctx->dev_buffer, vk_tensor_offset(src0) + src0->view_offs, VK_WHOLE_SIZE }, + vk_subbuffer{ d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, vk_subbuffer{ ctx->prealloc_x, 0, VK_WHOLE_SIZE } }, pc, { (uint32_t)(ggml_nelements(src0)), 1, 1}); @@ -5804,29 +5831,23 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub return; } - //GGML_LOG_INFO("%s: using Mali-optimized shader for %s x %s, %ux%u matrix (K=%u)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type), M, N, K); - - ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; - d_X = src0_buf_ctx->dev_buffer; - x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + d_X = d_Qx; + x_buf_offset = qx_buf_offset; stride_a = src0->nb[1] / ggml_type_size(src0->type); } if (pipeline != nullptr) { - ggml_vk_sync_buffers(ctx, subctx); // Ensure dequant (if any) is finished + if (dequantized) { + ggml_vk_sync_buffers(ctx, subctx); // Ensure dequant is finished + } const uint32_t stride_b = src1->nb[1] / ggml_type_size(src1->type); const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type); - const vk_mat_mat_embed_push_constants pc = { M, N, K, stride_a, stride_b, stride_d }; - - ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + const vk_mat_mat_push_constants pc = { M, N, K, stride_a, stride_b, stride_d, M * K, K * N, M * N, K, 1, 1, 1, 1, N }; - vk_buffer d_D = dst_buf_ctx->dev_buffer; - const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; - vk_buffer d_Y = src1_buf_ctx->dev_buffer; - const uint64_t y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + vk_buffer d_Y = d_Qy; + const uint64_t y_buf_offset = qy_buf_offset; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, VK_WHOLE_SIZE }, @@ -5834,6 +5855,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub vk_subbuffer{ d_D, d_buf_offset, VK_WHOLE_SIZE }, }, pc, { M, N, 1 }); + if (dequantized) { + ctx->prealloc_x_need_sync = true; + } + return; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp index 17745c5de8c6a..c67ad5a45bb3a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_embed.comp @@ -2,6 +2,7 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_control_flow_attributes : require #ifdef A_TYPE_FP16 #define A_VEC4_TYPE f16vec4 @@ -15,6 +16,17 @@ #define A_VEC4_CAST(v) (v) #endif +layout(constant_id = 0) const uint BM = 64; +layout(constant_id = 1) const uint BN = 64; +layout(constant_id = 2) const uint BK = 16; +layout(constant_id = 3) const uint TM = 4; +layout(constant_id = 4) const uint TN = 8; + +const uint WG_X = BN / TN; +const uint WG_Y = BM / TM; +const uint WG_SIZE = WG_X * WG_Y; +const uint VEC_K = BK / 4; + layout(local_size_x = 16, local_size_y = 8, local_size_z = 1) in; layout (binding = 0) readonly buffer A_BUFFER { A_SCALAR_TYPE data_a[]; }; @@ -31,19 +43,6 @@ layout (push_constant) uniform parameter uint stride_d; } p; -const uint BM = 32; -const uint BN = 32; -const uint BK = 32; - -const uint TM = 4; -const uint TN = 2; - -const uint WG_X = 16; -const uint WG_Y = 8; -const uint WG_SIZE = WG_X * WG_Y; - -const uint VEC_K = BK / 4; - shared A_VEC4_TYPE buf_a[BM][VEC_K]; shared vec4 buf_b[BN][VEC_K]; @@ -56,44 +55,64 @@ void main() { const uint group_n = gl_WorkGroupID.y * BN; float sums[TM][TN]; - #pragma unroll + [[unroll]] for (uint i = 0; i < TM; i++) { - #pragma unroll + [[unroll]] for (uint j = 0; j < TN; j++) { sums[i][j] = 0.0f; } } const uint num_k_tiles = (p.K + BK - 1) / BK; + const uint A_LOADS_PER_THREAD = (BM * VEC_K) / WG_SIZE; + const uint B_LOADS_PER_THREAD = (BN * VEC_K) / WG_SIZE; for (uint t = 0; t < num_k_tiles; t++) { const uint k_tile_start = t * BK; - #pragma unroll - for(uint i = 0; i < 2; ++i) { + [[unroll]] + for(uint i = 0; i < A_LOADS_PER_THREAD; ++i) { uint load_idx = lid + i * WG_SIZE; uint m = load_idx / VEC_K; uint k = load_idx % VEC_K; uint global_m = group_m + m; uint k_scalar = k_tile_start + k * 4; - if (global_m < p.M && k_scalar + 3 < p.K) { + + if (global_m < p.M && k_scalar < p.K) { uint base_idx = global_m * p.stride_a + k_scalar; - buf_a[m][k] = A_VEC4_TYPE(data_a[base_idx], data_a[base_idx+1], data_a[base_idx+2], data_a[base_idx+3]); + if (k_scalar + 3 < p.K) { + buf_a[m][k] = A_VEC4_TYPE(data_a[base_idx], data_a[base_idx+1], data_a[base_idx+2], data_a[base_idx+3]); + } else { + A_SCALAR_TYPE temp[4] = {A_SCALAR_TYPE(0), A_SCALAR_TYPE(0), A_SCALAR_TYPE(0), A_SCALAR_TYPE(0)}; + if (k_scalar < p.K) temp[0] = data_a[base_idx]; + if (k_scalar + 1 < p.K) temp[1] = data_a[base_idx+1]; + if (k_scalar + 2 < p.K) temp[2] = data_a[base_idx+2]; + buf_a[m][k] = A_VEC4_TYPE(temp[0], temp[1], temp[2], temp[3]); + } } else { buf_a[m][k] = A_VEC4_ZERO; } } - #pragma unroll - for(uint i = 0; i < 2; ++i) { + [[unroll]] + for(uint i = 0; i < B_LOADS_PER_THREAD; ++i) { uint load_idx = lid + i * WG_SIZE; uint n = load_idx / VEC_K; uint k = load_idx % VEC_K; uint global_n = group_n + n; uint k_scalar = k_tile_start + k * 4; - if (global_n < p.N && k_scalar + 3 < p.K) { + + if (global_n < p.N && k_scalar < p.K) { uint base_idx = global_n * p.stride_b + k_scalar; - buf_b[n][k] = vec4(data_b[base_idx], data_b[base_idx+1], data_b[base_idx+2], data_b[base_idx+3]); + if (k_scalar + 3 < p.K) { + buf_b[n][k] = vec4(data_b[base_idx], data_b[base_idx+1], data_b[base_idx+2], data_b[base_idx+3]); + } else { + float temp[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + if (k_scalar < p.K) temp[0] = data_b[base_idx]; + if (k_scalar + 1 < p.K) temp[1] = data_b[base_idx+1]; + if (k_scalar + 2 < p.K) temp[2] = data_b[base_idx+2]; + buf_b[n][k] = vec4(temp[0], temp[1], temp[2], temp[3]); + } } else { buf_b[n][k] = vec4(0.0f); } @@ -101,59 +120,47 @@ void main() { barrier(); - uint m_base = lidy * TM; - uint n_base = lidx * TN; - - A_VEC4_TYPE a_reg[TM]; - vec4 b_reg[TN]; - A_VEC4_TYPE a_reg_next[TM]; - vec4 b_reg_next[TN]; - - #pragma unroll - for (uint i = 0; i < TM; i++) { a_reg[i] = buf_a[m_base + i][0]; } - #pragma unroll - for (uint j = 0; j < TN; j++) { b_reg[j] = buf_b[n_base + j][0]; } - - for (uint k = 1; k < VEC_K; k++) { - #pragma unroll - for (uint i = 0; i < TM; i++) { a_reg_next[i] = buf_a[m_base + i][k]; } - #pragma unroll - for (uint j = 0; j < TN; j++) { b_reg_next[j] = buf_b[n_base + j][k]; } - - #pragma unroll + [[unroll]] + for (uint k = 0; k < VEC_K; k++) { + A_VEC4_TYPE a_reg[TM]; + [[unroll]] for (uint i = 0; i < TM; i++) { - #pragma unroll - for (uint j = 0; j < TN; j++) { - sums[i][j] += dot(A_VEC4_CAST(a_reg[i]), b_reg[j]); - } + a_reg[i] = buf_a[lidy + i * WG_Y][k]; } - a_reg = a_reg_next; - b_reg = b_reg_next; - } - - #pragma unroll - for (uint i = 0; i < TM; i++) { - #pragma unroll + vec4 b_reg[TN]; + [[unroll]] for (uint j = 0; j < TN; j++) { - sums[i][j] += dot(A_VEC4_CAST(a_reg[i]), b_reg[j]); + b_reg[j] = buf_b[lidx + j * WG_X][k]; } - } + [[unroll]] + for (uint i = 0; i < TM; i++) { + vec4 a_f32 = A_VEC4_CAST(a_reg[i]); + + sums[i][0] += a_f32.x * b_reg[0].x + a_f32.y * b_reg[0].y + a_f32.z * b_reg[0].z + a_f32.w * b_reg[0].w; + sums[i][1] += a_f32.x * b_reg[1].x + a_f32.y * b_reg[1].y + a_f32.z * b_reg[1].z + a_f32.w * b_reg[1].w; + sums[i][2] += a_f32.x * b_reg[2].x + a_f32.y * b_reg[2].y + a_f32.z * b_reg[2].z + a_f32.w * b_reg[2].w; + sums[i][3] += a_f32.x * b_reg[3].x + a_f32.y * b_reg[3].y + a_f32.z * b_reg[3].z + a_f32.w * b_reg[3].w; + sums[i][4] += a_f32.x * b_reg[4].x + a_f32.y * b_reg[4].y + a_f32.z * b_reg[4].z + a_f32.w * b_reg[4].w; + sums[i][5] += a_f32.x * b_reg[5].x + a_f32.y * b_reg[5].y + a_f32.z * b_reg[5].z + a_f32.w * b_reg[5].w; + sums[i][6] += a_f32.x * b_reg[6].x + a_f32.y * b_reg[6].y + a_f32.z * b_reg[6].z + a_f32.w * b_reg[6].w; + sums[i][7] += a_f32.x * b_reg[7].x + a_f32.y * b_reg[7].y + a_f32.z * b_reg[7].z + a_f32.w * b_reg[7].w; + } + } barrier(); } - uint m_base = lidy * TM; - uint n_base = lidx * TN; - - #pragma unroll + [[unroll]] for (uint i = 0; i < TM; i++) { - uint global_m = group_m + m_base + i; - #pragma unroll - for (uint j = 0; j < TN; j++) { - uint global_n = group_n + n_base + j; - if (global_m < p.M && global_n < p.N) { - data_d[global_n * p.stride_d + global_m] = sums[i][j]; + uint global_m = group_m + lidy + i * WG_Y; + if (global_m < p.M) { + [[unroll]] + for (uint j = 0; j < TN; j++) { + uint global_n = group_n + lidx + j * WG_X; + if (global_n < p.N) { + data_d[global_n * p.stride_d + global_m] = sums[i][j]; + } } } } From 850b9bf40f51dfc86fc3d7ea1c2a17dafd1fda65 Mon Sep 17 00:00:00 2001 From: rmatif Date: Sat, 6 Sep 2025 09:59:40 +0000 Subject: [PATCH 4/4] add qcom support --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0abefc67c54e6..f92686566efb0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -78,6 +78,8 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de #define VK_VENDOR_ID_ARM 0x13B5 +#define VK_VENDOR_ID_QUALCOMM 0x5143 + #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 @@ -2905,15 +2907,23 @@ static void ggml_vk_load_shaders(vk_device& device) { } } - if (device->vendor_id == VK_VENDOR_ID_ARM) { + if (device->vendor_id == VK_VENDOR_ID_ARM || device->vendor_id == VK_VENDOR_ID_QUALCOMM) { // Shader workgroup size is 16x8 = 128 const uint32_t wg_x = 16; const uint32_t wg_y = 8; // Tile sizes for the workgroup - const uint32_t bm = 64; - const uint32_t bn = 64; - const uint32_t bk = 16; + uint32_t bm, bn, bk; + + if (device->vendor_id == VK_VENDOR_ID_QUALCOMM) { + bm = 32; + bn = 128; + bk = 8; + } else { + bm = 64; + bn = 64; + bk = 16; + } // Threads per tile const uint32_t tm = bm / wg_y; @@ -5755,7 +5765,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t ne12 = src1->ne[2]; const uint64_t ne13 = src1->ne[3]; - if (ctx->device->vendor_id == VK_VENDOR_ID_ARM && + if ((ctx->device->vendor_id == VK_VENDOR_ID_ARM || ctx->device->vendor_id == VK_VENDOR_ID_QUALCOMM) && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32 && ggml_vk_dim01_contiguous(src1) && ne02 == 1 && ne03 == 1 &&