@@ -3249,7 +3249,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_
32493249template [[host_name(" kernel_flash_attn_ext_f16_h128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 128 >;
32503250template [[host_name(" kernel_flash_attn_ext_f16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 256 >;
32513251
3252- #if ! defined(GGML_METAL_NO_BFLOAT )
3252+ #if defined(GGML_METAL_USE_BF16 )
32533253template [[host_name(" kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 64 >;
32543254template [[host_name(" kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 80 >;
32553255template [[host_name(" kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 96 >;
@@ -3634,7 +3634,7 @@ kernel void kernel_flash_attn_ext_vec(
36343634typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_vec_t;
36353635
36363636template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 128 >;
3637- #if ! defined(GGML_METAL_NO_BFLOAT )
3637+ #if defined(GGML_METAL_USE_BF16 )
36383638template [[host_name(" kernel_flash_attn_ext_vec_bf16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 128 >;
36393639#endif
36403640template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 128 >;
@@ -3644,7 +3644,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_
36443644template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 128 >;
36453645
36463646template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 256 >;
3647- #if ! defined(GGML_METAL_NO_BFLOAT )
3647+ #if defined(GGML_METAL_USE_BF16 )
36483648template [[host_name(" kernel_flash_attn_ext_vec_bf16_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 >;
36493649#endif
36503650template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 256 >;
0 commit comments