@@ -440,6 +440,65 @@ void block_b_to_registers(const uint ib) {
440440}
441441#endif
442442
443+ #if defined(DATA_A_Q6_K)
444+ // 2-byte loads for Q6_K blocks (210 bytes)
445+ #ifdef MMQ_SHMEM
446+ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
447+ const uint ib_k = ib / 8 ;
448+ const uint iqs_k = (ib % 8 ) * 8 + iqs;
449+
450+ const uint ql_idx = (iqs_k / 32 ) * 16 + iqs_k % 16 ;
451+ const uint ql_shift = ((iqs_k % 32 ) / 16 ) * 4 ;
452+
453+ const uint qh_idx = (iqs_k / 32 ) * 8 + iqs;
454+ const uint qh_shift = ((iqs_k % 32 ) / 8 ) * 2 ;
455+
456+ const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
457+ unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4 ))) - int8_t(32 );
458+ const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1 ] >> ql_shift) & uint16_t(0x0F0F))) |
459+ unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1 ] >> qh_shift) & uint16_t(0x0303)) << 4 ))) - int8_t(32 );
460+ buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
461+
462+ if (iqs == 0 ) {
463+ const uint is = iqs_k / 4 ;
464+ const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2 ]);
465+
466+ buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
467+ }
468+ }
469+
470+ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
471+ cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
472+
473+ [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
474+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
475+ }
476+ }
477+
478+ ACC_TYPE mmq_dot_product(const uint ib_a) {
479+ float result = 0.0 ;
480+ int32_t q_sum = 0 ;
481+
482+ [[unroll]] for (uint iqs = 0 ; iqs < 4 ; iqs++ ) {
483+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
484+
485+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
486+ }
487+ result += float (cache_a[ib_a].d_scales[0 ]) * float (q_sum);
488+ q_sum = 0 ;
489+
490+ [[unroll]] for (uint iqs = 4 ; iqs < 8 ; iqs++ ) {
491+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
492+
493+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
494+ }
495+ result += float (cache_a[ib_a].d_scales[1 ]) * float (q_sum);
496+
497+ return ACC_TYPE(cache_b.ds.x * result);
498+ }
499+ #endif // MMQ_SHMEM
500+ #endif
501+
443502#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
444503FLOAT_TYPE get_d(uint ib) {
445504 return FLOAT_TYPE(data_a[ib].d);
0 commit comments