@@ -1310,6 +1310,29 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
13101310 }
13111311}
13121312
1313+ #ifdef __AVX2__
1314+ // There is no better way of doing this?
1315+ // I guess not, AVX is not very good at horizontal sums.
1316+ // The commented solution for a hotrizontal sum was suggested by @pubby as being slightly
1317+ // faster than the solution below. As I don't have an AVX2 system handt right now to test,
1318+ // keeping the original.
1319+ // TODO: Please try and if it does make a differece, uncomment and remove the implementation below.
1320+ //static inline float horizontal_sum(__m256i a) {
1321+ // __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a)));
1322+ // __m256i sum = _mm256_add_epi32(a, b);
1323+ // __m256i hi = _mm256_unpackhi_epi64(sum, sum);
1324+ // sum = _mm256_add_epi32(sum, hi);
1325+ // return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
1326+ //}
1327+ static inline float horizontal_sum (__m256i a ) {
1328+ __m128i sum128 = _mm_add_epi32 (_mm256_castsi256_si128 (a ), _mm256_extracti128_si256 (a , 1 ));
1329+ __m128i hi64 = _mm_unpackhi_epi64 (sum128 , sum128 );
1330+ __m128i sum64 = _mm_add_epi32 (hi64 , sum128 );
1331+ __m128i hi32 = _mm_shuffle_epi32 (sum64 , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
1332+ return _mm_cvtsi128_si32 (_mm_add_epi32 (sum64 , hi32 ));
1333+ }
1334+ #endif
1335+
13131336static void quantize_row_q8_0 (const float * restrict x , void * restrict vy , int k ) {
13141337 assert (k % QK8_0 == 0 );
13151338 const int nb = k / QK8_0 ;
@@ -1399,14 +1422,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13991422
14001423#if defined(__AVX2__ )
14011424
1402- // Compute the sum of the quants
1403- // There is not better way of doing this???
1404- __m256i acc = _mm256_add_epi32 (_mm256_add_epi32 (i0 , i1 ), _mm256_add_epi32 (i2 , i3 ));
1405- __m128i sum128 = _mm_add_epi32 (_mm256_castsi256_si128 (acc ), _mm256_extracti128_si256 (acc , 1 ));
1406- __m128i hi64 = _mm_unpackhi_epi64 (sum128 , sum128 );
1407- __m128i sum64 = _mm_add_epi32 (hi64 , sum128 );
1408- __m128i hi32 = _mm_shuffle_epi32 (sum64 , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
1409- y [i ].s = d * _mm_cvtsi128_si32 (_mm_add_epi32 (sum64 , hi32 ));
1425+ // Compute the sum of the quants and set y[i].s
1426+ y [i ].s = d * horizontal_sum (_mm256_add_epi32 (_mm256_add_epi32 (i0 , i1 ), _mm256_add_epi32 (i2 , i3 )));
14101427
14111428 // Convert int32 to int16
14121429 i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@@ -2411,7 +2428,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24112428 sum8 += x0 -> d * y0 -> s + x1 -> d * y1 -> s ;
24122429
24132430 const uint8x16_t m4b = vdupq_n_u8 (0xf );
2414- //const int8x16_t s8b = vdupq_n_s8(0x8);
24152431
24162432 const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
24172433 const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
@@ -2422,12 +2438,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24222438 const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
24232439 const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
24242440
2425- // sub 8
2426- //const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2427- //const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2428- //const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2429- //const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2430-
24312441 // load y
24322442 const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
24332443 const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
@@ -2442,27 +2452,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24422452
24432453#if defined(__ARM_FEATURE_DOTPROD )
24442454 // dot product into int32x4_t
2445- //const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2446- //const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
24472455 const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
24482456 const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
24492457
24502458 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
24512459 sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
24522460#else
2453- //const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2454- //const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2455- //const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2456- //const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
24572461 const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0l ), vget_low_s8 (v1_0ls ));
24582462 const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0l ), vget_high_s8 (v1_0ls ));
24592463 const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0h ), vget_low_s8 (v1_0hs ));
24602464 const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0h ), vget_high_s8 (v1_0hs ));
24612465
2462- //const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2463- //const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2464- //const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2465- //const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
24662466 const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1l ), vget_low_s8 (v1_1ls ));
24672467 const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1l ), vget_high_s8 (v1_1ls ));
24682468 const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1h ), vget_low_s8 (v1_1hs ));
@@ -2644,19 +2644,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26442644 const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
26452645 const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
26462646
2647- // We no longer need this. We have computed the sum of the y quants during quantization,
2648- // so we get the same as these via the scalar instruction above (summs += x0->m * y0->s + x1->m * y1->s)
2649- //const int16x8_t s0i = vaddq_s16(
2650- // vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2651- // vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2652-
2653- //const int16x8_t s1i = vaddq_s16(
2654- // vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2655- // vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2656-
2657- //sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2658- //sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2659-
26602647#if defined(__ARM_FEATURE_DOTPROD )
26612648 // dot product into int32x4_t
26622649 const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
@@ -2702,11 +2689,9 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
27022689
27032690 const __m256 d0v = _mm256_broadcast_ss ( d0 );
27042691 const __m256 d1v = _mm256_broadcast_ss ( d1 );
2705- //const __m256 m0v = _mm256_broadcast_ss( m0 );
27062692
27072693 // Compute combined scales
27082694 const __m256 d0d1 = _mm256_mul_ps ( d0v , d1v );
2709- //const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
27102695
27112696 // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
27122697 const __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
@@ -2728,17 +2713,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
27282713
27292714 // Accumulate d0*d1*x*y
27302715 acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
2731-
2732- // We no longer need this. We have computed the sum of the y quants during quantization,
2733- // so we get the same as these via the single scalar instruction above (summs += x[i].m * y[i].s)
2734- //// Compute sum of y values
2735- //const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2736- //const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2737- //const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2738- //const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2739-
2740- //// Accumulate d1*m0*y
2741- //acc = _mm256_fmadd_ps( d1m0, ysum, acc );
27422716 }
27432717
27442718 // Return horizontal sum of the acc vector
0 commit comments