@@ -120,6 +120,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
120120 GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
121121 GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
122122 GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
123+ GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
123124 GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
124125 GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
125126 GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
@@ -150,6 +151,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
150151 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
151152 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
152153 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
154+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
155+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
156+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
157+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
153158 GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
154159 GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
155160 GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
@@ -195,6 +200,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
195200 GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
196201 GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
197202 GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
203+ GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
198204 GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
199205 GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
200206 GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
@@ -300,8 +306,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
300306 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
301307 GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
302308 GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
309+ GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
303310 GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
304311 GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
312+ GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
305313 GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
306314 GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
307315 GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -615,6 +623,7 @@ @implementation GGMLMetalClass
615623 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true );
616624 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true );
617625 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
626+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true );
618627 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true );
619628 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true );
620629 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true );
@@ -641,6 +650,10 @@ @implementation GGMLMetalClass
641650 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true );
642651 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true );
643652 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
653+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, support_simdgroup_reduction);
654+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction);
655+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction);
656+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction);
644657 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
645658 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
646659 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
@@ -690,6 +703,7 @@ @implementation GGMLMetalClass
690703 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
691704 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
692705 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
706+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, support_simdgroup_mm);
693707 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
694708 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
695709 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
@@ -793,10 +807,12 @@ @implementation GGMLMetalClass
793807 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
794808 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
795809 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
796- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
797810 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
798- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
811+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
812+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true );
799813 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
814+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
815+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true );
800816 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
801817 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true );
802818 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true );
@@ -887,8 +903,13 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
887903
888904static bool ggml_metal_supports_op (const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
889905 for (size_t i = 0 , n = 3 ; i < n; ++i) {
890- if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
891- return false ;
906+ if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16 &&
907+ op->op != GGML_OP_GET_ROWS &&
908+ op->op != GGML_OP_MUL_MAT &&
909+ op->op != GGML_OP_VIEW &&
910+ op->op != GGML_OP_CPY) {
911+ GGML_LOG_ERROR (" unsupported BF16 op = %s , src[%zu ] = %s \n " , ggml_op_name (op->op ), i, ggml_type_name (op->src [i]->type ));
912+ GGML_ASSERT (false );
892913 }
893914 }
894915
@@ -969,6 +990,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
969990 switch (op->type ) {
970991 case GGML_TYPE_F32:
971992 case GGML_TYPE_F16:
993+ case GGML_TYPE_BF16:
972994 case GGML_TYPE_Q8_0:
973995 case GGML_TYPE_Q4_0:
974996 case GGML_TYPE_Q4_1:
@@ -980,11 +1002,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
9801002 return false ;
9811003 }
9821004 case GGML_TYPE_F16:
1005+ case GGML_TYPE_BF16:
9831006 switch (op->type ) {
984- case GGML_TYPE_F32:
985- case GGML_TYPE_F16:
1007+ case GGML_TYPE_F32:
1008+ case GGML_TYPE_F16:
1009+ case GGML_TYPE_BF16:
9861010 return true ;
987- default :
1011+ default :
9881012 return false ;
9891013 }
9901014 default :
@@ -1855,6 +1879,7 @@ static void ggml_metal_encode_node(
18551879 switch (src0->type ) {
18561880 case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
18571881 case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1882+ case GGML_TYPE_BF16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
18581883 default : break ;
18591884 }
18601885
@@ -1863,6 +1888,7 @@ static void ggml_metal_encode_node(
18631888 switch (src0->type ) {
18641889 case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline ; break ;
18651890 case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline ; break ;
1891+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline ; break ;
18661892 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline ; break ;
18671893 case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline ; break ;
18681894 case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline ; break ;
@@ -1940,6 +1966,25 @@ static void ggml_metal_encode_node(
19401966 nrows = 4 ;
19411967 }
19421968 } break ;
1969+ case GGML_TYPE_BF16:
1970+ {
1971+ nth0 = 32 ;
1972+ nth1 = 1 ;
1973+ if (src1t == GGML_TYPE_F32) {
1974+ if (ne11 * ne12 < 4 ) {
1975+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline ;
1976+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
1977+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline ;
1978+ nrows = ne11;
1979+ } else {
1980+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline ;
1981+ nrows = 4 ;
1982+ }
1983+ } else {
1984+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline ;
1985+ nrows = 4 ;
1986+ }
1987+ } break ;
19431988 case GGML_TYPE_Q4_0:
19441989 {
19451990 nth0 = 8 ;
@@ -2438,6 +2483,7 @@ static void ggml_metal_encode_node(
24382483 switch (src0->type ) {
24392484 case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline ; break ;
24402485 case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline ; break ;
2486+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline ; break ;
24412487 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline ; break ;
24422488 case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline ; break ;
24432489 case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline ; break ;
@@ -3237,6 +3283,7 @@ static void ggml_metal_encode_node(
32373283 switch (dstt) {
32383284 case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline ; break ;
32393285 case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline ; break ;
3286+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline ; break ;
32403287 case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline ; break ;
32413288 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline ; break ;
32423289 case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline ; break ;
@@ -3254,6 +3301,13 @@ static void ggml_metal_encode_node(
32543301 default : GGML_ABORT (" not implemented" );
32553302 };
32563303 } break ;
3304+ case GGML_TYPE_BF16:
3305+ {
3306+ switch (dstt) {
3307+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline ; break ;
3308+ default : GGML_ASSERT (false && " not implemented" );
3309+ };
3310+ } break ;
32573311 default : GGML_ABORT (" not implemented" );
32583312 }
32593313
0 commit comments