Skip to content

Commit 777a18e

Browse files
committed
vulkan: add mmq q2_k integer dot support
1 parent 97870e6 commit 777a18e

File tree

10 files changed

+128
-26
lines changed

10 files changed

+128
-26
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2936,6 +2936,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
29362936
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
29372937
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
29382938
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2939+
2940+
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
29392941
}
29402942
#endif
29412943

@@ -3055,6 +3057,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
30553057
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
30563058
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
30573059
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3060+
3061+
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
30583062
}
30593063
#endif
30603064

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
488488

489489
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
490490
const uint scales = data_a[a_offset + ib].scales[scalesi];
491-
const vec2 d = vec2(data_a[a_offset + ib].d);
491+
const vec2 dm = vec2(data_a[a_offset + ib].dm);
492492

493-
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
493+
return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
494494
}
495495
vec2 get_dm(uint ib, uint a_offset) {
496496
return vec2(1, 0);

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
108108
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
109109
{
110110
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
111-
const f16vec2 d = bl.block.d;
111+
const f16vec2 dm = bl.block.dm;
112112
const uint idx = coordInBlock[1];
113113

114114
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
@@ -119,7 +119,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
119119
qs = unpack8(qs)[idx & 1];
120120

121121
const uint scales = bl.block.scales[scalesi];
122-
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
122+
float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
123123
return ret;
124124
}
125125

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ void main() {
2424
const uint ql_idx = 32 * ip + il;
2525
const uint8_t qs = data_a[i].qs[32 * ip + il];
2626

27-
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
28-
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
27+
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
28+
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
2929
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
3030
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
3131
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
4141
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
4242
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
4343

44-
vec2 d = vec2(data_a[ib0 + i].d);
45-
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
46-
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
44+
vec2 dm = vec2(data_a[ib0 + i].dm);
45+
const FLOAT_TYPE dall = FLOAT_TYPE(dm.x);
46+
const FLOAT_TYPE dmin = FLOAT_TYPE(dm.y);
4747

4848
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
4949
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
134134
const uint ib = idx / 128; // 2 values per idx
135135
const uint iqs = idx % 128; // 0..127
136136

137-
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
137+
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
138138
const uint scalesi = iqs / 8; // 0..15
139139
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
140140

141-
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
141+
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
142142
const uint scales = data_a[ib].scales[scalesi];
143-
const vec2 d = vec2(data_a[ib].d);
143+
const vec2 dm = vec2(data_a[ib].dm);
144144

145-
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
145+
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
146146

147147
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
148148
#elif defined(DATA_A_Q3_K)

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424

2525
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
2626

27-
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
27+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
28+
#if defined(A_TYPE_PACKED16)
29+
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
30+
#endif
2831
#if defined(A_TYPE_PACKED32)
2932
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
3033
#endif
@@ -84,6 +87,11 @@ layout (constant_id = 10) const uint WARP = 32;
8487

8588
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
8689

90+
#ifdef DATA_A_QUANT_K
91+
#define SHMEM_SCALES_STRIDE (SCALES_PER_32 + 1)
92+
shared uint8_t buf_a_scales[BM * SHMEM_SCALES_STRIDE];
93+
#endif
94+
8795
#ifndef COOPMAT
8896
#if QUANT_AUXF == 1
8997
shared FLOAT_TYPE buf_a_dm[BM];
@@ -224,6 +232,10 @@ void main() {
224232
#else
225233
int32_t cache_a_qs[WMITER * TM * BK / 4];
226234

235+
#ifdef DATA_A_QUANT_K
236+
uint8_t cache_a_scales[WMITER * TM * SCALES_PER_32];
237+
#endif
238+
227239
int32_t cache_b_qs[TN * BK / 4];
228240

229241
ACC_TYPE sums[WMITER * TM * WNITER * TN];
@@ -243,9 +255,9 @@ void main() {
243255

244256
for (uint block = start_k; block < end_k; block += BK) {
245257
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
246-
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
247-
const uint iqs = loadr_a;
248258
const uint buf_ib = loadc_a + l;
259+
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
260+
const uint iqs = loadr_a;
249261

250262
if (iqs == 0) {
251263
#if QUANT_AUXF == 1
@@ -261,6 +273,12 @@ void main() {
261273
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
262274
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
263275
#endif
276+
277+
#ifdef DATA_A_QUANT_K
278+
if (iqs % 4 == 0) {
279+
buf_a_scales[buf_ib * SHMEM_SCALES_STRIDE + iqs / 4] = get_scale(ib, iqs);
280+
}
281+
#endif
264282
}
265283
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
266284
#ifdef MUL_MAT_ID
@@ -333,6 +351,11 @@ void main() {
333351
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
334352
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
335353
}
354+
#ifdef DATA_A_QUANT_K
355+
[[unroll]] for (uint s = 0; s < SCALES_PER_32; s++) {
356+
cache_a_scales[(wsir * TM + cr) * SCALES_PER_32 + s] = buf_a_scales[ib * SHMEM_SCALES_STRIDE + s];
357+
}
358+
#endif
336359
}
337360
}
338361

@@ -350,13 +373,45 @@ void main() {
350373
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
351374
const uint cache_a_idx = wsir * TM + cr;
352375
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
376+
377+
#if defined(DATA_A_QUANT_LEGACY)
353378
int32_t q_sum = 0;
354379
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
355380
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
356381
cache_b_qs[cc * (BK / 4) + idx_k]);
357382
}
358383

359384
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
385+
#elif defined(DATA_A_QUANT_K)
386+
int32_t sum_d = 0;
387+
int32_t sum_m = 0;
388+
389+
const int32_t scale0 = cache_a_scales[cache_a_idx * SCALES_PER_32];
390+
const int32_t scale1 = cache_a_scales[cache_a_idx * SCALES_PER_32 + 1];
391+
int32_t scale_m = scale0 >> 4;
392+
scale_m |= scale_m << 8;
393+
scale_m |= scale_m << 16;
394+
395+
[[unroll]] for (uint idx_k = 0; idx_k < BK / 8; idx_k++) {
396+
sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
397+
cache_b_qs[cc * (BK / 4) + idx_k]) * (scale0 & 0xF);
398+
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]);
399+
}
400+
401+
scale_m = scale1 >> 4;
402+
scale_m |= scale_m << 8;
403+
scale_m |= scale_m << 16;
404+
405+
[[unroll]] for (uint idx_k = BK / 8; idx_k < BK / 4; idx_k++) {
406+
sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
407+
cache_b_qs[cc * (BK / 4) + idx_k]) * (scale1 & 0xF);
408+
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]);
409+
}
410+
411+
sums[sums_idx] += mul_q8_1(sum_d, sum_m, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
412+
#else
413+
#error unsupported
414+
#endif
360415
}
361416
}
362417
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#if defined(DATA_A_Q4_0)
1010
i32vec2 repack(uint ib, uint iqs) {
1111
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
12-
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
13-
data_a[ib].qs[iqs * 2 + 1]);
12+
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
13+
data_a_packed16[ib].qs[iqs * 2 + 1]);
1414
const uint32_t vui = pack32(quants);
1515
return i32vec2( vui & 0x0F0F0F0F,
1616
(vui >> 4) & 0x0F0F0F0F);
@@ -37,8 +37,8 @@ ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int
3737
#if defined(DATA_A_Q5_0)
3838
i32vec2 repack(uint ib, uint iqs) {
3939
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
40-
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
41-
data_a[ib].qs[iqs * 2 + 1]);
40+
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
41+
data_a_packed16[ib].qs[iqs * 2 + 1]);
4242
const uint32_t vui = pack32(quants);
4343
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
4444
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
@@ -77,15 +77,40 @@ ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int
7777
#if defined(DATA_A_Q8_0)
7878
int32_t repack(uint ib, uint iqs) {
7979
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
80-
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
81-
data_a[ib].qs[iqs * 2 + 1]));
80+
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
81+
data_a_packed16[ib].qs[iqs * 2 + 1]));
8282
}
8383

