@@ -18,6 +18,21 @@ typedef struct {
1818 uint8_t qs[QK4_1 / 2 ]; // nibbles / quants
1919} block_q4_1;
2020
21+ #define QK5_0 32
22+ typedef struct {
23+ half d; // delta
24+ uint8_t qh[4 ]; // 5-th bit of quants
25+ uint8_t qs[QK5_0 / 2 ]; // nibbles / quants
26+ } block_q5_0;
27+
28+ #define QK5_1 32
29+ typedef struct {
30+ half d; // delta
31+ half m; // min
32+ uint8_t qh[4 ]; // 5-th bit of quants
33+ uint8_t qs[QK5_1 / 2 ]; // nibbles / quants
34+ } block_q5_1;
35+
2136#define QK8_0 32
2237typedef struct {
2338 half d; // delta
@@ -399,8 +414,11 @@ kernel void kernel_rms_norm(
399414// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
400415inline float block_q_n_dot_y (device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
401416 float d = qb_curr->d ;
417+
402418 float2 acc = 0 .f ;
419+
403420 device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2 );
421+
404422 for (int i = 0 ; i < 8 ; i+=2 ) {
405423 acc[0 ] += yl[i + 0 ] * (qs[i / 2 ] & 0x000F )
406424 + yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
@@ -417,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
417435inline float block_q_n_dot_y (device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
418436 float d = qb_curr->d ;
419437 float m = qb_curr->m ;
420- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/ 2 );
438+
421439 float2 acc = 0 .f ;
440+
441+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2 );
442+
422443 for (int i = 0 ; i < 8 ; i+=2 ) {
423444 acc[0 ] += yl[i + 0 ] * (qs[i / 2 ] & 0x000F )
424445 + yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
@@ -428,6 +449,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
428449 return d * (acc[0 ] + acc[1 ]) + sumy * m;
429450}
430451
452+ // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
453+ // il indicates where the q5 quants begin (0 or QK5_0/4)
454+ // we assume that the yl's have been multiplied with the appropriate scale factor
455+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
456+ inline float block_q_n_dot_y (device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
457+ float d = qb_curr->d ;
458+
459+ float2 acc = 0 .f ;
460+
461+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2 );
462+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh );
463+
464+ for (int i = 0 ; i < 8 ; i+=2 ) {
465+ acc[0 ] += yl[i + 0 ] * ((qs[i / 2 ] & 0x000F ) | ((qh >> (i+0 +il ) << 4 ) & 0x00010 ))
466+ + yl[i + 1 ] * ((qs[i / 2 ] & 0x0F00 ) | ((qh >> (i+1 +il ) << 12 ) & 0x01000 ));
467+ acc[1 ] += yl[i + 8 ] * ((qs[i / 2 ] & 0x00F0 ) | ((qh >> (i+0 +il+QK5_0/2 ) << 8 ) & 0x00100 ))
468+ + yl[i + 9 ] * ((qs[i / 2 ] & 0xF000 ) | ((qh >> (i+1 +il+QK5_0/2 ) << 16 ) & 0x10000 ));
469+ }
470+ return d * (sumy * -16 .f + acc[0 ] + acc[1 ]);
471+ }
472+
473+ // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
474+ // il indicates where the q5 quants begin (0 or QK5_1/4)
475+ // we assume that the yl's have been multiplied with the appropriate scale factor
476+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
477+ inline float block_q_n_dot_y (device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
478+ float d = qb_curr->d ;
479+ float m = qb_curr->m ;
480+
481+ float2 acc = 0 .f ;
482+
483+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2 );
484+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh );
485+
486+ for (int i = 0 ; i < 8 ; i+=2 ) {
487+ acc[0 ] += yl[i + 0 ] * ((qs[i / 2 ] & 0x000F ) | ((qh >> (i+0 +il ) << 4 ) & 0x00010 ))
488+ + yl[i + 1 ] * ((qs[i / 2 ] & 0x0F00 ) | ((qh >> (i+1 +il ) << 12 ) & 0x01000 ));
489+ acc[1 ] += yl[i + 8 ] * ((qs[i / 2 ] & 0x00F0 ) | ((qh >> (i+0 +il+QK5_0/2 ) << 8 ) & 0x00100 ))
490+ + yl[i + 9 ] * ((qs[i / 2 ] & 0xF000 ) | ((qh >> (i+1 +il+QK5_0/2 ) << 16 ) & 0x10000 ));
491+ }
492+ return d * (acc[0 ] + acc[1 ]) + sumy * m;
493+ }
494+
431495// putting them in the kernel cause a significant performance penalty
432496#define N_DST 4 // each SIMD group works on 4 rows
433497#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
@@ -525,6 +589,43 @@ kernel void kernel_mul_mv_q4_1_f32(
525589 mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
526590}
527591
592+ kernel void kernel_mul_mv_q5_0_f32 (
593+ device const void * src0,
594+ device const float * src1,
595+ device float * dst,
596+ constant int64_t & ne00,
597+ constant int64_t & ne01[[buffer(4 )]],
598+ constant int64_t & ne02[[buffer(5 )]],
599+ constant int64_t & ne10[[buffer(9 )]],
600+ constant int64_t & ne12[[buffer(11 )]],
601+ constant int64_t & ne0[[buffer(15 )]],
602+ constant int64_t & ne1[[buffer(16 )]],
603+ constant uint & gqa[[buffer(17 )]],
604+ uint3 tgpig[[threadgroup_position_in_grid]],
605+ uint tiisg[[thread_index_in_simdgroup]],
606+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
607+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
608+ }
609+
610+ kernel void kernel_mul_mv_q5_1_f32 (
611+ device const void * src0,
612+ device const float * src1,
613+ device float * dst,
614+ constant int64_t & ne00,
615+ constant int64_t & ne01[[buffer(4 )]],
616+ constant int64_t & ne02[[buffer(5 )]],
617+ constant int64_t & ne10[[buffer(9 )]],
618+ constant int64_t & ne12[[buffer(11 )]],
619+ constant int64_t & ne0[[buffer(15 )]],
620+ constant int64_t & ne1[[buffer(16 )]],
621+ constant uint & gqa[[buffer(17 )]],
622+ uint3 tgpig[[threadgroup_position_in_grid]],
623+ uint tiisg[[thread_index_in_simdgroup]],
624+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
625+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
626+ }
627+
628+
528629#define NB_Q8_0 8
529630
530631kernel void kernel_mul_mv_q8_0_f32 (
@@ -2149,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
21492250 }
21502251}
21512252
2253+ template <typename type4x4>
2254+ void dequantize_q5_0 (device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2255+ device const uint16_t * qs = ((device const uint16_t *)xb + 3 );
2256+ const float d = xb->d ;
2257+ const float md = -16 .h * xb->d ;
2258+ const ushort mask = il ? 0x00F0 : 0x000F ;
2259+
2260+ const uint32_t qh = *((device const uint32_t *)xb->qh );
2261+
2262+ const int x_mv = il ? 4 : 0 ;
2263+
2264+ const int gh_mv = il ? 12 : 0 ;
2265+ const int gh_bk = il ? 0 : 4 ;
2266+
2267+ for (int i = 0 ; i < 8 ; i++) {
2268+ // extract the 5-th bits for x0 and x1
2269+ const uint8_t xh_0 = ((qh >> (gh_mv + 2 *i )) << gh_bk) & 0x10 ;
2270+ const uint8_t xh_1 = ((qh >> (gh_mv + 2 *i+1 )) << gh_bk) & 0x10 ;
2271+
2272+ // combine the 4-bits from qs with the 5th bit
2273+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2274+ const int32_t x1 = ((((qs[i] >> 8 ) & mask) >> x_mv) | xh_1);
2275+
2276+ reg[i/2 ][2 *(i%2 )+0 ] = d * x0 + md;
2277+ reg[i/2 ][2 *(i%2 )+1 ] = d * x1 + md;
2278+ }
2279+ }
2280+
2281+ template <typename type4x4>
2282+ void dequantize_q5_1 (device const block_q5_1 *xb, short il, thread type4x4 & reg) {
2283+ device const uint16_t * qs = ((device const uint16_t *)xb + 4 );
2284+ const float d = xb->d ;
2285+ const float m = xb->m ;
2286+ const ushort mask = il ? 0x00F0 : 0x000F ;
2287+
2288+ const uint32_t qh = *((device const uint32_t *)xb->qh );
2289+
2290+ const int x_mv = il ? 4 : 0 ;
2291+
2292+ const int gh_mv = il ? 12 : 0 ;
2293+ const int gh_bk = il ? 0 : 4 ;
2294+
2295+ for (int i = 0 ; i < 8 ; i++) {
2296+ // extract the 5-th bits for x0 and x1
2297+ const uint8_t xh_0 = ((qh >> (gh_mv + 2 *i )) << gh_bk) & 0x10 ;
2298+ const uint8_t xh_1 = ((qh >> (gh_mv + 2 *i+1 )) << gh_bk) & 0x10 ;
2299+
2300+ // combine the 4-bits from qs with the 5th bit
2301+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2302+ const int32_t x1 = ((((qs[i] >> 8 ) & mask) >> x_mv) | xh_1);
2303+
2304+ reg[i/2 ][2 *(i%2 )+0 ] = d * x0 + m;
2305+ reg[i/2 ][2 *(i%2 )+1 ] = d * x1 + m;
2306+ }
2307+ }
2308+
21522309template <typename type4x4>
21532310void dequantize_q8_0 (device const block_q8_0 *xb, short il, thread type4x4 & reg) {
21542311 device const int8_t * qs = ((device const int8_t *)xb->qs );
@@ -2490,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
24902647template [[host_name(" kernel_get_rows_f16" )]] kernel get_rows_t kernel_get_rows<half4x4, 1 , dequantize_f16>;
24912648template [[host_name(" kernel_get_rows_q4_0" )]] kernel get_rows_t kernel_get_rows<block_q4_0, 2 , dequantize_q4_0>;
24922649template [[host_name(" kernel_get_rows_q4_1" )]] kernel get_rows_t kernel_get_rows<block_q4_1, 2 , dequantize_q4_1>;
2650+ template [[host_name(" kernel_get_rows_q5_0" )]] kernel get_rows_t kernel_get_rows<block_q5_0, 2 , dequantize_q5_0>;
2651+ template [[host_name(" kernel_get_rows_q5_1" )]] kernel get_rows_t kernel_get_rows<block_q5_1, 2 , dequantize_q5_1>;
24932652template [[host_name(" kernel_get_rows_q8_0" )]] kernel get_rows_t kernel_get_rows<block_q8_0, 2 , dequantize_q8_0>;
24942653template [[host_name(" kernel_get_rows_q2_K" )]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
24952654template [[host_name(" kernel_get_rows_q3_K" )]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2518,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
25182677template [[host_name(" kernel_mul_mm_f16_f32" )]] kernel mat_mm_t kernel_mul_mm<half4x4, 1 , dequantize_f16>;
25192678template [[host_name(" kernel_mul_mm_q4_0_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2 , dequantize_q4_0>;
25202679template [[host_name(" kernel_mul_mm_q4_1_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2 , dequantize_q4_1>;
2680+ template [[host_name(" kernel_mul_mm_q5_0_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2 , dequantize_q5_0>;
2681+ template [[host_name(" kernel_mul_mm_q5_1_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2 , dequantize_q5_1>;
25212682template [[host_name(" kernel_mul_mm_q8_0_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2 , dequantize_q8_0>;
25222683template [[host_name(" kernel_mul_mm_q2_K_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
25232684template [[host_name(" kernel_mul_mm_q3_K_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
0 commit comments