5858 GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
5959 GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
6060 GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
61+ GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
6162 GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
6263 GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
6364 GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
8384 GGML_METAL_KERNEL_TYPE_GROUP_NORM,
8485 GGML_METAL_KERNEL_TYPE_NORM,
8586 GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
87+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
88+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
89+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
90+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
8691 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
8792 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
8893 GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
131136 GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
132137 GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
133138 GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
139+ GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
134140 GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
135141 GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
136142 GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
194200 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195201 // GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
196202 GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
203+ GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
197204 GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
205+ GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
198206 GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
199207 GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
200208 GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -514,6 +522,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
514522 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true );
515523 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true );
516524 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true );
525+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true );
517526 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
518527 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true );
519528 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true );
@@ -539,6 +548,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
539548 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction );
540549 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
541550 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction );
551+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, ctx->support_simdgroup_reduction );
552+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, ctx->support_simdgroup_reduction );
553+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, ctx->support_simdgroup_reduction );
554+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, ctx->support_simdgroup_reduction );
542555 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction );
543556 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction );
544557 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction );
@@ -587,6 +600,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
587600 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction );
588601 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction );
589602 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm );
603+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, ctx->support_simdgroup_mm );
590604 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm );
591605 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm );
592606 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm );
@@ -649,8 +663,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
649663 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
650664 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction );
651665 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
666+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true );
652667 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
653668 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
669+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true );
654670 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
655671 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
656672 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
@@ -736,8 +752,13 @@ static void ggml_metal_free(struct ggml_metal_context * ctx) {
736752
737753static bool ggml_metal_supports_op (const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
738754 for (size_t i = 0 , n = 3 ; i < n; ++i) {
739- if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
740- return false ;
755+ if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16 &&
756+ op->op != GGML_OP_GET_ROWS &&
757+ op->op != GGML_OP_MUL_MAT &&
758+ op->op != GGML_OP_VIEW &&
759+ op->op != GGML_OP_CPY) {
760+ printf (" op = %s , src[%zu ] = %s \n " , ggml_op_name (op->op ), i, ggml_type_name (op->src [i]->type ));
761+ GGML_ASSERT (false );
741762 }
742763 }
743764
@@ -811,6 +832,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
811832 case GGML_TYPE_F32:
812833 switch (op->type ) {
813834 case GGML_TYPE_F32:
835+ case GGML_TYPE_BF16:
814836 case GGML_TYPE_F16:
815837 case GGML_TYPE_Q8_0:
816838 case GGML_TYPE_Q4_0:
@@ -830,6 +852,14 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
830852 default :
831853 return false ;
832854 }
855+ case GGML_TYPE_BF16:
856+ switch (op->type ) {
857+ case GGML_TYPE_F32:
858+ case GGML_TYPE_F16:
859+ return true ;
860+ default :
861+ return false ;
862+ }
833863 default :
834864 return false ;
835865 };
@@ -1581,6 +1611,7 @@ static enum ggml_status ggml_metal_graph_compute(
15811611 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
15821612 switch (src0->type ) {
15831613 case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1614+ case GGML_TYPE_BF16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
15841615 case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
15851616 default : break ;
15861617 }
@@ -1589,6 +1620,7 @@ static enum ggml_status ggml_metal_graph_compute(
15891620
15901621 switch (src0->type ) {
15911622 case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline ; break ;
1623+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline ; break ;
15921624 case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline ; break ;
15931625 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline ; break ;
15941626 case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline ; break ;
@@ -1665,6 +1697,25 @@ static enum ggml_status ggml_metal_graph_compute(
16651697 nrows = 4 ;
16661698 }
16671699 } break ;
1700+ case GGML_TYPE_BF16:
1701+ {
1702+ nth0 = 32 ;
1703+ nth1 = 1 ;
1704+ if (src1t == GGML_TYPE_F32) {
1705+ if (ne11 * ne12 < 4 ) {
1706+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline ;
1707+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
1708+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline ;
1709+ nrows = ne11;
1710+ } else {
1711+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline ;
1712+ nrows = 4 ;
1713+ }
1714+ } else {
1715+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline ;
1716+ nrows = 4 ;
1717+ }
1718+ } break ;
16681719 case GGML_TYPE_Q4_0:
16691720 {
16701721 nth0 = 8 ;
@@ -2161,6 +2212,7 @@ static enum ggml_status ggml_metal_graph_compute(
21612212
21622213 switch (src0->type ) {
21632214 case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline ; break ;
2215+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline ; break ;
21642216 case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline ; break ;
21652217 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline ; break ;
21662218 case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline ; break ;
@@ -2776,6 +2828,7 @@ static enum ggml_status ggml_metal_graph_compute(
27762828
27772829 switch (dstt) {
27782830 case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline ; break ;
2831+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline ; break ;
27792832 case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline ; break ;
27802833 case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline ; break ;
27812834 case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline ; break ;
@@ -2794,6 +2847,13 @@ static enum ggml_status ggml_metal_graph_compute(
27942847 default : GGML_ASSERT (false && " not implemented" );
27952848 };
27962849 } break ;
2850+ case GGML_TYPE_BF16:
2851+ {
2852+ switch (dstt) {
2853+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline ; break ;
2854+ default : GGML_ASSERT (false && " not implemented" );
2855+ };
2856+ } break ;
27972857 default : GGML_ASSERT (false && " not implemented" );
27982858 }
27992859
0 commit comments