8484
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
8585
return ACC_TYPE(float(q_sum) * da * dsb.x);
8686
}
8787
#endif
8888

89+
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
90+
// iqs still refers to a 32-bit integer, meaning 0..r for 32-wide quants
91+
#if defined(DATA_A_Q2_K)
92+
int32_t repack(uint ib, uint iqs) {
93+
const uint ib_k = ib / 8;
94+
const uint iqs_k = (ib % 8) * 8 + iqs;
95+
96+
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
97+
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
98+
99+
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
100+
}
101+
102+
uint8_t get_scale(uint ib, uint iqs) {
103+
const uint ib_k = ib / 8;
104+
const uint iqs_k = (ib % 8) * 8 + iqs;
105+
106+
return data_a[ib_k].scales[iqs_k / 4];
107+
}
108+
109+
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
110+
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
111+
}
112+
#endif
113+
89114
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
90115
FLOAT_TYPE get_d(uint ib) {
91116
return FLOAT_TYPE(data_a[ib].d);
@@ -103,3 +128,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
103128
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
104129
}
105130
#endif
131+
132+
#if defined(DATA_A_Q2_K)
133+
FLOAT_TYPE_VEC2 get_dm(uint ib) {
134+
const uint ib_k = ib / 8;
135+
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
136+
}
137+
#endif

