@@ -994,7 +994,7 @@ void ggml_metal_graph_compute(
994994 GGML_ASSERT (ne03 == ne13);
995995
996996 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
997- // to the matrix-vector kernel. the numbers below are measure on M2 Ultra
997+ // to the matrix-vector kernel. the numbers below are measured on M2 Ultra
998998 // not sure if this translates across all chips
999999 int ne11_mm_min = 1 ;
10001000
@@ -1015,12 +1015,13 @@ void ggml_metal_graph_compute(
10151015
10161016 // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
10171017 // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1018- if (!ggml_is_transposed (src0) &&
1018+ if ([ctx->device supportsFamily: MTLGPUFamilyApple7] &&
1019+ !ggml_is_transposed (src0) &&
10191020 !ggml_is_transposed (src1) &&
10201021 src1t == GGML_TYPE_F32 &&
1021- [ctx->device supportsFamily: MTLGPUFamilyApple7] &&
1022- ne00%32 == 0 &&
1022+ ne00 % 32 == 0 &&
10231023 ne11 > ne11_mm_min) {
1024+ // printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
10241025 switch (src0->type ) {
10251026 case GGML_TYPE_F32: [encoder setComputePipelineState: ctx->pipeline_mul_mm_f32_f32]; break ;
10261027 case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_mul_mm_f16_f32]; break ;
@@ -1049,11 +1050,12 @@ void ggml_metal_graph_compute(
10491050 [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 12 ];
10501051 [encoder setBytes: &gqa length: sizeof (gqa) atIndex: 13 ];
10511052 [encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
1052- [encoder dispatchThreadgroups: MTLSizeMake ( (ne11+ 31 )/32 , (ne01+ 63 ) / 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
1053+ [encoder dispatchThreadgroups: MTLSizeMake ( (ne11 + 31 )/32 , (ne01 + 63 )/ 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
10531054 } else {
10541055 int nth0 = 32 ;
10551056 int nth1 = 1 ;
10561057 int nrows = 1 ;
1058+ // printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
10571059
10581060 // use custom matrix x vector kernel
10591061 switch (src0t) {
@@ -1175,7 +1177,7 @@ void ggml_metal_graph_compute(
11751177 [encoder setBytes: &gqa length: sizeof (gqa) atIndex: 17 ];
11761178
11771179 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1178- src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
1180+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
11791181 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
11801182 }
11811183 else if (src0t == GGML_TYPE_Q4_K) {
0 commit comments