8181 GGML_METAL_DECL_KERNEL (get_rows_q6_K);
8282 GGML_METAL_DECL_KERNEL (rms_norm);
8383 GGML_METAL_DECL_KERNEL (norm);
84- GGML_METAL_DECL_KERNEL (mul_mat_f32_f32 );
85- GGML_METAL_DECL_KERNEL (mul_mat_f16_f32 );
86- GGML_METAL_DECL_KERNEL (mul_mat_f16_f32_1row );
87- GGML_METAL_DECL_KERNEL (mul_mat_f16_f32_l4 );
88- GGML_METAL_DECL_KERNEL (mul_mat_q4_0_f32 );
89- GGML_METAL_DECL_KERNEL (mul_mat_q4_1_f32 );
90- GGML_METAL_DECL_KERNEL (mul_mat_q8_0_f32 );
91- GGML_METAL_DECL_KERNEL (mul_mat_q2_K_f32 );
92- GGML_METAL_DECL_KERNEL (mul_mat_q3_K_f32 );
93- GGML_METAL_DECL_KERNEL (mul_mat_q4_K_f32 );
94- GGML_METAL_DECL_KERNEL (mul_mat_q5_K_f32 );
95- GGML_METAL_DECL_KERNEL (mul_mat_q6_K_f32 );
84+ GGML_METAL_DECL_KERNEL (mul_mv_f32_f32 );
85+ GGML_METAL_DECL_KERNEL (mul_mv_f16_f32 );
86+ GGML_METAL_DECL_KERNEL (mul_mv_f16_f32_1row );
87+ GGML_METAL_DECL_KERNEL (mul_mv_f16_f32_l4 );
88+ GGML_METAL_DECL_KERNEL (mul_mv_q4_0_f32 );
89+ GGML_METAL_DECL_KERNEL (mul_mv_q4_1_f32 );
90+ GGML_METAL_DECL_KERNEL (mul_mv_q8_0_f32 );
91+ GGML_METAL_DECL_KERNEL (mul_mv_q2_K_f32 );
92+ GGML_METAL_DECL_KERNEL (mul_mv_q3_K_f32 );
93+ GGML_METAL_DECL_KERNEL (mul_mv_q4_K_f32 );
94+ GGML_METAL_DECL_KERNEL (mul_mv_q5_K_f32 );
95+ GGML_METAL_DECL_KERNEL (mul_mv_q6_K_f32 );
9696 GGML_METAL_DECL_KERNEL (mul_mm_f32_f32);
9797 GGML_METAL_DECL_KERNEL (mul_mm_f16_f32);
9898 GGML_METAL_DECL_KERNEL (mul_mm_q4_0_f32);
@@ -262,28 +262,30 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
262262 GGML_METAL_ADD_KERNEL (get_rows_q6_K);
263263 GGML_METAL_ADD_KERNEL (rms_norm);
264264 GGML_METAL_ADD_KERNEL (norm);
265- GGML_METAL_ADD_KERNEL (mul_mat_f32_f32);
266- GGML_METAL_ADD_KERNEL (mul_mat_f16_f32);
267- GGML_METAL_ADD_KERNEL (mul_mat_f16_f32_1row);
268- GGML_METAL_ADD_KERNEL (mul_mat_f16_f32_l4);
269- GGML_METAL_ADD_KERNEL (mul_mat_q4_0_f32);
270- GGML_METAL_ADD_KERNEL (mul_mat_q4_1_f32);
271- GGML_METAL_ADD_KERNEL (mul_mat_q8_0_f32);
272- GGML_METAL_ADD_KERNEL (mul_mat_q2_K_f32);
273- GGML_METAL_ADD_KERNEL (mul_mat_q3_K_f32);
274- GGML_METAL_ADD_KERNEL (mul_mat_q4_K_f32);
275- GGML_METAL_ADD_KERNEL (mul_mat_q5_K_f32);
276- GGML_METAL_ADD_KERNEL (mul_mat_q6_K_f32);
277- GGML_METAL_ADD_KERNEL (mul_mm_f32_f32);
278- GGML_METAL_ADD_KERNEL (mul_mm_f16_f32);
279- GGML_METAL_ADD_KERNEL (mul_mm_q4_0_f32);
280- GGML_METAL_ADD_KERNEL (mul_mm_q8_0_f32);
281- GGML_METAL_ADD_KERNEL (mul_mm_q4_1_f32);
282- GGML_METAL_ADD_KERNEL (mul_mm_q2_K_f32);
283- GGML_METAL_ADD_KERNEL (mul_mm_q3_K_f32);
284- GGML_METAL_ADD_KERNEL (mul_mm_q4_K_f32);
285- GGML_METAL_ADD_KERNEL (mul_mm_q5_K_f32);
286- GGML_METAL_ADD_KERNEL (mul_mm_q6_K_f32);
265+ GGML_METAL_ADD_KERNEL (mul_mv_f32_f32);
266+ GGML_METAL_ADD_KERNEL (mul_mv_f16_f32);
267+ GGML_METAL_ADD_KERNEL (mul_mv_f16_f32_1row);
268+ GGML_METAL_ADD_KERNEL (mul_mv_f16_f32_l4);
269+ GGML_METAL_ADD_KERNEL (mul_mv_q4_0_f32);
270+ GGML_METAL_ADD_KERNEL (mul_mv_q4_1_f32);
271+ GGML_METAL_ADD_KERNEL (mul_mv_q8_0_f32);
272+ GGML_METAL_ADD_KERNEL (mul_mv_q2_K_f32);
273+ GGML_METAL_ADD_KERNEL (mul_mv_q3_K_f32);
274+ GGML_METAL_ADD_KERNEL (mul_mv_q4_K_f32);
275+ GGML_METAL_ADD_KERNEL (mul_mv_q5_K_f32);
276+ GGML_METAL_ADD_KERNEL (mul_mv_q6_K_f32);
277+ if ([ctx->device supportsFamily: MTLGPUFamilyApple7]) {
278+ GGML_METAL_ADD_KERNEL (mul_mm_f32_f32);
279+ GGML_METAL_ADD_KERNEL (mul_mm_f16_f32);
280+ GGML_METAL_ADD_KERNEL (mul_mm_q4_0_f32);
281+ GGML_METAL_ADD_KERNEL (mul_mm_q8_0_f32);
282+ GGML_METAL_ADD_KERNEL (mul_mm_q4_1_f32);
283+ GGML_METAL_ADD_KERNEL (mul_mm_q2_K_f32);
284+ GGML_METAL_ADD_KERNEL (mul_mm_q3_K_f32);
285+ GGML_METAL_ADD_KERNEL (mul_mm_q4_K_f32);
286+ GGML_METAL_ADD_KERNEL (mul_mm_q5_K_f32);
287+ GGML_METAL_ADD_KERNEL (mul_mm_q6_K_f32);
288+ }
287289 GGML_METAL_ADD_KERNEL (rope_f32);
288290 GGML_METAL_ADD_KERNEL (rope_f16);
289291 GGML_METAL_ADD_KERNEL (alibi_f32);
@@ -296,8 +298,22 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
296298#undef GGML_METAL_ADD_KERNEL
297299 }
298300
299- GGML_METAL_LOG_INFO (" %s : hasUnifiedMemory = %s \n " , __func__, ctx->device .hasUnifiedMemory ? " true" : " false" );
300301#if TARGET_OS_OSX
302+ // print MTL GPU family:
303+ GGML_METAL_LOG_INFO (" %s : GPU name: %s \n " , __func__, [[ctx->device name ] UTF8String ]);
304+ GGML_METAL_LOG_INFO (" %s : GPU arch: %s \n " , __func__, [[ctx->device architecture ].name UTF8String ]);
305+
306+ // determine max supported GPU family
307+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
308+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
309+ for (int i = MTLGPUFamilyApple9 + 10 ; i >= MTLGPUFamilyApple1 ; --i) {
310+ if ([ctx->device supportsFamily: i]) {
311+ GGML_METAL_LOG_INFO (" %s : GPU family: MTLGPUFamilyApple%d (%d )\n " , __func__, i - MTLGPUFamilyApple1 + 1 , i);
312+ break ;
313+ }
314+ }
315+
316+ GGML_METAL_LOG_INFO (" %s : hasUnifiedMemory = %s \n " , __func__, ctx->device .hasUnifiedMemory ? " true" : " false" );
301317 GGML_METAL_LOG_INFO (" %s : recommendedMaxWorkingSetSize = %8.2f MB\n " , __func__, ctx->device .recommendedMaxWorkingSetSize / 1024.0 / 1024.0 );
302318 if (ctx->device .maxTransferRate != 0 ) {
303319 GGML_METAL_LOG_INFO (" %s : maxTransferRate = %8.2f MB/s\n " , __func__, ctx->device .maxTransferRate / 1024.0 / 1024.0 );
@@ -339,28 +355,30 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
339355 GGML_METAL_DEL_KERNEL (get_rows_q6_K);
340356 GGML_METAL_DEL_KERNEL (rms_norm);
341357 GGML_METAL_DEL_KERNEL (norm);
342- GGML_METAL_DEL_KERNEL (mul_mat_f32_f32);
343- GGML_METAL_DEL_KERNEL (mul_mat_f16_f32);
344- GGML_METAL_DEL_KERNEL (mul_mat_f16_f32_1row);
345- GGML_METAL_DEL_KERNEL (mul_mat_f16_f32_l4);
346- GGML_METAL_DEL_KERNEL (mul_mat_q4_0_f32);
347- GGML_METAL_DEL_KERNEL (mul_mat_q4_1_f32);
348- GGML_METAL_DEL_KERNEL (mul_mat_q8_0_f32);
349- GGML_METAL_DEL_KERNEL (mul_mat_q2_K_f32);
350- GGML_METAL_DEL_KERNEL (mul_mat_q3_K_f32);
351- GGML_METAL_DEL_KERNEL (mul_mat_q4_K_f32);
352- GGML_METAL_DEL_KERNEL (mul_mat_q5_K_f32);
353- GGML_METAL_DEL_KERNEL (mul_mat_q6_K_f32);
354- GGML_METAL_DEL_KERNEL (mul_mm_f32_f32);
355- GGML_METAL_DEL_KERNEL (mul_mm_f16_f32);
356- GGML_METAL_DEL_KERNEL (mul_mm_q4_0_f32);
357- GGML_METAL_DEL_KERNEL (mul_mm_q8_0_f32);
358- GGML_METAL_DEL_KERNEL (mul_mm_q4_1_f32);
359- GGML_METAL_DEL_KERNEL (mul_mm_q2_K_f32);
360- GGML_METAL_DEL_KERNEL (mul_mm_q3_K_f32);
361- GGML_METAL_DEL_KERNEL (mul_mm_q4_K_f32);
362- GGML_METAL_DEL_KERNEL (mul_mm_q5_K_f32);
363- GGML_METAL_DEL_KERNEL (mul_mm_q6_K_f32);
358+ GGML_METAL_DEL_KERNEL (mul_mv_f32_f32);
359+ GGML_METAL_DEL_KERNEL (mul_mv_f16_f32);
360+ GGML_METAL_DEL_KERNEL (mul_mv_f16_f32_1row);
361+ GGML_METAL_DEL_KERNEL (mul_mv_f16_f32_l4);
362+ GGML_METAL_DEL_KERNEL (mul_mv_q4_0_f32);
363+ GGML_METAL_DEL_KERNEL (mul_mv_q4_1_f32);
364+ GGML_METAL_DEL_KERNEL (mul_mv_q8_0_f32);
365+ GGML_METAL_DEL_KERNEL (mul_mv_q2_K_f32);
366+ GGML_METAL_DEL_KERNEL (mul_mv_q3_K_f32);
367+ GGML_METAL_DEL_KERNEL (mul_mv_q4_K_f32);
368+ GGML_METAL_DEL_KERNEL (mul_mv_q5_K_f32);
369+ GGML_METAL_DEL_KERNEL (mul_mv_q6_K_f32);
370+ if ([ctx->device supportsFamily: MTLGPUFamilyApple7]) {
371+ GGML_METAL_DEL_KERNEL (mul_mm_f32_f32);
372+ GGML_METAL_DEL_KERNEL (mul_mm_f16_f32);
373+ GGML_METAL_DEL_KERNEL (mul_mm_q4_0_f32);
374+ GGML_METAL_DEL_KERNEL (mul_mm_q8_0_f32);
375+ GGML_METAL_DEL_KERNEL (mul_mm_q4_1_f32);
376+ GGML_METAL_DEL_KERNEL (mul_mm_q2_K_f32);
377+ GGML_METAL_DEL_KERNEL (mul_mm_q3_K_f32);
378+ GGML_METAL_DEL_KERNEL (mul_mm_q4_K_f32);
379+ GGML_METAL_DEL_KERNEL (mul_mm_q5_K_f32);
380+ GGML_METAL_DEL_KERNEL (mul_mm_q6_K_f32);
381+ }
364382 GGML_METAL_DEL_KERNEL (rope_f32);
365383 GGML_METAL_DEL_KERNEL (rope_f16);
366384 GGML_METAL_DEL_KERNEL (alibi_f32);
@@ -986,21 +1004,46 @@ void ggml_metal_graph_compute(
9861004 } break ;
9871005 case GGML_OP_MUL_MAT:
9881006 {
989- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
990-
9911007 GGML_ASSERT (ne00 == ne10);
992- // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
993- uint gqa = ne12/ne02;
9941008 GGML_ASSERT (ne03 == ne13);
9951009
1010+ const uint gqa = ne12/ne02;
1011+
1012+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1013+ // to the matrix-vector kernel
1014+ int ne11_mm_min = 1 ;
1015+
1016+ #if 0
1017+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1018+ // these numbers do not translate to other devices or model sizes
1019+ // TODO: need to find a better approach
1020+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1021+ switch (src0t) {
1022+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
1023+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1024+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1025+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1026+ case GGML_TYPE_Q4_0:
1027+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1028+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1029+ case GGML_TYPE_Q5_0: // not tested yet
1030+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1031+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1032+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1033+ default: ne11_mm_min = 1; break;
1034+ }
1035+ }
1036+ #endif
1037+
9961038 // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
9971039 // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
998- if (!ggml_is_transposed (src0) &&
1040+ if ([ctx->device supportsFamily: MTLGPUFamilyApple7] &&
1041+ !ggml_is_transposed (src0) &&
9991042 !ggml_is_transposed (src1) &&
10001043 src1t == GGML_TYPE_F32 &&
1001- [ctx->device supportsFamily: MTLGPUFamilyApple7] &&
1002- ne00% 32 == 0 &&
1003- ne11 > 2 ) {
1044+ ne00 % 32 == 0 &&
1045+ ne11 > ne11_mm_min) {
1046+ // printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
10041047 switch (src0->type ) {
10051048 case GGML_TYPE_F32: [encoder setComputePipelineState: ctx->pipeline_mul_mm_f32_f32]; break ;
10061049 case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_mul_mm_f16_f32]; break ;
@@ -1029,30 +1072,31 @@ void ggml_metal_graph_compute(
10291072 [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 12 ];
10301073 [encoder setBytes: &gqa length: sizeof (gqa) atIndex: 13 ];
10311074 [encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
1032- [encoder dispatchThreadgroups: MTLSizeMake ( (ne11+ 31 )/32 , (ne01+ 63 ) / 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
1075+ [encoder dispatchThreadgroups: MTLSizeMake ( (ne11 + 31 )/32 , (ne01 + 63 )/ 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
10331076 } else {
10341077 int nth0 = 32 ;
10351078 int nth1 = 1 ;
10361079 int nrows = 1 ;
1080+ // printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
10371081
10381082 // use custom matrix x vector kernel
10391083 switch (src0t) {
10401084 case GGML_TYPE_F32:
10411085 {
1042- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f32_f32 ];
1086+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_f32_f32 ];
10431087 nrows = 4 ;
10441088 } break ;
10451089 case GGML_TYPE_F16:
10461090 {
10471091 nth0 = 32 ;
10481092 nth1 = 1 ;
10491093 if (ne11 * ne12 < 4 ) {
1050- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32_1row ];
1094+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_f16_f32_1row ];
10511095 } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
1052- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32_l4 ];
1096+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_f16_f32_l4 ];
10531097 nrows = ne11;
10541098 } else {
1055- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32 ];
1099+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_f16_f32 ];
10561100 nrows = 4 ;
10571101 }
10581102 } break ;
@@ -1063,7 +1107,7 @@ void ggml_metal_graph_compute(
10631107
10641108 nth0 = 8 ;
10651109 nth1 = 8 ;
1066- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_0_f32 ];
1110+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q4_0_f32 ];
10671111 } break ;
10681112 case GGML_TYPE_Q4_1:
10691113 {
@@ -1072,7 +1116,7 @@ void ggml_metal_graph_compute(
10721116
10731117 nth0 = 8 ;
10741118 nth1 = 8 ;
1075- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_1_f32 ];
1119+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q4_1_f32 ];
10761120 } break ;
10771121 case GGML_TYPE_Q8_0:
10781122 {
@@ -1081,7 +1125,7 @@ void ggml_metal_graph_compute(
10811125
10821126 nth0 = 8 ;
10831127 nth1 = 8 ;
1084- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q8_0_f32 ];
1128+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q8_0_f32 ];
10851129 } break ;
10861130 case GGML_TYPE_Q2_K:
10871131 {
@@ -1090,7 +1134,7 @@ void ggml_metal_graph_compute(
10901134
10911135 nth0 = 2 ;
10921136 nth1 = 32 ;
1093- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q2_K_f32 ];
1137+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q2_K_f32 ];
10941138 } break ;
10951139 case GGML_TYPE_Q3_K:
10961140 {
@@ -1099,7 +1143,7 @@ void ggml_metal_graph_compute(
10991143
11001144 nth0 = 2 ;
11011145 nth1 = 32 ;
1102- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q3_K_f32 ];
1146+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q3_K_f32 ];
11031147 } break ;
11041148 case GGML_TYPE_Q4_K:
11051149 {
@@ -1108,7 +1152,7 @@ void ggml_metal_graph_compute(
11081152
11091153 nth0 = 4 ; // 1;
11101154 nth1 = 8 ; // 32;
1111- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_K_f32 ];
1155+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q4_K_f32 ];
11121156 } break ;
11131157 case GGML_TYPE_Q5_K:
11141158 {
@@ -1117,7 +1161,7 @@ void ggml_metal_graph_compute(
11171161
11181162 nth0 = 2 ;
11191163 nth1 = 32 ;
1120- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q5_K_f32 ];
1164+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q5_K_f32 ];
11211165 } break ;
11221166 case GGML_TYPE_Q6_K:
11231167 {
@@ -1126,7 +1170,7 @@ void ggml_metal_graph_compute(
11261170
11271171 nth0 = 2 ;
11281172 nth1 = 32 ;
1129- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q6_K_f32 ];
1173+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q6_K_f32 ];
11301174 } break ;
11311175 default :
11321176 {
@@ -1155,7 +1199,7 @@ void ggml_metal_graph_compute(
11551199 [encoder setBytes: &gqa length: sizeof (gqa) atIndex: 17 ];
11561200
11571201 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1158- src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
1202+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
11591203 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
11601204 }
11611205 else if (src0t == GGML_TYPE_Q4_K) {
0 commit comments