ggml/src/ggml-vulkan/vulkan-shaders/types.glsl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct block_q4_0_packed16
6666
#define QUANT_AUXF 1
6767
#define A_TYPE block_q4_0
6868
#define A_TYPE_PACKED16 block_q4_0_packed16
69+
#define DATA_A_QUANT_LEGACY
6970
#endif
7071

7172
#define QUANT_K_Q4_1 32
@@ -98,6 +99,7 @@ struct block_q4_1_packed32
9899
#define A_TYPE block_q4_1
99100
#define A_TYPE_PACKED16 block_q4_1_packed16
100101
#define A_TYPE_PACKED32 block_q4_1_packed32
102+
#define DATA_A_QUANT_LEGACY
101103
#endif
102104

103105
#define QUANT_K_Q5_0 32
@@ -123,6 +125,7 @@ struct block_q5_0_packed16
123125
#define QUANT_AUXF 1
124126
#define A_TYPE block_q5_0
125127
#define A_TYPE_PACKED16 block_q5_0_packed16
128+
#define DATA_A_QUANT_LEGACY
126129
#endif
127130

128131
#define QUANT_K_Q5_1 32
@@ -158,6 +161,7 @@ struct block_q5_1_packed32
158161
#define A_TYPE block_q5_1
159162
#define A_TYPE_PACKED16 block_q5_1_packed16
160163
#define A_TYPE_PACKED32 block_q5_1_packed32
164+
#define DATA_A_QUANT_LEGACY
161165
#endif
162166

163167
#define QUANT_K_Q8_0 32
@@ -186,6 +190,7 @@ struct block_q8_0_packed32
186190
#define A_TYPE block_q8_0
187191
#define A_TYPE_PACKED16 block_q8_0_packed16
188192
#define A_TYPE_PACKED32 block_q8_0_packed32
193+
#define DATA_A_QUANT_LEGACY
189194
#endif
190195

191196
#define QUANT_K_Q8_1 32
@@ -226,21 +231,21 @@ struct block_q2_K
226231
{
227232
uint8_t scales[QUANT_K_Q2_K/16];
228233
uint8_t qs[QUANT_K_Q2_K/4];
229-
f16vec2 d;
234+
f16vec2 dm;
230235
};
231236

232237
struct block_q2_K_packed16
233238
{
234239
uint16_t scales[QUANT_K_Q2_K/16/2];
235240
uint16_t qs[QUANT_K_Q2_K/4/2];
236-
f16vec2 d;
241+
f16vec2 dm;
237242
};
238243

239244
struct block_q2_K_packed32
240245
{
241246
uint32_t scales[QUANT_K_Q2_K/16/4];
242247
uint32_t qs[QUANT_K_Q2_K/4/4];
243-
f16vec2 d;
248+
f16vec2 dm;
244249
};
245250

246251
#if defined(DATA_A_Q2_K)
@@ -249,6 +254,8 @@ struct block_q2_K_packed32
249254
#define A_TYPE block_q2_K
250255
#define A_TYPE_PACKED16 block_q2_K_packed16
251256
#define A_TYPE_PACKED32 block_q2_K_packed32
257+
#define SCALES_PER_32 2
258+
#define DATA_A_QUANT_K
252259
#endif
253260

254261
#define QUANT_K_Q3_K 256
@@ -274,6 +281,7 @@ struct block_q3_K_packed16
274281
#define QUANT_R 1
275282
#define A_TYPE block_q3_K
276283
#define A_TYPE_PACKED16 block_q3_K_packed16
284+
#define DATA_A_QUANT_K
277285
#endif
278286

279287
#define QUANT_K_Q4_K 256
@@ -310,6 +318,7 @@ struct block_q4_K_packed128
310318
#define A_TYPE block_q4_K
311319
#define A_TYPE_PACKED16 block_q4_K_packed16
312320
#define A_TYPE_PACKED32 block_q4_K_packed32
321+
#define DATA_A_QUANT_K
313322
#endif
314323

315324
#define QUANT_K_Q5_K 256
@@ -340,6 +349,7 @@ struct block_q5_K_packed128
340349
#define QUANT_R 1
341350
#define A_TYPE block_q5_K
342351
#define A_TYPE_PACKED16 block_q5_K_packed16
352+
#define DATA_A_QUANT_K
343353
#endif
344354

345355
#define QUANT_K_Q6_K 256
@@ -365,6 +375,7 @@ struct block_q6_K_packed16
365375
#define QUANT_R 1
366376
#define A_TYPE block_q6_K
367377
#define A_TYPE_PACKED16 block_q6_K_packed16
378+
#define DATA_A_QUANT_K
368379
#endif
369380

370381
// IQuants

0 commit comments

Comments
 (0)