@@ -3247,12 +3247,14 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_
32473247template [[host_name(" kernel_flash_attn_ext_f16_h128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 128 >;
32483248template [[host_name(" kernel_flash_attn_ext_f16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 256 >;
32493249
3250+ #if !defined(GGML_METAL_NO_BFLOAT)
32503251template [[host_name(" kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 64 >;
32513252template [[host_name(" kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 80 >;
32523253template [[host_name(" kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 96 >;
32533254template [[host_name(" kernel_flash_attn_ext_bf16_h112" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 112 >;
32543255template [[host_name(" kernel_flash_attn_ext_bf16_h128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 128 >;
32553256template [[host_name(" kernel_flash_attn_ext_bf16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 >;
3257+ #endif
32563258
32573259template [[host_name(" kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 64 >;
32583260template [[host_name(" kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 80 >;
@@ -3630,15 +3632,19 @@ kernel void kernel_flash_attn_ext_vec(
36303632typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_vec_t;
36313633
36323634template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 128 >;
3635+ #if !defined(GGML_METAL_NO_BFLOAT)
36333636template [[host_name(" kernel_flash_attn_ext_vec_bf16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 128 >;
3637+ #endif
36343638template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 128 >;
36353639template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 128 >;
36363640template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 128 >;
36373641template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 128 >;
36383642template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 128 >;
36393643
36403644template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 256 >;
3645+ #if !defined(GGML_METAL_NO_BFLOAT)
36413646template [[host_name(" kernel_flash_attn_ext_vec_bf16_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 >;
3647+ #endif
36423648template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 256 >;
36433649template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 256 >;
36443650template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 256 >;
0 commit comments