@@ -346,8 +346,8 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
346346#endif // MMQ_SHMEM
347347#endif
348348
349- #if defined(DATA_A_Q4_K)
350- // 4-byte loads for Q4_K blocks (144 bytes)
349+ #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
350+ // 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
351351ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
352352 return ACC_TYPE(dsb.x * dma.x * float (q_sum) - dma.y * dsb.y);
353353}
@@ -361,10 +361,19 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
361361 const uint qs_shift = ((iqs_k % 16 ) / 8 ) * 4 ;
362362
363363 // Repack 2x4 quants into one int
364+ #if defined(DATA_A_Q4_K)
364365 const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
365366 const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1 ] >> qs_shift) & 0x0F0F0F0F;
366367
367368 buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4 );
369+ #else // defined(DATA_A_Q5_K)
370+ const uint qh_idx = iqs * QUANT_R_MMQ;
371+ const uint qh_shift = iqs_k / 8 ;
372+
373+ buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
374+ (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4 ));
375+ #endif
376+
368377
369378 if (iqs == 0 ) {
370379 // Scale index
@@ -384,7 +393,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
384393void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
385394 cache_a[reg_ib].dm = buf_a[buf_ib].dm;
386395
387- [[unroll]] for (uint iqs = 0 ; iqs < 4 ; iqs++ ) {
396+ [[unroll]] for (uint iqs = 0 ; iqs < 8 / QUANT_R_MMQ ; iqs++ ) {
388397 cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
389398 }
390399}
@@ -393,7 +402,11 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
393402 int32_t q_sum = 0 ;
394403
395404 [[unroll]] for (uint iqs = 0 ; iqs < 8 ; iqs++ ) {
405+ #if defined(DATA_A_Q4_K)
396406 const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2 ] >> ((iqs % 2 ) * 4 )) & 0x0F0F0F0F);
407+ #else // defined(DATA_A_Q5_K)
408+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
409+ #endif
397410
398411 q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
399412 }
0 commit comments