@@ -473,16 +473,23 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
473473//
474474
475475#if __AVX__ || __AVX2__ || __AVX512F__
476- // multiply int8_t, add results pairwise twice
477- static inline __m128i mul_sum_i8_pairs (const __m128i x , const __m128i y ) {
478- // Get absolute values of x vectors
479- const __m128i ax = _mm_sign_epi8 (x , x );
480- // Sign the values of the y vectors
481- const __m128i sy = _mm_sign_epi8 (y , x );
482- // Perform multiplication and create 16-bit values
483- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
484- const __m128i ones = _mm_set1_epi16 (1 );
485- return _mm_madd_epi16 (ones , dot );
476+ // Unpack 16 4-bit fields into 16 bytes
477+ // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
478+ static inline __m128i bytes_from_nibbles_16 (const uint8_t * rsi )
479+ {
480+ // Load 8 bytes from memory
481+ __m128i tmp = _mm_loadl_epi64 ( ( const __m128i * )rsi );
482+
483+ // Expand bytes into uint16_t values
484+ __m128i bytes = _mm_cvtepu8_epi16 ( tmp );
485+
486+ // Unpack values into individual bytes
487+ const __m128i lowMask = _mm_set1_epi8 ( 0xF );
488+ __m128i high = _mm_andnot_si128 ( lowMask , bytes );
489+ __m128i low = _mm_and_si128 ( lowMask , bytes );
490+ high = _mm_slli_epi16 ( high , 4 );
491+ bytes = _mm_or_si128 ( low , high );
492+ return bytes ;
486493}
487494
488495// horizontally add 8 floats
@@ -529,10 +536,19 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
529536// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
530537static inline __m256i bytes_from_nibbles_32 (const uint8_t * rsi )
531538{
532- const __m128i tmp = _mm_loadu_si128 ((const __m128i * )rsi );
533- const __m256i bytes = _mm256_set_m128i (_mm_srli_epi16 (tmp , 4 ), tmp );
539+ // Load 16 bytes from memory
540+ __m128i tmp = _mm_loadu_si128 ( ( const __m128i * )rsi );
541+
542+ // Expand bytes into uint16_t values
543+ __m256i bytes = _mm256_cvtepu8_epi16 ( tmp );
544+
545+ // Unpack values into individual bytes
534546 const __m256i lowMask = _mm256_set1_epi8 ( 0xF );
535- return _mm256_and_si256 (lowMask , bytes );
547+ __m256i high = _mm256_andnot_si256 ( lowMask , bytes );
548+ __m256i low = _mm256_and_si256 ( lowMask , bytes );
549+ high = _mm256_slli_epi16 ( high , 4 );
550+ bytes = _mm256_or_si256 ( low , high );
551+ return bytes ;
536552}
537553
538554// add int16_t pairwise and return as float vector
@@ -2109,23 +2125,31 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
21092125 // Compute combined scale for the block
21102126 const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
21112127
2112- const __m128i lowMask = _mm_set1_epi8 (0xF );
2113- const __m128i off = _mm_set1_epi8 (8 );
2128+ __m128i i32 [2 ];
2129+ for (int j = 0 ; j < 2 ; ++ j ) {
2130+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2131+ __m128i bx = bytes_from_nibbles_16 (x [i ].qs + 8 * j );
2132+ __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 * j ));
2133+
2134+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2135+ const __m128i off = _mm_set1_epi8 ( 8 );
2136+ bx = _mm_sub_epi8 ( bx , off );
21142137
2115- const __m128i tmp = _mm_loadu_si128 ((const __m128i * )x [i ].qs );
2138+ // Get absolute values of x vectors
2139+ const __m128i ax = _mm_sign_epi8 (bx , bx );
21162140
2117- __m128i bx = _mm_and_si128 (lowMask , tmp );
2118- __m128i by = _mm_loadu_si128 ((const __m128i * )y [i ].qs );
2119- bx = _mm_sub_epi8 (bx , off );
2120- const __m128i i32_0 = mul_sum_i8_pairs (bx , by );
2141+ // Sign the values of the y vectors
2142+ const __m128i sy = _mm_sign_epi8 (by , bx );
21212143
2122- bx = _mm_and_si128 (lowMask , _mm_srli_epi64 (tmp , 4 ));
2123- by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 ));
2124- bx = _mm_sub_epi8 (bx , off );
2125- const __m128i i32_1 = mul_sum_i8_pairs (bx , by );
2144+ // Perform multiplication and create 16-bit values
2145+ const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2146+
2147+ const __m128i ones = _mm_set1_epi16 (1 );
2148+ i32 [j ] = _mm_madd_epi16 (ones , dot );
2149+ }
21262150
21272151 // Convert int32_t to float
2128- __m256 p = _mm256_cvtepi32_ps (_mm256_set_m128i (i32_0 , i32_1 ));
2152+ __m256 p = _mm256_cvtepi32_ps ( _mm256_set_m128i ( i32 [ 0 ], i32 [ 1 ] ));
21292153 // Apply the scale, and accumulate
21302154 acc = _mm256_add_ps (_mm256_mul_ps ( d , p ), acc );
21312155 }
@@ -2472,8 +2496,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
24722496 int sumi = 0 ;
24732497
24742498 for (int j = 0 ; j < qk /2 ; ++ j ) {
2475- const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
2476- const uint8_t xh_1 = ((qh >> ( j + 12 )) ) & 0x10 ;
2499+ const uint8_t xh_0 = ((qh & ( 1u << ( j + 0 ))) >> (j + 0 )) << 4 ;
2500+ const uint8_t xh_1 = ((qh & ( 1u << ( j + 16 ))) >> ( j + 12 )) ;
24772501
24782502 const int32_t x0 = ((x [i ].qs [j ] & 0x0F ) | xh_0 ) - 16 ;
24792503 const int32_t x1 = ((x [i ].qs [j ] >> 4 ) | xh_1 ) - 16 ;
@@ -2698,8 +2722,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
26982722 int sumi = 0 ;
26992723
27002724 for (int j = 0 ; j < qk /2 ; ++ j ) {
2701- const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
2702- const uint8_t xh_1 = ((qh >> ( j + 12 )) ) & 0x10 ;
2725+ const uint8_t xh_0 = ((qh & ( 1u << ( j + 0 ))) >> (j + 0 )) << 4 ;
2726+ const uint8_t xh_1 = ((qh & ( 1u << ( j + 16 ))) >> ( j + 12 )) ;
27032727
27042728 const int32_t x0 = (x [i ].qs [j ] & 0xF ) | xh_0 ;
27052729 const int32_t x1 = (x [i ].qs [j ] >> 4 ) | xh_1 ;
0 commit comments