2424
2525layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
2626
27- layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
27+ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
28+ #if defined(A_TYPE_PACKED16)
29+ layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
30+ #endif
2831#if defined(A_TYPE_PACKED32)
2932layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
3033#endif
@@ -84,6 +87,11 @@ layout (constant_id = 10) const uint WARP = 32;
8487
8588shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
8689
90+ #ifdef DATA_A_QUANT_K
91+ #define SHMEM_SCALES_STRIDE (SCALES_PER_32 + 1)
92+ shared uint8_t buf_a_scales[BM * SHMEM_SCALES_STRIDE];
93+ #endif
94+
8795#ifndef COOPMAT
8896#if QUANT_AUXF == 1
8997shared FLOAT_TYPE buf_a_dm[BM];
@@ -224,6 +232,10 @@ void main() {
224232#else
225233 int32_t cache_a_qs[WMITER * TM * BK / 4];
226234
235+ #ifdef DATA_A_QUANT_K
236+ uint8_t cache_a_scales[WMITER * TM * SCALES_PER_32];
237+ #endif
238+
227239 int32_t cache_b_qs[TN * BK / 4];
228240
229241 ACC_TYPE sums[WMITER * TM * WNITER * TN];
@@ -243,9 +255,9 @@ void main() {
243255
244256 for (uint block = start_k; block < end_k; block += BK) {
245257 [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
246- const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
247- const uint iqs = loadr_a;
248258 const uint buf_ib = loadc_a + l;
259+ const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
260+ const uint iqs = loadr_a;
249261
250262 if (iqs == 0) {
251263#if QUANT_AUXF == 1
@@ -261,6 +273,12 @@ void main() {
261273 buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
262274 buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
263275#endif
276+
277+ #ifdef DATA_A_QUANT_K
278+ if (iqs % 4 == 0) {
279+ buf_a_scales[buf_ib * SHMEM_SCALES_STRIDE + iqs / 4] = get_scale(ib, iqs);
280+ }
281+ #endif
264282 }
265283 [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
266284#ifdef MUL_MAT_ID
@@ -333,6 +351,11 @@ void main() {
333351 [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
334352 cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
335353 }
354+ #ifdef DATA_A_QUANT_K
355+ [[unroll]] for (uint s = 0; s < SCALES_PER_32; s++) {
356+ cache_a_scales[(wsir * TM + cr) * SCALES_PER_32 + s] = buf_a_scales[ib * SHMEM_SCALES_STRIDE + s];
357+ }
358+ #endif
336359 }
337360 }
338361
@@ -350,13 +373,45 @@ void main() {
350373 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
351374 const uint cache_a_idx = wsir * TM + cr;
352375 const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
376+
377+ #if defined(DATA_A_QUANT_LEGACY)
353378 int32_t q_sum = 0;
354379 [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
355380 q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
356381 cache_b_qs[cc * (BK / 4) + idx_k]);
357382 }
358383
359384 sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
385+ #elif defined(DATA_A_QUANT_K)
386+ int32_t sum_d = 0;
387+ int32_t sum_m = 0;
388+
389+ const int32_t scale0 = cache_a_scales[cache_a_idx * SCALES_PER_32];
390+ const int32_t scale1 = cache_a_scales[cache_a_idx * SCALES_PER_32 + 1];
391+ int32_t scale_m = scale0 >> 4;
392+ scale_m |= scale_m << 8;
393+ scale_m |= scale_m << 16;
394+
395+ [[unroll]] for (uint idx_k = 0; idx_k < BK / 8; idx_k++) {
396+ sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
397+ cache_b_qs[cc * (BK / 4) + idx_k]) * (scale0 & 0xF);
398+ sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]);
399+ }
400+
401+ scale_m = scale1 >> 4;
402+ scale_m |= scale_m << 8;
403+ scale_m |= scale_m << 16;
404+
405+ [[unroll]] for (uint idx_k = BK / 8; idx_k < BK / 4; idx_k++) {
406+ sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
407+ cache_b_qs[cc * (BK / 4) + idx_k]) * (scale1 & 0xF);
408+ sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]);
409+ }
410+
411+ sums[sums_idx] += mul_q8_1(sum_d, sum_m, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
412+ #else
413+ #error unsupported
414+ #endif
360415 }
361416 }
362417 }
0 commit comments