@@ -257,7 +257,17 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
257257 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
258258 // GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
259259 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
260- // GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
260+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
261+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
262+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
263+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
264+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
265+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
266+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
267+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
268+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
269+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
270+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
261271 GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
262272 GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
263273 GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
@@ -712,7 +722,17 @@ @implementation GGMLMetalClass
712722 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
713723 // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
714724 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
715- // GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
725+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
726+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
727+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction);
728+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction);
729+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction);
730+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
731+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction);
732+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction);
733+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
734+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
735+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
716736 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
717737 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
718738 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
@@ -869,13 +889,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
869889 case GGML_OP_LEAKY_RELU:
870890 return true ;
871891 case GGML_OP_FLASH_ATTN_EXT:
872- if (op->src [1 ]->type != GGML_TYPE_F16) {
873- return false ;
874- }
875- if (op->src [2 ]->type != GGML_TYPE_F16) {
876- return false ;
877- }
878- if (op->src [0 ]->ne [0 ] == 256 ) {
892+ if (op->src [1 ]->type != GGML_TYPE_F16 && op->src [0 ]->ne [0 ] % 128 != 0 ) {
879893 return false ;
880894 }
881895 return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
@@ -2868,14 +2882,14 @@ static void ggml_metal_encode_node(
28682882
28692883 bool use_vec_kernel = false ;
28702884
2871- if (ne01 >= 4 || (ne00%128 != 0 )) {
2885+ if (src1-> type == GGML_TYPE_F16 && ne00 < 256 && ( ne01 >= 4 || (ne00%128 != 0 ) )) {
28722886 switch (ne00) {
28732887 case 64 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline ; break ;
28742888 case 80 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline ; break ;
28752889 case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline ; break ;
28762890 case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline ; break ;
28772891 case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline ; break ;
2878- // case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2892+ // case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
28792893 default :
28802894 {
28812895 GGML_LOG_ERROR (" unsupported size: %lld \n " , ne00);
@@ -2887,8 +2901,40 @@ static void ggml_metal_encode_node(
28872901 use_vec_kernel = true ;
28882902
28892903 switch (ne00) {
2890- case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline ; break ;
2891- // case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2904+ case 128 :
2905+ {
2906+ switch (src1->type ) {
2907+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline ; break ;
2908+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline ; break ;
2909+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline ; break ;
2910+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline ; break ;
2911+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline ; break ;
2912+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline ; break ;
2913+ default :
2914+ {
2915+ GGML_LOG_ERROR (" unsupported type: %d \n " , src1->type );
2916+ GGML_LOG_ERROR (" add template specialization for this type\n " );
2917+ GGML_ABORT (" add template specialization for this type" );
2918+ }
2919+ }
2920+ } break ;
2921+ case 256 :
2922+ {
2923+ switch (src1->type ) {
2924+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline ; break ;
2925+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline ; break ;
2926+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline ; break ;
2927+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline ; break ;
2928+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline ; break ;
2929+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline ; break ;
2930+ default :
2931+ {
2932+ GGML_LOG_ERROR (" unsupported type: %d \n " , src1->type );
2933+ GGML_LOG_ERROR (" add template specialization for this type\n " );
2934+ GGML_ABORT (" add template specialization for this type" );
2935+ }
2936+ }
2937+ } break ;
28922938 default :
28932939 {
28942940 GGML_LOG_ERROR (" unsupported size: %lld \n " , ne00);
0 commit comments