@@ -396,7 +396,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
396396 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
397397 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
398398 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
399- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
399+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
400+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
401+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
402+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
403+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
404+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
400405 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
401406 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
402407 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -1425,7 +1430,12 @@ @implementation GGMLMetalClass
14251430 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
14261431 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
14271432 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
1428- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1433+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
1434+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
1435+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
1436+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
1437+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
1438+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
14291439 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
14301440 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
14311441 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -3935,7 +3945,17 @@ static int ggml_metal_encode_node(
39353945
39363946 id <MTLComputePipelineState > pipeline = nil ;
39373947
3938- pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline ;
3948+ pipeline = nil ;
3949+
3950+ switch (ne20) {
3951+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline ; break ;
3952+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline ; break ;
3953+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline ; break ;
3954+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline ; break ;
3955+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline ; break ;
3956+ case 16 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline ; break ;
3957+ default : GGML_ABORT (" missing specialization for ne20 = %d " , (int ) ne20);
3958+ }
39393959
39403960 GGML_ASSERT (ne02 <= (int ) pipeline.maxTotalThreadsPerThreadgroup );
39413961
0 commit comments