Skip to content

Commit bdd1e4d

Browse files
committed
support different subgroup sizes (tested)
1 parent 5641108 commit bdd1e4d

File tree

2 files changed

+38
-28
lines changed

2 files changed

+38
-28
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,7 +1877,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
18771877
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
18781878
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
18791879
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
1880-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {64, rm_kq}, 1, true);
1880+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
18811881
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
18821882

18831883
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -1891,7 +1891,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
18911891
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
18921892
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
18931893
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
1894-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {64, rm_kq}, 1, true);
1894+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
18951895
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
18961896

18971897
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -1905,7 +1905,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
19051905
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
19061906
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
19071907
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
1908-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {64, rm_kq}, 1, true);
1908+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
19091909
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
19101910

19111911
// dequant shaders

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,29 @@ shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
1313
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
1414
shared block_q6_K_packed16 blkcache[BLOCK_SIZE/16];
1515

16+
uint fill_blkcache_its(uint wg_size) {
17+
// subgroup sizes are always a power of 2
18+
if (wg_size > 64)
19+
return 1;
20+
else if (wg_size == 64)
21+
return 2;
22+
else if (wg_size == 32)
23+
return 4;
24+
else
25+
return 8;
26+
}
27+
28+
void fill_blkcache(const int num_blocks, const uint ib0, const uint i0, const uint tid, const uint fbi) {
29+
uint bc_t = 104 / fbi;
30+
if (tid < bc_t) {
31+
[[unroll]] for (int l = 0; l < num_blocks; ++l) {
32+
[[unroll]] for (int m = 0; m < fbi; ++m)
33+
// cache full superblock into shared memory with coalesced reads
34+
blkcache[l].blk[tid + m*bc_t] = data_a_packed16[ib0 + i0 + l].blk[tid + m*bc_t];
35+
}
36+
}
37+
}
38+
1639
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1740
uint a_offset, b_offset, d_offset;
1841
get_offsets(a_offset, b_offset, d_offset);
@@ -24,6 +47,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
2447
const uint tid = gl_LocalInvocationID.x;
2548
const uint itid = tid%16; // 0...15
2649
const uint ix = tid/16;
50+
const uint fbi = fill_blkcache_its(gl_WorkGroupSize.x);
2751

2852
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
2953
const uint v_in = itid - 8*v_im; // 0...15 or 0...7
@@ -38,10 +62,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
3862
const uint bcs_offset = (itid%2 == 1) ? 8 : 0;
3963

4064
FLOAT_TYPE temp[NUM_ROWS];
41-
42-
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
65+
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i)
4366
temp[i] = FLOAT_TYPE(0);
44-
}
4567

4668
[[unroll]] for (uint i0 = 0; i0 < num_blocks_per_row; i0 += it_size) {
4769
uint i = i0 + ix; // 16 thread group specific counter
@@ -55,33 +77,23 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
5577
uint ibi = first_row*num_blocks_per_row;
5678
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
5779
const uint ib0 = a_offset / QUANT_K + ibi;
58-
ibi += num_blocks_per_row;
80+
const int blim = min(int(num_blocks_per_row) - int(i0), int(it_size));
5981

60-
// cache full superblock into shared memory with coalesced reads
61-
// we assume 64 threads here!
62-
const int blim = min(int(num_blocks_per_row) - int(i0), 4);
63-
// this is required as this loop is super sensitive to unrolling with hardcoded 4
64-
if (blim == 4) {
65-
if (tid < 52) {
66-
[[unroll]] for (int l = 0; l < 4; ++l) {
67-
blkcache[l].blk[tid] = data_a_packed16[ib0 + i0 + l].blk[tid];
68-
blkcache[l].blk[tid + 52] = data_a_packed16[ib0 + i0 + l].blk[tid + 52];
69-
}
70-
}
82+
// fill_blkcache is sensitive to unrolling with hardcoded it_size
83+
if (blim == it_size) {
84+
fill_blkcache(int(it_size), ib0, i0, tid, fbi);
7185
} else {
72-
if (tid < 52) {
73-
[[unroll]] for (int l = 0; l < blim; ++l) {
74-
blkcache[l].blk[tid] = data_a_packed16[ib0 + i0 + l].blk[tid];
75-
blkcache[l].blk[tid + 52] = data_a_packed16[ib0 + i0 + l].blk[tid + 52];
76-
}
77-
}
86+
fill_blkcache(blim, ib0, i0, tid, fbi);
7887
}
88+
7989
sccache[ix][itid] = FLOAT_TYPE(int8_t(bitfieldExtract(blkcache[ix].blk[96 + itid/2], int(bcs_offset), 8)));
8090
barrier();
91+
92+
ibi += num_blocks_per_row;
8193
if (i >= num_blocks_per_row)
8294
continue;
8395

84-
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
96+
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib0 + i].d);
8597

8698
uint32_t ql0_u32 = uint32_t(blkcache[ix].blk[ql_offset / 2]) | (uint32_t(blkcache[ix].blk[ql_offset / 2 + 1]) << 16);
8799
uint32_t ql32_u32 = uint32_t(blkcache[ix].blk[ql_offset / 2 + 16]) | (uint32_t(blkcache[ix].blk[ql_offset / 2 + 17]) << 16);
@@ -115,9 +127,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
115127
sum[3] = fma(FLOAT_TYPE(by96[l]), FLOAT_TYPE(int8_t(q3[l]) - 32), sum[3]);
116128
}
117129

118-
[[unroll]] for (uint l = 0; l < 4; ++l)
119-
sum[l] *= sccache[ix][s_offset + l*2];
120-
temp[n] += (sum[0] + sum[1] + sum[2] + sum[3]) * d;
130+
temp[n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[n]);
121131
}
122132
}
123133

0 commit comments

Comments
 (0)