@@ -398,7 +398,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
398398 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
399399 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
400400 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
401- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
401+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
402+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
403+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
404+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
405+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
406+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
402407 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
403408 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
404409 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -1427,7 +1432,12 @@ @implementation GGMLMetalClass
14271432 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
14281433 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
14291434 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
1430- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1435+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
1436+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
1437+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
1438+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
1439+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
1440+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
14311441 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
14321442 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
14331443 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -3937,7 +3947,17 @@ static int ggml_metal_encode_node(
39373947
39383948 id <MTLComputePipelineState > pipeline = nil ;
39393949
3940- pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline ;
3950+ pipeline = nil ;
3951+
3952+ switch (ne20) {
3953+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline ; break ;
3954+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline ; break ;
3955+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline ; break ;
3956+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline ; break ;
3957+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline ; break ;
3958+ case 16 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline ; break ;
3959+ default : GGML_ABORT (" missing specialization for ne20 = %d " , (int ) ne20);
3960+ }
39413961
39423962 GGML_ASSERT (ne02 <= (int ) pipeline.maxTotalThreadsPerThreadgroup );
39433963
0 commit comments