From 82d745ebb3ac2e27381747c1f38b93c5ee284595 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 27 Sep 2025 12:21:35 -0500 Subject: [PATCH] vulkan: Fix validation failure in quantized flash attention --- .../vulkan-shaders/flash_attn_base.comp | 40 ++++++++++++++----- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index f73e17e1fa8d9..c09205cf98cc9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -65,30 +65,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; #if defined(A_TYPE_PACKED16) #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 -layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; +layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } } #endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + if (binding_idx == BINDING_IDX_K) { + const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } else { + const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } } #endif