diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 21bd052255564..133f5ca8e97d1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5826,10 +5826,10 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { + if ((ctx->device->mul_mat_s[src0_type] && (m * n <= 32 * 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { + if ((ctx->device->mul_mat_m[src0_type] && (m * n <= 64 * 64)) || !ctx->device->mul_mat_l[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; @@ -5892,10 +5892,10 @@ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ct return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { + if ((ctx->device->mul_mat_id_s[src0_type] && (m * n <= 32 * 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { + if ((ctx->device->mul_mat_id_m[src0_type] && (m * n <= 64 * 64)) || !ctx->device->mul_mat_id_l[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l;