From d122d5c987b8b13483190acf9838535298e585f5 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Fri, 3 Jan 2025 21:57:55 -0500 Subject: [PATCH 01/23] q6_k scale caching --- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 37 +++++++++---------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index 70e13a56bd730..a2a3623869f12 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -6,22 +6,22 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); const uint num_blocks_per_row = p.ncols / QUANT_K; - // 16 threads are used to process each block + // 16 thread groups are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; - - const uint step = 8; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...15 or 0...7 const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 const uint is = v_in / 4; @@ -46,11 +46,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - FLOAT_TYPE scales[4]; - scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); - scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); - scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); - scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); + sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); @@ -63,7 +60,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; - uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; + uint32_t qh4_u32 = (qh_u32 & 0x30303030); uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; @@ -82,14 +79,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]; B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]; - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 4; ++l) { - sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), - fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), - fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), - fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); + FLOAT_TYPE sum[4] = {0, 0, 0, 0}; + [[unroll]] for (uint l = 0; l < 4; ++l) { + sum[0] = fma(FLOAT_TYPE(by0[l]), FLOAT_TYPE(int8_t(q0[l]) - 32), sum[0]); + sum[1] = fma(FLOAT_TYPE(by32[l]), FLOAT_TYPE(int8_t(q1[l]) - 32), sum[1]); + sum[2] = fma(FLOAT_TYPE(by64[l]), FLOAT_TYPE(int8_t(q2[l]) - 32), sum[2]); + sum[3] = fma(FLOAT_TYPE(by96[l]), FLOAT_TYPE(int8_t(q3[l]) - 32), sum[3]); } - temp[j][n] += sum * d; + temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]); } } } From 6b06d1689011196ff3312277530402adefb53fbb Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sat, 4 Jan 2025 13:32:44 -0500 Subject: [PATCH 02/23] 16 bit unpack --- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 22 +++++++++---------- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 2 +- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index cd1dd8e89c21e..225f0ce70a6fc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -15,13 +15,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint step = 4; - - const uint il = itid/step; // 0...3 - const uint ir = itid - step*il; // 0...7 or 0...3 + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...7 or 0...3 const uint n = 4; const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 @@ -49,12 +47,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE dall = FLOAT_TYPE(d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec4 scale0 = uvec4(unpack8(scale0_u32)); - uvec4 scale4 = uvec4(unpack8(scale4_u32)); - uvec4 scale8 = uvec4(unpack8(scale8_u32)); + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + uvec2 scale0 = uvec2(unpack8(scale0_u32)); + uvec2 scale4 = uvec2(unpack8(scale4_u32)); + uvec2 scale8 = uvec2(unpack8(scale8_u32)); const uint32_t sc0 = ( scale0.x & 0x3f); const uint32_t sc1 = ( scale0.y & 0x3f); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index a2a3623869f12..e1afd55e0f540 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -14,7 +14,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint num_blocks_per_row = p.ncols / QUANT_K; - // 16 thread groups are used to process each block + // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; const uint itid = tid%16; // 0...15 From 21c6b805c99d90332250adb969e128171d431525 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sat, 4 Jan 2025 14:57:28 -0500 Subject: [PATCH 03/23] q4_k test (slow) --- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 51 ++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 225f0ce70a6fc..ee439f5c4ca31 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -6,6 +6,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +shared FLOAT_TYPE sc[BLOCK_SIZE/16][2][8]; + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -25,6 +27,16 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_in = il % 2; + // v_im may be incorrect + const uint itid8 = itid%8; + const uint a = (itid8 > 5) ? 4 : 0; + const uint b = (itid8 > 3) ? 0x0f : 0x3f; + const uint c = (itid8 > 3) ? 0xc0 : 0xFF; + const uint dd = (itid8 > 3) ? 2 : 0; + const uint gg = (itid8 > 1) ? 1 : 0; + const uint g = (itid8 > 3) ? 8 : gg; + const uint h = itid8%2; + const uint l0 = n * (2 * ir + v_in); // 0...15 const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; @@ -50,18 +62,23 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec2 scale0 = uvec2(unpack8(scale0_u32)); - uvec2 scale4 = uvec2(unpack8(scale4_u32)); - uvec2 scale8 = uvec2(unpack8(scale8_u32)); - - const uint32_t sc0 = ( scale0.x & 0x3f); - const uint32_t sc1 = ( scale0.y & 0x3f); - const uint32_t sc2 = ( scale4.x & 0x3f); - const uint32_t sc3 = ( scale4.y & 0x3f); - const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); - const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); - const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); - const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); + + uvec2 scale[3]; + scale[0] = uvec2(unpack8(scale0_u32)); + scale[1] = uvec2(unpack8(scale4_u32)); + scale[2] = uvec2(unpack8(scale8_u32)); + + sc[ix][v_im][itid8] = FLOAT_TYPE(((scale[gg][h] >> a) & b) | ((scale[g][h] & c) >> dd)); + barrier(); + + //const uint32_t sc0 = ( scale0.x & 0x3f); + //const uint32_t sc1 = ( scale0.y & 0x3f); + //const uint32_t sc2 = ( scale4.x & 0x3f); + //const uint32_t sc3 = ( scale4.y & 0x3f); + //const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); + //const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); + //const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); + //const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; @@ -104,11 +121,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, - fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, - fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, - fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + fma(FLOAT_TYPE(by10.x), sc[ix][v_im][2], fma(FLOAT_TYPE(by132.x), sc[ix][v_im][3], fma(FLOAT_TYPE(by20.x), sc[ix][v_im][6], fma(FLOAT_TYPE(by232.x), sc[ix][v_im][7], + fma(FLOAT_TYPE(by10.y), sc[ix][v_im][2], fma(FLOAT_TYPE(by132.y), sc[ix][v_im][3], fma(FLOAT_TYPE(by20.y), sc[ix][v_im][6], fma(FLOAT_TYPE(by232.y), sc[ix][v_im][7], + fma(FLOAT_TYPE(by10.z), sc[ix][v_im][2], fma(FLOAT_TYPE(by132.z), sc[ix][v_im][3], fma(FLOAT_TYPE(by20.z), sc[ix][v_im][6], fma(FLOAT_TYPE(by232.z), sc[ix][v_im][7], + fma(FLOAT_TYPE(by10.w), sc[ix][v_im][2], fma(FLOAT_TYPE(by132.w), sc[ix][v_im][3], fma(FLOAT_TYPE(by20.w), sc[ix][v_im][6], FLOAT_TYPE(by232.w) * sc[ix][v_im][7]))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc[ix][v_im][0], fma(sy, sc[ix][v_im][1], fma(sz, sc[ix][v_im][4], sw * sc[ix][v_im][5]))), fma(-dmin, smin, temp[j][n])); } } } From b0e4ccbeb95f3987052fbc08500459773216a691 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sat, 4 Jan 2025 14:57:40 -0500 Subject: [PATCH 04/23] revert it --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 16 +++--- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 12 ++--- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 51 +++++++------------ .../vulkan-shaders/mul_mat_vec_q5_k.comp | 12 ++--- 4 files changed, 35 insertions(+), 56 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 9342134462416..921759bc8bed2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -14,13 +14,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint step = 8; - - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...15 or 0...7 const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; @@ -44,7 +42,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE dall = d.x; const FLOAT_TYPE dmin = d.y; - uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; + uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4]; uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; @@ -57,13 +55,13 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); - uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; + uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2]; uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; uvec2 qs0 = uvec2(unpack8(qs0_u16)); uvec2 qs16 = uvec2(unpack8(qs16_u16)); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; + B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 86b0159d97a89..b9a87f13a99ed 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -14,13 +14,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint step = 8; - - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...15 or 0...7 const uint8_t m = uint8_t(1 << (4 * v_im)); @@ -60,7 +58,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]; + B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index ee439f5c4ca31..225f0ce70a6fc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -6,8 +6,6 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sc[BLOCK_SIZE/16][2][8]; - void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -27,16 +25,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_in = il % 2; - // v_im may be incorrect - const uint itid8 = itid%8; - const uint a = (itid8 > 5) ? 4 : 0; - const uint b = (itid8 > 3) ? 0x0f : 0x3f; - const uint c = (itid8 > 3) ? 0xc0 : 0xFF; - const uint dd = (itid8 > 3) ? 2 : 0; - const uint gg = (itid8 > 1) ? 1 : 0; - const uint g = (itid8 > 3) ? 8 : gg; - const uint h = itid8%2; - const uint l0 = n * (2 * ir + v_in); // 0...15 const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; @@ -62,23 +50,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - - uvec2 scale[3]; - scale[0] = uvec2(unpack8(scale0_u32)); - scale[1] = uvec2(unpack8(scale4_u32)); - scale[2] = uvec2(unpack8(scale8_u32)); - - sc[ix][v_im][itid8] = FLOAT_TYPE(((scale[gg][h] >> a) & b) | ((scale[g][h] & c) >> dd)); - barrier(); - - //const uint32_t sc0 = ( scale0.x & 0x3f); - //const uint32_t sc1 = ( scale0.y & 0x3f); - //const uint32_t sc2 = ( scale4.x & 0x3f); - //const uint32_t sc3 = ( scale4.y & 0x3f); - //const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); - //const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); - //const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); - //const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); + uvec2 scale0 = uvec2(unpack8(scale0_u32)); + uvec2 scale4 = uvec2(unpack8(scale4_u32)); + uvec2 scale8 = uvec2(unpack8(scale8_u32)); + + const uint32_t sc0 = ( scale0.x & 0x3f); + const uint32_t sc1 = ( scale0.y & 0x3f); + const uint32_t sc2 = ( scale4.x & 0x3f); + const uint32_t sc3 = ( scale4.y & 0x3f); + const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); + const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); + const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); + const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; @@ -121,11 +104,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x), sc[ix][v_im][2], fma(FLOAT_TYPE(by132.x), sc[ix][v_im][3], fma(FLOAT_TYPE(by20.x), sc[ix][v_im][6], fma(FLOAT_TYPE(by232.x), sc[ix][v_im][7], - fma(FLOAT_TYPE(by10.y), sc[ix][v_im][2], fma(FLOAT_TYPE(by132.y), sc[ix][v_im][3], fma(FLOAT_TYPE(by20.y), sc[ix][v_im][6], fma(FLOAT_TYPE(by232.y), sc[ix][v_im][7], - fma(FLOAT_TYPE(by10.z), sc[ix][v_im][2], fma(FLOAT_TYPE(by132.z), sc[ix][v_im][3], fma(FLOAT_TYPE(by20.z), sc[ix][v_im][6], fma(FLOAT_TYPE(by232.z), sc[ix][v_im][7], - fma(FLOAT_TYPE(by10.w), sc[ix][v_im][2], fma(FLOAT_TYPE(by132.w), sc[ix][v_im][3], fma(FLOAT_TYPE(by20.w), sc[ix][v_im][6], FLOAT_TYPE(by232.w) * sc[ix][v_im][7]))))))))))))))); - temp[j][n] = fma(dall, fma(sx, sc[ix][v_im][0], fma(sy, sc[ix][v_im][1], fma(sz, sc[ix][v_im][4], sw * sc[ix][v_im][5]))), fma(-dmin, smin, temp[j][n])); + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 0a68891c35a5f..bea1dc8608f6b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -15,8 +15,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; const uint il = itid/4; // 0...3 const uint ir = itid - 4*il; // 0...7 or 0...3 @@ -49,9 +49,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec4 scale0 = uvec4(unpack8(scale0_u32)); - uvec4 scale4 = uvec4(unpack8(scale4_u32)); - uvec4 scale8 = uvec4(unpack8(scale8_u32)); + uvec2 scale0 = uvec2(unpack8(scale0_u32)); + uvec2 scale4 = uvec2(unpack8(scale4_u32)); + uvec2 scale8 = uvec2(unpack8(scale8_u32)); const uint32_t sc0 = ( scale0.x & 0x3f); const uint32_t sc1 = ( scale0.y & 0x3f); @@ -74,7 +74,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; - uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0; + uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; qs0_16_u32_lo4 += qs0_16_lo4_offset16; From 07d0d58bef57366233b52c421632cfb2e54c76d4 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sat, 4 Jan 2025 16:24:29 -0500 Subject: [PATCH 05/23] q3_k --- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 36 ++++++++----------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index b9a87f13a99ed..0266417a88f72 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -5,6 +5,9 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +shared uint32_t scu8[BLOCK_SIZE/16][12]; +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][12]; + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -16,6 +19,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; const uint itid = tid%16; // 0...15 const uint ix = tid/16; + const uint itid8 = itid%8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_in = itid - 8*v_im; // 0...15 or 0...7 @@ -34,8 +38,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { } } - const uint s_shift = 4 * v_im; - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint y_idx = i * QUANT_K + y_offset; @@ -43,18 +45,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0]; - uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1]; - uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2]; - uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3]; - uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4]; - uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5]; - u8vec2 s0 = unpack8(s0_16); - u8vec2 s2 = unpack8(s2_16); - u8vec2 s4 = unpack8(s4_16); - u8vec2 s6 = unpack8(s6_16); - u8vec2 s8 = unpack8(s8_16); - u8vec2 s10 = unpack8(s10_16); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid%4+8]) >> (4*v_im + 2*(itid8/4)) & 0x3) << 4)) - 32); + barrier(); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { @@ -69,14 +61,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); + sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m )) != 0) ? 0 : 4)), + fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m )) != 0) ? 0 : 4)), + fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), + fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); } temp[j][n] = fma(d, sum, temp[j][n]); } From d70a731639d9acd0baa588a94bfaa5f928b26b9c Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sat, 4 Jan 2025 20:48:27 -0500 Subject: [PATCH 06/23] q2_k --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 50 ++++++++----------- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 5 +- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 6 +-- 3 files changed, 27 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 921759bc8bed2..098771493fbe1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -5,6 +5,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][16]; + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -16,6 +18,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; const uint itid = tid%16; // 0...15 const uint ix = tid/16; + const uint itid8 = itid%8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_in = itid - 8*v_im; // 0...15 or 0...7 @@ -42,18 +45,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE dall = d.x; const FLOAT_TYPE dmin = d.y; - uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4]; - uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; - - uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; - uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; - uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; - uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; - - uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); - uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); - uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); - uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); + sccache[ix][0][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8]), int(v_im*4), 4)); // lower 8 bytes + sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes + barrier(); uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2]; uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; @@ -73,22 +67,22 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); - sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); + sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3), + fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3), + fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3), + fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3), + fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3), + fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3), + fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3), + fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8], + fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9], + fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10], + fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11], + fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12], + fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13], + fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14], + fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15], sum2)))))))); } temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 0266417a88f72..0658f46bd7688 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -5,7 +5,6 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared uint32_t scu8[BLOCK_SIZE/16][12]; shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][12]; void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { @@ -61,8 +60,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m )) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m )) != 0) ? 0 : 4)), + sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m )) != 0) ? 0 : 4)), + fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m )) != 0) ? 0 : 4)), fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 225f0ce70a6fc..8a3644e968c9c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -47,9 +47,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE dall = FLOAT_TYPE(d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; uvec2 scale0 = uvec2(unpack8(scale0_u32)); uvec2 scale4 = uvec2(unpack8(scale4_u32)); uvec2 scale8 = uvec2(unpack8(scale8_u32)); From c01ccf8288f1e375ed5be535bab1efc43c213406 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sun, 5 Jan 2025 02:31:28 +0000 Subject: [PATCH 07/23] little stuff --- ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 0658f46bd7688..54085f90e46a7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -5,7 +5,7 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][12]; +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8]; void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; @@ -44,7 +44,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid%4+8]) >> (4*v_im + 2*(itid8/4)) & 0x3) << 4)) - 32); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (4*v_im + 2*(itid8/4)) & 0x3) << 4)) - 32); barrier(); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { From bdd98c74e24b38a820aacecd5d0cef1149b66879 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sun, 5 Jan 2025 14:59:56 -0500 Subject: [PATCH 08/23] try precalculating products of a and q2_k scales --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 098771493fbe1..0a7ee58b39cee 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -54,6 +54,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uvec2 qs0 = uvec2(unpack8(qs0_u16)); uvec2 qs16 = uvec2(unpack8(qs16_u16)); + FLOAT_TYPE sc_q[2][8]; + [[unroll]] for (int l = 0; l < 2; ++l) { + sc_q[l][0] = sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3); + sc_q[l][1] = sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3); + sc_q[l][2] = sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3); + sc_q[l][3] = sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3); + sc_q[l][4] = sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3); + sc_q[l][5] = sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3); + sc_q[l][6] = sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3); + sc_q[l][7] = sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3); + } + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; @@ -67,14 +79,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3), - fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3), - fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3), - fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3), - fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3), - fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3), - fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3), - fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); + sum1 = fma(FLOAT_TYPE(b0[l]), sc_q[l][0], + fma(FLOAT_TYPE(b16[l]), sc_q[l][1], + fma(FLOAT_TYPE(b32[l]), sc_q[l][2], + fma(FLOAT_TYPE(b48[l]), sc_q[l][3], + fma(FLOAT_TYPE(b64[l]), sc_q[l][4], + fma(FLOAT_TYPE(b80[l]), sc_q[l][5], + fma(FLOAT_TYPE(b96[l]), sc_q[l][6], + fma(FLOAT_TYPE(b112[l]), sc_q[l][7], sum1)))))))); sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8], fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9], fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10], From 173077180ff5e63ffeda96a7b451303a9af69543 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sun, 5 Jan 2025 17:01:34 -0500 Subject: [PATCH 09/23] Revert "try precalculating products of a and q2_k scales" This reverts commit 65110b81f23f66331a50c6e889a7c1ab9470a86b. --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 28 ++++++------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 0a7ee58b39cee..098771493fbe1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -54,18 +54,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uvec2 qs0 = uvec2(unpack8(qs0_u16)); uvec2 qs16 = uvec2(unpack8(qs16_u16)); - FLOAT_TYPE sc_q[2][8]; - [[unroll]] for (int l = 0; l < 2; ++l) { - sc_q[l][0] = sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3); - sc_q[l][1] = sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3); - sc_q[l][2] = sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3); - sc_q[l][3] = sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3); - sc_q[l][4] = sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3); - sc_q[l][5] = sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3); - sc_q[l][6] = sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3); - sc_q[l][7] = sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3); - } - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; @@ -79,14 +67,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), sc_q[l][0], - fma(FLOAT_TYPE(b16[l]), sc_q[l][1], - fma(FLOAT_TYPE(b32[l]), sc_q[l][2], - fma(FLOAT_TYPE(b48[l]), sc_q[l][3], - fma(FLOAT_TYPE(b64[l]), sc_q[l][4], - fma(FLOAT_TYPE(b80[l]), sc_q[l][5], - fma(FLOAT_TYPE(b96[l]), sc_q[l][6], - fma(FLOAT_TYPE(b112[l]), sc_q[l][7], sum1)))))))); + sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3), + fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3), + fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3), + fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3), + fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3), + fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3), + fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3), + fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8], fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9], fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10], From b4ae7005e66cb03d8de601dd271b10a1127970c4 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sun, 5 Jan 2025 21:59:43 -0500 Subject: [PATCH 10/23] unpack should be u16, add vim swap to gitignore (about time) --- .gitignore | 1 + .../ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp | 12 ++++++------ .../ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp | 12 ++++++------ 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 1df7cf4a15ecd..694f36e042fb5 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ *.metallib *.o *.so +*.swp *.tmp # IDE / OS diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 8a3644e968c9c..940ba3f51f299 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -47,12 +47,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE dall = FLOAT_TYPE(d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec2 scale0 = uvec2(unpack8(scale0_u32)); - uvec2 scale4 = uvec2(unpack8(scale4_u32)); - uvec2 scale8 = uvec2(unpack8(scale8_u32)); + uint16_t scale0_u16 = data_a_packed16[ib0 + i].scales[v_im ]; + uint16_t scale4_u16 = data_a_packed16[ib0 + i].scales[v_im + 2]; + uint16_t scale8_u16 = data_a_packed16[ib0 + i].scales[v_im + 4]; + uvec2 scale0 = uvec2(unpack8(scale0_u16)); + uvec2 scale4 = uvec2(unpack8(scale4_u16)); + uvec2 scale8 = uvec2(unpack8(scale8_u16)); const uint32_t sc0 = ( scale0.x & 0x3f); const uint32_t sc1 = ( scale0.y & 0x3f); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index bea1dc8608f6b..c4aaf9fea525e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -46,12 +46,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE dall = FLOAT_TYPE(d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec2 scale0 = uvec2(unpack8(scale0_u32)); - uvec2 scale4 = uvec2(unpack8(scale4_u32)); - uvec2 scale8 = uvec2(unpack8(scale8_u32)); + uint16_t scale0_u16 = data_a_packed16[ib0 + i].scales[v_im ]; + uint16_t scale4_u16 = data_a_packed16[ib0 + i].scales[v_im + 2]; + uint16_t scale8_u16 = data_a_packed16[ib0 + i].scales[v_im + 4]; + uvec2 scale0 = uvec2(unpack8(scale0_u16)); + uvec2 scale4 = uvec2(unpack8(scale4_u16)); + uvec2 scale8 = uvec2(unpack8(scale8_u16)); const uint32_t sc0 = ( scale0.x & 0x3f); const uint32_t sc1 = ( scale0.y & 0x3f); From cdf70cf27fb9c8abb0caa2f3104b66616c207f60 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sun, 5 Jan 2025 22:43:12 -0500 Subject: [PATCH 11/23] better q4_k scales --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 1 - .../vulkan-shaders/mul_mat_vec_q4_k.comp | 94 ++++++++++--------- 2 files changed, 48 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 098771493fbe1..1f00e442defbb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -25,7 +25,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; - const uint s_offset = 8*v_im; const uint y_offset = 128*v_im + l0; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 940ba3f51f299..6761f3d32b3d9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -43,55 +43,57 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - f16vec2 d = data_a[ib0 + i].d; + const f16vec2 d = data_a[ib0 + i].d; const FLOAT_TYPE dall = FLOAT_TYPE(d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - uint16_t scale0_u16 = data_a_packed16[ib0 + i].scales[v_im ]; - uint16_t scale4_u16 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint16_t scale8_u16 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec2 scale0 = uvec2(unpack8(scale0_u16)); - uvec2 scale4 = uvec2(unpack8(scale4_u16)); - uvec2 scale8 = uvec2(unpack8(scale8_u16)); - - const uint32_t sc0 = ( scale0.x & 0x3f); - const uint32_t sc1 = ( scale0.y & 0x3f); - const uint32_t sc2 = ( scale4.x & 0x3f); - const uint32_t sc3 = ( scale4.y & 0x3f); - const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); - const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); - const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); - const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); - - uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; - uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; - - uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; - uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; - uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; - uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; - - uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4)); - uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4)); - uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4)); - uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4)); - - const uint32_t q4_0 = qs0_lo4.x; - const uint32_t q4_1 = qs0_lo4.y; - const uint32_t q4_2 = qs0_lo4.z; - const uint32_t q4_3 = qs0_lo4.w; - const uint32_t q4_4 = qs0_hi4.x; - const uint32_t q4_5 = qs0_hi4.y; - const uint32_t q4_6 = qs0_hi4.z; - const uint32_t q4_7 = qs0_hi4.w; - const uint32_t q4_8 = qs64_lo4.x; - const uint32_t q4_9 = qs64_lo4.y; - const uint32_t q4_10 = qs64_lo4.z; - const uint32_t q4_11 = qs64_lo4.w; - const uint32_t q4_12 = qs64_hi4.x; - const uint32_t q4_13 = qs64_hi4.y; - const uint32_t q4_14 = qs64_hi4.z; - const uint32_t q4_15 = qs64_hi4.w; + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xc0c0c0c0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3f3f3f3f)); + const vec4 scale8_f = vec4(unpack8(((((scale8_u32 >> 4) << 16) | scale8_u32) & 0x0f0f0f0f) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; + const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; + + const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4)); + const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4)); + const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4)); + const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_lo4.x; + const FLOAT_TYPE q4_1 = qs0_lo4.y; + const FLOAT_TYPE q4_2 = qs0_lo4.z; + const FLOAT_TYPE q4_3 = qs0_lo4.w; + const FLOAT_TYPE q4_4 = qs0_hi4.x; + const FLOAT_TYPE q4_5 = qs0_hi4.y; + const FLOAT_TYPE q4_6 = qs0_hi4.z; + const FLOAT_TYPE q4_7 = qs0_hi4.w; + const FLOAT_TYPE q4_8 = qs64_lo4.x; + const FLOAT_TYPE q4_9 = qs64_lo4.y; + const FLOAT_TYPE q4_10 = qs64_lo4.z; + const FLOAT_TYPE q4_11 = qs64_lo4.w; + const FLOAT_TYPE q4_12 = qs64_hi4.x; + const FLOAT_TYPE q4_13 = qs64_hi4.y; + const FLOAT_TYPE q4_14 = qs64_hi4.z; + const FLOAT_TYPE q4_15 = qs64_hi4.w; [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4]; From 6f5d62b098a45d9c4a0d833d03ec68848b0c06b5 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:13:23 -0500 Subject: [PATCH 12/23] q5_k --- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 6 ++-- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 32 ++++++++++--------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 6761f3d32b3d9..28cde16db7c33 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -52,9 +52,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; - const uint32_t scale_0_4_h = (scale_0_4_l & 0xc0c0c0c0) >> 2; - const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3f3f3f3f)); - const vec4 scale8_f = vec4(unpack8(((((scale8_u32 >> 4) << 16) | scale8_u32) & 0x0f0f0f0f) | scale_0_4_h)); + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); const FLOAT_TYPE sc0 = scale_0_4_l_f.x; const FLOAT_TYPE sc1 = scale_0_4_l_f.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index c4aaf9fea525e..b5e8399fe5f8d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -46,21 +46,23 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE dall = FLOAT_TYPE(d.x); const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - uint16_t scale0_u16 = data_a_packed16[ib0 + i].scales[v_im ]; - uint16_t scale4_u16 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint16_t scale8_u16 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec2 scale0 = uvec2(unpack8(scale0_u16)); - uvec2 scale4 = uvec2(unpack8(scale4_u16)); - uvec2 scale8 = uvec2(unpack8(scale8_u16)); - - const uint32_t sc0 = ( scale0.x & 0x3f); - const uint32_t sc1 = ( scale0.y & 0x3f); - const uint32_t sc2 = ( scale4.x & 0x3f); - const uint32_t sc3 = ( scale4.y & 0x3f); - const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); - const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); - const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); - const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); From 91f1d9ce991f068a0bc39befecadf0fa52800e24 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Tue, 7 Jan 2025 19:57:55 -0500 Subject: [PATCH 13/23] better q6_k with separate paths for all threads and partial threads in use, plus some more optimizations --- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 126 ++++++++++-------- 1 file changed, 73 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index e1afd55e0f540..231bfc5da04a1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -8,7 +8,73 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; -void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) + sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + const uint32_t qh4_u32 = (qh_u32 & 0x30303030); + const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + const vec4 q0 = vec4(unpack8(q0_u32)) - 32; + const vec4 q1 = vec4(unpack8(q1_u32)) - 32; + const vec4 q2 = vec4(unpack8(q2_u32)) - 32; + const vec4 q3 = vec4(unpack8(q3_u32)) - 32; + + if (all_threads) { + sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4]; + B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]; + B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]; + B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]; + + FLOAT_TYPE sum[4] = {0, 0, 0, 0}; + [[unroll]] for (uint l = 0; l < 4; ++l) { + sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]); + sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]); + sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); + sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); + } + temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]); + } + } +} + +void compute_outputs(const uint first_row, const uint num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -31,65 +97,19 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint s_offset = 8*v_im + is; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); - barrier(); - - uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); - uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); - - uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; - uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; - uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; - uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; - - uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); - uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; - uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; - uint32_t qh4_u32 = (qh_u32 & 0x30303030); - uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; - - uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; - uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; - uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; - uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; - - uvec4 q0 = uvec4(unpack8(q0_u32)); - uvec4 q1 = uvec4(unpack8(q1_u32)); - uvec4 q2 = uvec4(unpack8(q2_u32)); - uvec4 q3 = uvec4(unpack8(q3_u32)); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4]; - B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]; - B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]; - B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]; - - FLOAT_TYPE sum[4] = {0, 0, 0, 0}; - [[unroll]] for (uint l = 0; l < 4; ++l) { - sum[0] = fma(FLOAT_TYPE(by0[l]), FLOAT_TYPE(int8_t(q0[l]) - 32), sum[0]); - sum[1] = fma(FLOAT_TYPE(by32[l]), FLOAT_TYPE(int8_t(q1[l]) - 32), sum[1]); - sum[2] = fma(FLOAT_TYPE(by64[l]), FLOAT_TYPE(int8_t(q2[l]) - 32), sum[2]); - sum[3] = fma(FLOAT_TYPE(by96[l]), FLOAT_TYPE(int8_t(q3[l]) - 32), sum[3]); - } - temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]); - } - } + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) { + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); } + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } From cc28742ca39b12efea5f9b8d87d44860a3430ccb Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Tue, 7 Jan 2025 21:20:33 -0500 Subject: [PATCH 14/23] q2_k better dequant --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 27 ++++----- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 56 +++++++++---------- 2 files changed, 42 insertions(+), 41 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 1f00e442defbb..ed924a072a982 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -40,7 +40,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - f16vec2 d = data_a[ib0 + i].d; + const f16vec2 d = data_a[ib0 + i].d; const FLOAT_TYPE dall = d.x; const FLOAT_TYPE dmin = d.y; @@ -48,10 +48,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes barrier(); - uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2]; - uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; - uvec2 qs0 = uvec2(unpack8(qs0_u16)); - uvec2 qs16 = uvec2(unpack8(qs16_u16)); + const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; @@ -66,14 +67,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3), - fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3), - fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3), - fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3), - fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3), - fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3), - fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3), - fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); + sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1)))))))); sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8], fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9], fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index b5e8399fe5f8d..9039ce1f34cbf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -64,47 +64,47 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE sc6 = scale8_f.z; const FLOAT_TYPE sc7 = scale8_f.w; - uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); - uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; - uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); - uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; - uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; - uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); - uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); + const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; qs0_16_u32_lo4 += qs0_16_lo4_offset16; qs0_16_u32_hi4 += qs0_16_hi4_offset16; qs64_80_u32_lo4 += qs64_80_lo4_offset16; qs64_80_u32_hi4 += qs64_80_hi4_offset16; - uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4)); - uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4)); - uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4)); - uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4)); - - const uint32_t q4_0 = qs0_16_lo4.x; - const uint32_t q4_1 = qs0_16_lo4.y; - const uint32_t q4_2 = qs0_16_lo4.z; - const uint32_t q4_3 = qs0_16_lo4.w; - const uint32_t q4_4 = qs0_16_hi4.x; - const uint32_t q4_5 = qs0_16_hi4.y; - const uint32_t q4_6 = qs0_16_hi4.z; - const uint32_t q4_7 = qs0_16_hi4.w; - const uint32_t q4_8 = qs64_80_lo4.x; - const uint32_t q4_9 = qs64_80_lo4.y; - const uint32_t q4_10 = qs64_80_lo4.z; - const uint32_t q4_11 = qs64_80_lo4.w; - const uint32_t q4_12 = qs64_80_hi4.x; - const uint32_t q4_13 = qs64_80_hi4.y; - const uint32_t q4_14 = qs64_80_hi4.z; - const uint32_t q4_15 = qs64_80_hi4.w; + const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); + const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); + const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); + const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_16_lo4.x; + const FLOAT_TYPE q4_1 = qs0_16_lo4.y; + const FLOAT_TYPE q4_2 = qs0_16_lo4.z; + const FLOAT_TYPE q4_3 = qs0_16_lo4.w; + const FLOAT_TYPE q4_4 = qs0_16_hi4.x; + const FLOAT_TYPE q4_5 = qs0_16_hi4.y; + const FLOAT_TYPE q4_6 = qs0_16_hi4.z; + const FLOAT_TYPE q4_7 = qs0_16_hi4.w; + const FLOAT_TYPE q4_8 = qs64_80_lo4.x; + const FLOAT_TYPE q4_9 = qs64_80_lo4.y; + const FLOAT_TYPE q4_10 = qs64_80_lo4.z; + const FLOAT_TYPE q4_11 = qs64_80_lo4.w; + const FLOAT_TYPE q4_12 = qs64_80_hi4.x; + const FLOAT_TYPE q4_13 = qs64_80_hi4.y; + const FLOAT_TYPE q4_14 = qs64_80_hi4.z; + const FLOAT_TYPE q4_15 = qs64_80_hi4.w; [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2]; From fe71a8c4a12540f0ae666e2b6d518f54a057d833 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Tue, 7 Jan 2025 22:05:51 -0500 Subject: [PATCH 15/23] q3_k optimizations --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 29 +++++++++++++------ .../vulkan-shaders/mul_mat_vec_q4_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 2 +- 5 files changed, 24 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index ed924a072a982..ae78a2d14123b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -21,7 +21,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint itid8 = itid%8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - 8*v_im; // 0...15 or 0...7 + const uint v_in = itid - 8*v_im; // 0...7 const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 54085f90e46a7..cbe3269d62210 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -21,7 +21,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint itid8 = itid%8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - 8*v_im; // 0...15 or 0...7 + const uint v_in = itid - 8*v_im; // 0...7 const uint8_t m = uint8_t(1 << (4 * v_im)); @@ -47,6 +47,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (4*v_im + 2*(itid8/4)) & 0x3) << 4)) - 32); barrier(); + // 0, 1, 16, 17 + uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); + qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + const uvec2 hmk0 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in])); + const uvec2 hmk16 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in + 8])); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; @@ -60,14 +71,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m )) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m )) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); + sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - FLOAT_TYPE((( hmk0[l] & (m )) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - FLOAT_TYPE(((hmk16[l] & (m )) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - FLOAT_TYPE((( hmk0[l] & (m << 1)) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 1)) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - FLOAT_TYPE((( hmk0[l] & (m << 2)) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 2)) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - FLOAT_TYPE((( hmk0[l] & (m << 3)) != 0) ? 0 : 4), + fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 3)) != 0) ? 0 : 4), sum)))))))); } temp[j][n] = fma(d, sum, temp[j][n]); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 28cde16db7c33..f8bc885dccda8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -19,7 +19,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint ix = tid/16; const uint il = itid/4; // 0...3 - const uint ir = itid - 4*il; // 0...7 or 0...3 + const uint ir = itid - 4*il; // 0...3 const uint n = 4; const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 9039ce1f34cbf..72b36d9af80c8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -19,7 +19,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint ix = tid/16; const uint il = itid/4; // 0...3 - const uint ir = itid - 4*il; // 0...7 or 0...3 + const uint ir = itid - 4*il; // 0...3 const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_in = il % 2; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index 231bfc5da04a1..efbdfe939cd87 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -87,7 +87,7 @@ void compute_outputs(const uint first_row, const uint num_rows) { const uint ix = tid/16; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - 8*v_im; // 0...15 or 0...7 + const uint v_in = itid - 8*v_im; // 0...7 const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 const uint is = v_in / 4; From 923e9a8377dfc76a189c0e3f5e06aff4384453e3 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Wed, 8 Jan 2025 21:13:09 -0500 Subject: [PATCH 16/23] q3_k use hmask simd from cpu avx version --- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index cbe3269d62210..b9bc0bfcdc389 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -21,9 +21,13 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint itid8 = itid%8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_im4 = v_im*4; const uint v_in = itid - 8*v_im; // 0...7 - const uint8_t m = uint8_t(1 << (4 * v_im)); + const uint32_t m = 0x01010101 << (4 * v_im); + uint32_t hm_m[4]; + [[unroll]] for (uint j = 0; j < 4; ++j) + hm_m[j] = m << j; const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; @@ -44,7 +48,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (4*v_im + 2*(itid8/4)) & 0x3) << 4)) - 32); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> v_im4) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (v_im4 + 2*(itid8/4)) & 0x3) << 4)) - 32); barrier(); // 0, 1, 16, 17 @@ -55,8 +59,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); - const uvec2 hmk0 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in])); - const uvec2 hmk16 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in + 8])); + const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); + const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); + const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); + const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); + const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { @@ -71,14 +78,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - FLOAT_TYPE((( hmk0[l] & (m )) != 0) ? 0 : 4), - fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - FLOAT_TYPE(((hmk16[l] & (m )) != 0) ? 0 : 4), - fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - FLOAT_TYPE((( hmk0[l] & (m << 1)) != 0) ? 0 : 4), - fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 1)) != 0) ? 0 : 4), - fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - FLOAT_TYPE((( hmk0[l] & (m << 2)) != 0) ? 0 : 4), - fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 2)) != 0) ? 0 : 4), - fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - FLOAT_TYPE((( hmk0[l] & (m << 3)) != 0) ? 0 : 4), - fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 3)) != 0) ? 0 : 4), sum)))))))); + sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); } temp[j][n] = fma(d, sum, temp[j][n]); } From 51b5ac507db6c7e4288c57cf2e7955c411ef8490 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Thu, 9 Jan 2025 17:06:54 -0500 Subject: [PATCH 17/23] make the caches happy --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 7 ++++--- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 18 +++++++++--------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index ae78a2d14123b..129442f7ec924 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -40,9 +40,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - const f16vec2 d = data_a[ib0 + i].d; - const FLOAT_TYPE dall = d.x; - const FLOAT_TYPE dmin = d.y; sccache[ix][0][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8]), int(v_im*4), 4)); // lower 8 bytes sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes @@ -54,6 +51,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + const f16vec2 d = data_a[ib0 + i].d; + const FLOAT_TYPE dall = d.x; + const FLOAT_TYPE dmin = d.y; + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index b9bc0bfcdc389..0a48f84de307a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -46,10 +46,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> v_im4) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (v_im4 + 2*(itid8/4)) & 0x3) << 4)) - 32); - barrier(); + const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); + const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); + const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); + const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); + const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); // 0, 1, 16, 17 uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); @@ -59,14 +61,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); - const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); - const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); - const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); - const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); - const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> v_im4) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (v_im4 + 2*(itid8/4)) & 0x3) << 4)) - 32); + barrier(); - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; From 973bc4069f0d8da9e11c2612bc4b085e18d975af Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Thu, 9 Jan 2025 21:06:05 -0500 Subject: [PATCH 18/23] q3_k separate out calculation --- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 121 ++++++++++-------- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 3 +- 2 files changed, 71 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 0a48f84de307a..e395b61438f80 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -6,6 +6,69 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8]; +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | ((data_a[ib0+i].scales[itid8%4+8] >> s_shift & 3) << 4)) - 32); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); + const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); + const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); + const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); + const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); + + // 0, 1, 16, 17 + uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); + qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + if (all_threads) { + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | ((data_a[ib0+i].scales[itid8%4+8] >> s_shift & 3) << 4)) - 32); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; + B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; + B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; + B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; + B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; + B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; + B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; + B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); + } + } +} void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; @@ -33,64 +96,20 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - - const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); - const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); - const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); - const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); - const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); - - // 0, 1, 16, 17 - uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); - qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; - const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); - const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); - const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); - const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + const uint s_shift = v_im4 + 2*(itid8/4); - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> v_im4) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (v_im4 + 2*(itid8/4)) & 0x3) << 4)) - 32); - barrier(); - - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; - B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; - B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; - B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; - B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; - B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; - B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; - B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], - fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], - fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], - fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], - fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], - fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], - fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], - fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); - } - temp[j][n] = fma(d, sum, temp[j][n]); - } - } - } + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index efbdfe939cd87..55d928e6688e4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -106,9 +106,8 @@ void compute_outputs(const uint first_row, const uint num_rows) { const uint nbr_par_th = num_blocks_per_row%it_size; const uint nbr_all_th = num_blocks_per_row - nbr_par_th; uint i0 = 0; - [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) { + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); - } calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); From 6145fc79e5117959e49a667ea76f72649922e705 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Thu, 9 Jan 2025 21:41:50 -0500 Subject: [PATCH 19/23] q2_k separate out --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 130 ++++++++++-------- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 4 +- 2 files changed, 76 insertions(+), 58 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 129442f7ec924..99db8f32ba20a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -7,6 +7,74 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][16]; +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint itid8, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) { + sccache[ix][0][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8] >> v_im*4) & 0xF); // lower 8 bytes + sccache[ix][1][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8+8] >> v_im*4) & 0xF); // upper 8 bytes + } + barrier(); + + if (i >= num_blocks_per_row) + continue; + } else { + sccache[ix][0][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8] >> v_im*4) & 0xF); + sccache[ix][1][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8+8] >> v_im*4) & 0xF); + barrier(); + } + + const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + const f16vec2 d = data_a[ib0 + i].d; + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; + B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; + B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; + B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; + B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; + B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; + B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; + B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8], + fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9], + fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10], + fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11], + fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12], + fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13], + fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14], + fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15], sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -27,68 +95,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - - sccache[ix][0][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8]), int(v_im*4), 4)); // lower 8 bytes - sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes - barrier(); - - const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); - const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); - const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); - const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); - const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); - - const f16vec2 d = data_a[ib0 + i].d; - const FLOAT_TYPE dall = d.x; - const FLOAT_TYPE dmin = d.y; - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; - B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]; - B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]; - B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]; - B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]; - B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]; - B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]; - B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]; - - FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); - FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ], - fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2], - fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ], - fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2], - fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ], - fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2], - fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ], - fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1)))))))); - sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8], - fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9], - fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10], - fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11], - fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12], - fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13], - fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14], - fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15], sum2)))))))); - } - temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); - } - } - } + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, itid8, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, itid8, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index e395b61438f80..6d655205d454e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -16,7 +16,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co if (!all_threads) { // when we don't have enough blocks to use all threads if (i < num_blocks_per_row) - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | ((data_a[ib0+i].scales[itid8%4+8] >> s_shift & 3) << 4)) - 32); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); barrier(); if (i >= num_blocks_per_row) @@ -38,7 +38,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); if (all_threads) { - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | ((data_a[ib0+i].scales[itid8%4+8] >> s_shift & 3) << 4)) - 32); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); barrier(); } From 845d572b877a94c91b756e8532787f7b9507458f Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Thu, 9 Jan 2025 21:58:26 -0500 Subject: [PATCH 20/23] little stuff --- ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 99db8f32ba20a..9c41c55e54bb4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -19,7 +19,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, if (i < num_blocks_per_row) { sccache[ix][0][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8] >> v_im*4) & 0xF); // lower 8 bytes sccache[ix][1][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8+8] >> v_im*4) & 0xF); // upper 8 bytes - } + } barrier(); if (i >= num_blocks_per_row) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 6d655205d454e..a80764f9ce63e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -6,6 +6,7 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8]; + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { From 30eacad2905e375540fa800dccb4d62b053d3df5 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sat, 11 Jan 2025 17:14:22 -0500 Subject: [PATCH 21/23] use calc_superblock everywhere --- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 161 +++++++------ .../vulkan-shaders/mul_mat_vec_q5_k.comp | 225 +++++++++--------- 2 files changed, 196 insertions(+), 190 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index dc1540a5b7029..f9cde064887a8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -6,6 +6,86 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; + const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; + + const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4)); + const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4)); + const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4)); + const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_lo4.x; + const FLOAT_TYPE q4_1 = qs0_lo4.y; + const FLOAT_TYPE q4_2 = qs0_lo4.z; + const FLOAT_TYPE q4_3 = qs0_lo4.w; + const FLOAT_TYPE q4_4 = qs0_hi4.x; + const FLOAT_TYPE q4_5 = qs0_hi4.y; + const FLOAT_TYPE q4_6 = qs0_hi4.z; + const FLOAT_TYPE q4_7 = qs0_hi4.w; + const FLOAT_TYPE q4_8 = qs64_lo4.x; + const FLOAT_TYPE q4_9 = qs64_lo4.y; + const FLOAT_TYPE q4_10 = qs64_lo4.z; + const FLOAT_TYPE q4_11 = qs64_lo4.w; + const FLOAT_TYPE q4_12 = qs64_hi4.x; + const FLOAT_TYPE q4_13 = qs64_hi4.y; + const FLOAT_TYPE q4_14 = qs64_hi4.z; + const FLOAT_TYPE q4_15 = qs64_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]); + vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); + vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]); + vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -29,91 +109,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - - const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - - const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; - const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; - const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); - const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); - - const FLOAT_TYPE sc0 = scale_0_4_l_f.x; - const FLOAT_TYPE sc1 = scale_0_4_l_f.y; - const FLOAT_TYPE sc2 = scale_0_4_l_f.z; - const FLOAT_TYPE sc3 = scale_0_4_l_f.w; - const FLOAT_TYPE sc4 = scale8_f.x; - const FLOAT_TYPE sc5 = scale8_f.y; - const FLOAT_TYPE sc6 = scale8_f.z; - const FLOAT_TYPE sc7 = scale8_f.w; - - const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; - const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; - - const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; - const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; - const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; - const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; - - const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4)); - const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4)); - const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4)); - const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4)); - - const FLOAT_TYPE q4_0 = qs0_lo4.x; - const FLOAT_TYPE q4_1 = qs0_lo4.y; - const FLOAT_TYPE q4_2 = qs0_lo4.z; - const FLOAT_TYPE q4_3 = qs0_lo4.w; - const FLOAT_TYPE q4_4 = qs0_hi4.x; - const FLOAT_TYPE q4_5 = qs0_hi4.y; - const FLOAT_TYPE q4_6 = qs0_hi4.z; - const FLOAT_TYPE q4_7 = qs0_hi4.w; - const FLOAT_TYPE q4_8 = qs64_lo4.x; - const FLOAT_TYPE q4_9 = qs64_lo4.y; - const FLOAT_TYPE q4_10 = qs64_lo4.z; - const FLOAT_TYPE q4_11 = qs64_lo4.w; - const FLOAT_TYPE q4_12 = qs64_hi4.x; - const FLOAT_TYPE q4_13 = qs64_hi4.y; - const FLOAT_TYPE q4_14 = qs64_hi4.z; - const FLOAT_TYPE q4_15 = qs64_hi4.w; - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]); - vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); - vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]); - vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); - - const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); - const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); - const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); - const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, - fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, - fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, - fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); - } - } - } + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index df2a8bd20913b..6c84ef3cde3ff 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -6,6 +6,118 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + + uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; + uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; + uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + + const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); + const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + + qs0_16_u32_lo4 += qs0_16_lo4_offset16; + qs0_16_u32_hi4 += qs0_16_hi4_offset16; + qs64_80_u32_lo4 += qs64_80_lo4_offset16; + qs64_80_u32_hi4 += qs64_80_hi4_offset16; + + const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); + const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); + const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); + const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_16_lo4.x; + const FLOAT_TYPE q4_1 = qs0_16_lo4.y; + const FLOAT_TYPE q4_2 = qs0_16_lo4.z; + const FLOAT_TYPE q4_3 = qs0_16_lo4.w; + const FLOAT_TYPE q4_4 = qs0_16_hi4.x; + const FLOAT_TYPE q4_5 = qs0_16_hi4.y; + const FLOAT_TYPE q4_6 = qs0_16_hi4.z; + const FLOAT_TYPE q4_7 = qs0_16_hi4.w; + const FLOAT_TYPE q4_8 = qs64_80_lo4.x; + const FLOAT_TYPE q4_9 = qs64_80_lo4.y; + const FLOAT_TYPE q4_10 = qs64_80_lo4.z; + const FLOAT_TYPE q4_11 = qs64_80_lo4.w; + const FLOAT_TYPE q4_12 = qs64_80_hi4.x; + const FLOAT_TYPE q4_13 = qs64_80_hi4.y; + const FLOAT_TYPE q4_14 = qs64_80_hi4.z; + const FLOAT_TYPE q4_15 = qs64_80_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]); + vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]); + vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); + vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); + vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]); + vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]); + vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); + vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); + + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -28,123 +140,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - - const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - - const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; - const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; - const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); - const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); - - const FLOAT_TYPE sc0 = scale_0_4_l_f.x; - const FLOAT_TYPE sc1 = scale_0_4_l_f.y; - const FLOAT_TYPE sc2 = scale_0_4_l_f.z; - const FLOAT_TYPE sc3 = scale_0_4_l_f.w; - const FLOAT_TYPE sc4 = scale8_f.x; - const FLOAT_TYPE sc5 = scale8_f.y; - const FLOAT_TYPE sc6 = scale8_f.z; - const FLOAT_TYPE sc7 = scale8_f.w; - - const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); - const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); - - uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; - uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; - uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; - uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; - - const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); - - const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; - const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; - const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); - const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; - - qs0_16_u32_lo4 += qs0_16_lo4_offset16; - qs0_16_u32_hi4 += qs0_16_hi4_offset16; - qs64_80_u32_lo4 += qs64_80_lo4_offset16; - qs64_80_u32_hi4 += qs64_80_hi4_offset16; - - const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); - const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); - const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); - const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); - - const FLOAT_TYPE q4_0 = qs0_16_lo4.x; - const FLOAT_TYPE q4_1 = qs0_16_lo4.y; - const FLOAT_TYPE q4_2 = qs0_16_lo4.z; - const FLOAT_TYPE q4_3 = qs0_16_lo4.w; - const FLOAT_TYPE q4_4 = qs0_16_hi4.x; - const FLOAT_TYPE q4_5 = qs0_16_hi4.y; - const FLOAT_TYPE q4_6 = qs0_16_hi4.z; - const FLOAT_TYPE q4_7 = qs0_16_hi4.w; - const FLOAT_TYPE q4_8 = qs64_80_lo4.x; - const FLOAT_TYPE q4_9 = qs64_80_lo4.y; - const FLOAT_TYPE q4_10 = qs64_80_lo4.z; - const FLOAT_TYPE q4_11 = qs64_80_lo4.w; - const FLOAT_TYPE q4_12 = qs64_80_hi4.x; - const FLOAT_TYPE q4_13 = qs64_80_hi4.y; - const FLOAT_TYPE q4_14 = qs64_80_hi4.z; - const FLOAT_TYPE q4_15 = qs64_80_hi4.w; - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]); - vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]); - vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); - vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); - vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]); - vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]); - vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); - vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); - - const FLOAT_TYPE sx = - fma(FLOAT_TYPE(by10.x), q4_0, - fma(FLOAT_TYPE(by10.y), q4_1, - fma(FLOAT_TYPE(by116.x), q4_2, - FLOAT_TYPE(by116.y) * q4_3))); - const FLOAT_TYPE sy = - fma(FLOAT_TYPE(by132.x), q4_4, - fma(FLOAT_TYPE(by132.y), q4_5, - fma(FLOAT_TYPE(by148.x), q4_6, - FLOAT_TYPE(by148.y) * q4_7))); - const FLOAT_TYPE sz = - fma(FLOAT_TYPE(by20.x), q4_8, - fma(FLOAT_TYPE(by20.y), q4_9, - fma(FLOAT_TYPE(by216.x), q4_10, - FLOAT_TYPE(by216.y) * q4_11))); - const FLOAT_TYPE sw = - fma(FLOAT_TYPE(by232.x), q4_12, - fma(FLOAT_TYPE(by232.y), q4_13, - fma(FLOAT_TYPE(by248.x), q4_14, - FLOAT_TYPE(by248.y) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, - fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, - fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, - (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); - } - } - } + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); reduce_result(temp, d_offset, first_row, num_rows, tid); } From ed1ad94c8411213d7a22eedfa47fd7e205783e90 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sat, 11 Jan 2025 21:46:53 -0500 Subject: [PATCH 22/23] q2_k optimize scale calculation --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 52 ++++++++++--------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 45d6a830c28cc..afa109262d9b3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -5,11 +5,12 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][16]; +shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16]; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; -void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint itid8, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { @@ -17,16 +18,18 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, if (!all_threads) { // when we don't have enough blocks to use all threads if (i < num_blocks_per_row) { - sccache[ix][0][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8] >> v_im*4) & 0xF); // lower 8 bytes - sccache[ix][1][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8+8] >> v_im*4) & 0xF); // upper 8 bytes + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); } barrier(); if (i >= num_blocks_per_row) continue; } else { - sccache[ix][0][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8] >> v_im*4) & 0xF); - sccache[ix][1][itid] = FLOAT_TYPE((data_a[ib0 + i].scales[itid8+8] >> v_im*4) & 0xF); + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); barrier(); } @@ -53,22 +56,22 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * qs_u32_0[l ], - fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * qs_u32_0[l+2], - fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * qs_u32_2[l ], - fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * qs_u32_2[l+2], - fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * qs_u32_4[l ], - fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * qs_u32_4[l+2], - fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * qs_u32_6[l ], - fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * qs_u32_6[l+2], sum1)))))))); - sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8], - fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9], - fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10], - fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11], - fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12], - fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13], - fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14], - fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15], sum2)))))))); + sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im], + fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im], + fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im], + fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im], + fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im], + fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im], + fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im], + fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2)))))))); } temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); } @@ -86,7 +89,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; const uint itid = tid%16; // 0...15 const uint ix = tid/16; - const uint itid8 = itid%8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_in = itid - 8*v_im; // 0...7 @@ -105,8 +107,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint nbr_all_th = num_blocks_per_row - nbr_par_th; uint i0 = 0; [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) - calc_superblock(a_offset, b_offset, itid, itid8, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); - calc_superblock(a_offset, b_offset, itid, itid8, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } From 4ae3fc01552cb3a3a2f210fafb7f277c41cb906f Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sun, 12 Jan 2025 17:21:57 -0500 Subject: [PATCH 23/23] more barriers --- ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp | 1 + ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp | 2 ++ ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp | 2 ++ 3 files changed, 5 insertions(+) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index afa109262d9b3..8cdc640e80e31 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -16,6 +16,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + barrier(); if (!all_threads) { // when we don't have enough blocks to use all threads if (i < num_blocks_per_row) { const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 58be53bb8cd05..3116fad165be0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -16,6 +16,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; if (!all_threads) { // when we don't have enough blocks to use all threads + barrier(); if (i < num_blocks_per_row) sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); barrier(); @@ -39,6 +40,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); if (all_threads) { + barrier(); sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); barrier(); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index b362c46bb07c5..f05f96b5efb9d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -17,6 +17,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; if (!all_threads) { // when we don't have enough blocks to use all threads + barrier(); if (i < num_blocks_per_row) sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); barrier(); @@ -50,6 +51,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const vec4 q3 = vec4(unpack8(q3_u32)) - 32; if (all_threads) { + barrier(); sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); barrier(); }