@@ -487,6 +487,15 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
487487 return bytes ;
488488}
489489
490+ // horizontally add 8 floats
491+ static inline float hsum_float_8 (const __m256 x ) {
492+ __m128 res = _mm256_extractf128_ps (x , 1 );
493+ res = _mm_add_ps (res , _mm256_castps256_ps128 (x ));
494+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
495+ res = _mm_add_ss (res , _mm_movehdup_ps (res ));
496+ return _mm_cvtss_f32 (res );
497+ }
498+
490499#if __AVX2__ || __AVX512F__
491500// Unpack 32 4-bit fields into 32 bytes
492501// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
@@ -507,6 +516,24 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
507516 return bytes ;
508517}
509518
519+ // add int16_t pairwise and return as float vector
520+ static inline __m256 sum_i16_pairs_float (const __m256i x ) {
521+ const __m256i ones = _mm256_set1_epi16 (1 );
522+ const __m256i summed_pairs = _mm256_madd_epi16 (ones , x );
523+ return _mm256_cvtepi32_ps (summed_pairs );
524+ }
525+
526+ // multiply int8_t, add results pairwise twice and return as float vector
527+ static inline __m256 mul_sum_i8_pairs_float (const __m256i x , const __m256i y ) {
528+ // Get absolute values of x vectors
529+ const __m256i ax = _mm256_sign_epi8 (x , x );
530+ // Sign the values of the y vectors
531+ const __m256i sy = _mm256_sign_epi8 (y , x );
532+ // Perform multiplication and create 16-bit values
533+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
534+ return sum_i16_pairs_float (dot );
535+ }
536+
510537static inline __m128i packNibbles ( __m256i bytes )
511538{
512539 // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -2366,8 +2393,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23662393 const block_q4_0 * restrict x = vx ;
23672394 const block_q8_0 * restrict y = vy ;
23682395
2369- float sumf = 0.0 ;
2370-
23712396#if defined(__ARM_NEON )
23722397 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
23732398 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
@@ -2436,7 +2461,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24362461#endif
24372462 }
24382463
2439- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2464+ * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
24402465#elif defined(__AVX2__ )
24412466 // Initialize accumulator with zeros
24422467 __m256 acc = _mm256_setzero_ps ();
@@ -2454,32 +2479,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24542479
24552480 __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
24562481
2457- // Get absolute values of x vectors
2458- const __m256i ax = _mm256_sign_epi8 (bx , bx );
2459-
2460- // Sign the values of the y vectors
2461- const __m256i sy = _mm256_sign_epi8 (by , bx );
2462-
2463- // Perform multiplication and create 16-bit values
2464- const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2465-
2466- const __m256i ones = _mm256_set1_epi16 (1 );
2467- __m256i xy_q = _mm256_madd_epi16 (ones , dot );
2468-
2469- /* Convert to vectore of 8 int32_t to 8 floats */
2470- __m256 q = _mm256_cvtepi32_ps ( xy_q );
2482+ const __m256 q = mul_sum_i8_pairs_float (bx , by );
24712483
24722484 /* Multiply q with scale and accumulate */
24732485 acc = _mm256_fmadd_ps ( d , q , acc );
24742486 }
24752487
2476- // Return horizontal sum of the acc vector
2477- __m128 res = _mm256_extractf128_ps ( acc , 1 );
2478- res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2479- res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2480- res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2481-
2482- sumf = _mm_cvtss_f32 ( res );
2488+ * s = hsum_float_8 (acc );
24832489#elif defined(__AVX__ )
24842490 // Initialize accumulator with zeros
24852491 __m256 acc = _mm256_setzero_ps ();
@@ -2518,15 +2524,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25182524 acc = _mm256_add_ps (_mm256_mul_ps ( d , p ), acc );
25192525 }
25202526
2521- // Return horizontal sum of the acc vector
2522- __m128 res = _mm256_extractf128_ps ( acc , 1 );
2523- res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2524- res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2525- res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2526-
2527- sumf = _mm_cvtss_f32 ( res );
2527+ * s = hsum_float_8 (acc );
25282528#else
25292529 // scalar
2530+ float sumf = 0.0 ;
25302531 for (int i = 0 ; i < nb ; i ++ ) {
25312532 const float d0 = x [i ].d ;
25322533 const float d1 = y [i ].d ;
@@ -2548,9 +2549,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25482549 }
25492550 sumf += d0 * d1 * sumi ;
25502551 }
2551- #endif
2552-
25532552 * s = sumf ;
2553+ #endif
25542554}
25552555
25562556static void ggml_vec_dot_q4_1_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
@@ -2562,8 +2562,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25622562 const block_q4_1 * restrict x = vx ;
25632563 const block_q8_0 * restrict y = vy ;
25642564
2565- float sumf = 0.0 ;
2566-
25672565 // TODO: add AVX / WASM SIMD / etc
25682566#if defined(__ARM_NEON )
25692567 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
@@ -2637,7 +2635,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26372635#endif
26382636 }
26392637
2640- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2638+ * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
26412639#elif defined(__AVX2__ )
26422640 // Initialize accumulator with zeros
26432641 __m256 acc = _mm256_setzero_ps ();
@@ -2660,42 +2658,24 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26602658 const __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
26612659 const __m256i by = _mm256_loadu_si256 ( (const __m256i * )y [i ].qs );
26622660
2663- // Get absolute values of x vectors
2664- const __m256i ax = _mm256_sign_epi8 ( bx , bx );
2665-
2666- // Sign the values of the y vectors
2667- const __m256i sy = _mm256_sign_epi8 ( by , bx );
2668-
2669- // Perform multiplication and create 16-bit values
2670- const __m256i dot = _mm256_maddubs_epi16 ( ax , sy );
2671- const __m256i ones = _mm256_set1_epi16 ( 1 );
2672- const __m256i xy_q = _mm256_madd_epi16 ( ones , dot );
2673-
2674- // Convert to vector of 8 int32_t to 8 floats
2675- const __m256 xy = _mm256_cvtepi32_ps ( xy_q );
2661+ const __m256 xy = mul_sum_i8_pairs_float (bx , by );
26762662
26772663 // Accumulate d0*d1*x*y
26782664 acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
26792665
26802666 // Compute sum of y values
26812667 const __m256i y16_l = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
26822668 const __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
2683- const __m256i ysumi = _mm256_madd_epi16 ( _mm256_add_epi16 (y16_l , y16_h ), ones );
2684- const __m256 ysum = _mm256_cvtepi32_ps ( ysumi );
2669+ const __m256 ysum = sum_i16_pairs_float (_mm256_add_epi16 (y16_l , y16_h ));
26852670
26862671 // Accumulate d1*m0*y
26872672 acc = _mm256_fmadd_ps ( d1m0 , ysum , acc );
26882673 }
26892674
2690- // Return horizontal sum of the acc vector
2691- __m128 res = _mm256_extractf128_ps ( acc , 1 );
2692- res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2693- res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2694- res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2695-
2696- sumf = _mm_cvtss_f32 ( res );
2675+ * s = hsum_float_8 (acc );
26972676#else
26982677 // scalar
2678+ float sumf = 0.0 ;
26992679 for (int i = 0 ; i < nb ; i ++ ) {
27002680 const float d0 = x [i ].d ;
27012681 const float m0 = x [i ].m ;
@@ -2717,9 +2697,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
27172697 sumf += f0 * f2 + f1 * f3 ;
27182698 }
27192699 }
2720- #endif
2721-
27222700 * s = sumf ;
2701+ #endif
27232702}
27242703
27252704static void ggml_vec_dot_q4_2_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
@@ -2732,8 +2711,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
27322711 const block_q4_2 * restrict x = vx ;
27332712 const block_q8_0 * restrict y = vy ;
27342713
2735- float sumf = 0.0 ;
2736-
27372714#if defined(__ARM_NEON )
27382715 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
27392716 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
@@ -2811,7 +2788,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
28112788#endif
28122789 }
28132790
2814- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2791+ * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
28152792#elif defined(__AVX2__ )
28162793 // Initialize accumulator with zeros
28172794 __m256 acc = _mm256_setzero_ps ();
@@ -2833,32 +2810,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
28332810
28342811 __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
28352812
2836- // Get absolute values of x vectors
2837- const __m256i ax = _mm256_sign_epi8 (bx , bx );
2838- // Sign the values of the y vectors
2839- const __m256i sy = _mm256_sign_epi8 (by , bx );
2840- // Perform multiplication and create 16-bit values
2841- const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2842-
2843- const __m256i ones = _mm256_set1_epi16 (1 );
2844- __m256i xy_q = _mm256_madd_epi16 (ones , dot );
2845-
2846- /* Convert to vectore of 8 int32_t to 8 floats */
2847- __m256 q = _mm256_cvtepi32_ps (xy_q );
2813+ const __m256 q = mul_sum_i8_pairs_float (bx , by );
28482814
28492815 /* Multiply q with scale and accumulate */
28502816 acc = _mm256_fmadd_ps (d , q , acc );
28512817 }
28522818
2853- // Return horizontal sum of the acc vector
2854- __m128 res = _mm256_extractf128_ps (acc , 1 );
2855- res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2856- res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2857- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2858-
2859- sumf = _mm_cvtss_f32 (res );
2819+ * s = hsum_float_8 (acc );
28602820#else
28612821 // scalar
2822+ float sumf = 0.0 ;
28622823 for (int i = 0 ; i < nb ; i ++ ) {
28632824 const uint8_t * restrict x0 = x [2 * i + 0 ].qs ;
28642825 const uint8_t * restrict x1 = x [2 * i + 1 ].qs ;
@@ -2893,9 +2854,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
28932854 sumf += (d0 * y [i ].d ) * sumi_0 ;
28942855 sumf += (d1 * y [i ].d ) * sumi_1 ;
28952856 }
2896- #endif
2897-
28982857 * s = sumf ;
2858+ #endif
28992859}
29002860
29012861static void ggml_vec_dot_q4_3_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
@@ -2908,8 +2868,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
29082868 const block_q4_3 * restrict x = vx ;
29092869 const block_q8_0 * restrict y = vy ;
29102870
2911- float sumf = 0.0 ;
2912-
29132871#if defined(__ARM_NEON )
29142872 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
29152873 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
@@ -2995,9 +2953,41 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
29952953#endif
29962954 }
29972955
2998- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2956+ * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2957+ #elif defined(__AVX2__ )
2958+ // Initialize accumulator with zeros
2959+ __m256 acc = _mm256_setzero_ps ();
2960+
2961+ // Main loop
2962+ for (int i = 0 ; i < nb ; i ++ ) {
2963+ const __m128 d0 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 0 ].d ));
2964+ const __m128 d1 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 1 ].d ));
2965+ const __m256 dx = _mm256_set_m128 (d1 , d0 );
2966+
2967+ const __m128 m0 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 0 ].m ));
2968+ const __m128 m1 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 1 ].m ));
2969+ const __m256 mx = _mm256_set_m128 (m1 , m0 );
2970+
2971+ const __m128i bx0 = bytes_from_nibbles_16 (x [2 * i + 0 ].qs );
2972+ const __m128i bx1 = bytes_from_nibbles_16 (x [2 * i + 1 ].qs );
2973+ const __m256i bx = _mm256_set_m128i (bx1 , bx0 );
2974+
2975+ const __m256 dy = _mm256_broadcast_ss (& y [i ].d );
2976+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2977+
2978+ const __m256i syi = _mm256_maddubs_epi16 (_mm256_set1_epi8 (1 ), by );
2979+ const __m256 syf = sum_i16_pairs_float (syi );
2980+
2981+ const __m256 q = mul_sum_i8_pairs_float (bx , by );
2982+
2983+ const __m256 sxy = _mm256_fmadd_ps (q , dx , _mm256_mul_ps (mx , syf ));
2984+ acc = _mm256_fmadd_ps (sxy , dy , acc );
2985+ }
2986+
2987+ * s = hsum_float_8 (acc );
29992988#else
30002989 // scalar
2990+ float sumf = 0.0 ;
30012991 for (int i = 0 ; i < nb ; i ++ ) {
30022992 const uint8_t * restrict x0 = x [2 * i + 0 ].qs ;
30032993 const uint8_t * restrict x1 = x [2 * i + 1 ].qs ;
@@ -3040,9 +3030,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
30403030 sumf += (d0 * sxy_0 + m0 * sy_0 )* y [i ].d ;
30413031 sumf += (d1 * sxy_1 + m1 * sy_1 )* y [i ].d ;
30423032 }
3043- #endif
3044-
30453033 * s = sumf ;
3034+ #endif
30463035}
30473036
30483037
0 commit comments