11#version 450
22
33#extension GL_EXT_shader_explicit_arithmetic_types : require
4- #extension GL_KHR_shader_subgroup_arithmetic : require
5- #extension GL_KHR_shader_subgroup_shuffle : require
6- #extension GL_EXT_shader_subgroup_extended_types_int16 : require
4+ #extension GL_KHR_shader_subgroup_arithmetic: require
5+ #extension GL_KHR_shader_subgroup_shuffle: require
76
87#include "mul_mat_vec_base.comp"
98
@@ -12,11 +11,7 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1211layout (constant_id = 0) const uint BLOCK_SIZE = 32;
1312layout (constant_id = 1) const uint NUM_ROWS = 1;
1413
15- uint16_t blk[BLOCK_SIZE/16][8];
16-
17- uint16_t get_blk_shuffle(uint fbi, uint ix, uint ofst) {
18- return subgroupShuffle(blk[ix][ofst/(104/fbi)], ofst%(104/fbi));
19- }
14+ shared block_q6_K_packed16 blkcache[BLOCK_SIZE/16];
2015
2116uint fill_blkcache_its(uint wg_size) {
2217 // subgroup sizes are always a power of 2
@@ -36,7 +31,7 @@ void fill_blkcache(const int num_blocks, const uint ib0, const uint i0, const ui
3631 [[unroll]] for (int l = 0; l < num_blocks; ++l) {
3732 [[unroll]] for (int m = 0; m < fbi; ++m)
3833 // cache full superblock into shared memory with coalesced reads
39- blk [l][m ] = data_a_packed16[ib0 + i0 + l].blk[tid + m*bc_t];
34+ blkcache [l].blk[tid + m*bc_t ] = data_a_packed16[ib0 + i0 + l].blk[tid + m*bc_t];
4035 }
4136 }
4237}
@@ -91,7 +86,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
9186 fill_blkcache(blim, ib0, i0, tid, fbi);
9287 }
9388
94- FLOAT_TYPE sccache = FLOAT_TYPE(int8_t(bitfieldExtract(get_blk_shuffle(fbi, ix, 96 + itid/2) , int(bcs_offset), 8)));
89+ FLOAT_TYPE sccache = FLOAT_TYPE(int8_t(bitfieldExtract(blkcache[ix].blk[ 96 + itid/2] , int(bcs_offset), 8)));
9590 barrier();
9691
9792 ibi += num_blocks_per_row;
@@ -100,15 +95,15 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
10095
10196 const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib0 + i].d);
10297
103- uint32_t ql0_u32 = uint32_t(get_blk_shuffle(fbi, ix, ql_offset / 2)) | (uint32_t(get_blk_shuffle(fbi, ix, ql_offset / 2 + 1) ) << 16);
104- uint32_t ql32_u32 = uint32_t(get_blk_shuffle(fbi, ix, ql_offset / 2 + 16)) | (uint32_t(get_blk_shuffle(fbi, ix, ql_offset / 2 + 17) ) << 16);
98+ uint32_t ql0_u32 = uint32_t(blkcache[ix].blk[ ql_offset / 2]) | (uint32_t(blkcache[ix].blk[ ql_offset / 2 + 1] ) << 16);
99+ uint32_t ql32_u32 = uint32_t(blkcache[ix].blk[ ql_offset / 2 + 16]) | (uint32_t(blkcache[ix].blk[ ql_offset / 2 + 17] ) << 16);
105100
106101 uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
107102 uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
108103 uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
109104 uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
110105
111- uint32_t qh_u32 = uint32_t(get_blk_shuffle(fbi, ix, 64 + qh_offset / 2)) | (uint32_t(get_blk_shuffle(fbi, ix, 64 + qh_offset / 2 + 1) ) << 16);
106+ uint32_t qh_u32 = uint32_t(blkcache[ix].blk[ 64 + qh_offset / 2]) | (uint32_t(blkcache[ix].blk[ 64 + qh_offset / 2 + 1] ) << 16);
112107 uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
113108 uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
114109 uint32_t qh4_u32 = (qh_u32 & 0x30303030);
0 commit comments