@@ -657,9 +657,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
657657#define QK8_0 32
658658typedef struct {
659659 float d ; // delta
660+ float s ; // d * sum(qs[i])
660661 int8_t qs [QK8_0 ]; // quants
661662} block_q8_0 ;
662- static_assert (sizeof (block_q8_0 ) == sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663+ static_assert (sizeof (block_q8_0 ) == 2 * sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663664
664665
665666// reference implementation for deterministic creation of model files
@@ -1299,10 +1300,13 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
12991300
13001301 y [i ].d = d ;
13011302
1303+ int sum = 0 ;
13021304 for (int l = 0 ; l < QK8_0 ; ++ l ) {
13031305 const float v = x [i * QK8_0 + l ]* id ;
13041306 y [i ].qs [l ] = roundf (v );
1307+ sum += y [i ].qs [l ];
13051308 }
1309+ y [i ].s = d * sum ;
13061310 }
13071311}
13081312
@@ -1332,6 +1336,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13321336
13331337 y [i ].d = d ;
13341338
1339+ int32x4_t accv = vdupq_n_s32 (0 );
1340+
13351341 for (int l = 0 ; l < 8 ; l ++ ) {
13361342 const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
13371343 const int32x4_t vi = vcvtnq_s32_f32 (v );
@@ -1340,7 +1346,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13401346 y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
13411347 y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
13421348 y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1349+
1350+ accv = vaddq_s32 (accv , vi );
13431351 }
1352+ int32_t sum = vaddvq_s32 (accv );
1353+ y [i ].s = d * sum ;
13441354 }
13451355#elif defined(__AVX2__ ) || defined(__AVX__ )
13461356 for (int i = 0 ; i < nb ; i ++ ) {
@@ -1388,6 +1398,16 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13881398 __m256i i3 = _mm256_cvtps_epi32 ( v3 );
13891399
13901400#if defined(__AVX2__ )
1401+
1402+ // Compute the sum of the quants
1403+ // There is not better way of doing this???
1404+ __m256i acc = _mm256_add_epi32 (_mm256_add_epi32 (i0 , i1 ), _mm256_add_epi32 (i2 , i3 ));
1405+ __m128i sum128 = _mm_add_epi32 (_mm256_castsi256_si128 (acc ), _mm256_extracti128_si256 (acc , 1 ));
1406+ __m128i hi64 = _mm_unpackhi_epi64 (sum128 , sum128 );
1407+ __m128i sum64 = _mm_add_epi32 (hi64 , sum128 );
1408+ __m128i hi32 = _mm_shuffle_epi32 (sum64 , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
1409+ y [i ].s = d * _mm_cvtsi128_si32 (_mm_add_epi32 (sum64 , hi32 ));
1410+
13911411 // Convert int32 to int16
13921412 i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
13931413 i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1430,6 +1450,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
14301450 // scalar
14311451 quantize_row_q8_0_reference (x , y , k );
14321452#endif
1453+ #if defined __AVX__
1454+ // TODO: vectorize this
1455+ for (int i = 0 ; i < nb ; ++ i ) {
1456+ int sum = 0 ;
1457+ for (int l = 0 ; l < QK8_0 ; ++ l ) sum += y [i ].qs [l ];
1458+ y [i ].s = y [i ].d * sum ;
1459+ }
1460+ #endif
14331461}
14341462
14351463static void dequantize_row_q4_0 (const void * restrict vx , float * restrict y , int k ) {
@@ -2372,14 +2400,18 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23722400 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
23732401 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
23742402
2403+ float sum8 = 0 ;
2404+
23752405 for (int i = 0 ; i < nb ; i += 2 ) {
23762406 const block_q4_0 * restrict x0 = & x [i + 0 ];
23772407 const block_q4_0 * restrict x1 = & x [i + 1 ];
23782408 const block_q8_0 * restrict y0 = & y [i + 0 ];
23792409 const block_q8_0 * restrict y1 = & y [i + 1 ];
23802410
2411+ sum8 += x0 -> d * y0 -> s + x1 -> d * y1 -> s ;
2412+
23812413 const uint8x16_t m4b = vdupq_n_u8 (0xf );
2382- const int8x16_t s8b = vdupq_n_s8 (0x8 );
2414+ // const int8x16_t s8b = vdupq_n_s8(0x8);
23832415
23842416 const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
23852417 const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
@@ -2391,10 +2423,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23912423 const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
23922424
23932425 // sub 8
2394- const int8x16_t v0_0ls = vsubq_s8 (v0_0l , s8b );
2395- const int8x16_t v0_0hs = vsubq_s8 (v0_0h , s8b );
2396- const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
2397- const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
2426+ // const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2427+ // const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2428+ // const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2429+ // const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
23982430
23992431 // load y
24002432 const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
@@ -2410,21 +2442,31 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24102442
24112443#if defined(__ARM_FEATURE_DOTPROD )
24122444 // dot product into int32x4_t
2413- const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls ), v0_0hs , v1_0hs );
2414- const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls ), v0_1hs , v1_1hs );
2445+ //const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2446+ //const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2447+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2448+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
24152449
24162450 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
24172451 sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
24182452#else
2419- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
2420- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
2421- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
2422- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
2453+ //const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2454+ //const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2455+ //const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2456+ //const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2457+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0l ), vget_low_s8 (v1_0ls ));
2458+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0l ), vget_high_s8 (v1_0ls ));
2459+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0h ), vget_low_s8 (v1_0hs ));
2460+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0h ), vget_high_s8 (v1_0hs ));
24232461
2424- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
2425- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
2426- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
2427- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
2462+ //const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2463+ //const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2464+ //const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2465+ //const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2466+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1l ), vget_low_s8 (v1_1ls ));
2467+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1l ), vget_high_s8 (v1_1ls ));
2468+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1h ), vget_low_s8 (v1_1hs ));
2469+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1h ), vget_high_s8 (v1_1hs ));
24282470
24292471 const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
24302472 const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
@@ -2436,7 +2478,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24362478#endif
24372479 }
24382480
2439- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2481+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) - 8 * sum8 ;
24402482#elif defined(__AVX2__ )
24412483 // Initialize accumulator with zeros
24422484 __m256 acc = _mm256_setzero_ps ();
@@ -2569,12 +2611,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25692611 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
25702612 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
25712613
2614+ float summs = 0 ;
2615+
25722616 for (int i = 0 ; i < nb ; i += 2 ) {
25732617 const block_q4_1 * restrict x0 = & x [i + 0 ];
25742618 const block_q4_1 * restrict x1 = & x [i + 1 ];
25752619 const block_q8_0 * restrict y0 = & y [i + 0 ];
25762620 const block_q8_0 * restrict y1 = & y [i + 1 ];
25772621
2622+ summs += x0 -> m * y0 -> s + x1 -> m * y1 -> s ;
2623+
25782624 const uint8x16_t m4b = vdupq_n_u8 (0xf );
25792625
25802626 const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
@@ -2598,16 +2644,18 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25982644 const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
25992645 const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
26002646
2601- const int16x8_t s0i = vaddq_s16 (
2602- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0ls )), vmovl_s8 (vget_high_s8 (v1_0ls ))),
2603- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0hs )), vmovl_s8 (vget_high_s8 (v1_0hs ))));
2647+ // We no longer need this. We have computed the sum of the y quants during quantization,
2648+ // so we get the same as these via the scalar instruction above (summs += x0->m * y0->s + x1->m * y1->s)
2649+ //const int16x8_t s0i = vaddq_s16(
2650+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2651+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
26042652
2605- const int16x8_t s1i = vaddq_s16 (
2606- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1ls )), vmovl_s8 (vget_high_s8 (v1_1ls ))),
2607- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1hs )), vmovl_s8 (vget_high_s8 (v1_1hs ))));
2653+ // const int16x8_t s1i = vaddq_s16(
2654+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2655+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
26082656
2609- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s0i ), vget_high_s16 (s0i ))), x0 -> m * y0 -> d );
2610- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s1i ), vget_high_s16 (s1i ))), x1 -> m * y1 -> d );
2657+ // sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2658+ // sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
26112659
26122660#if defined(__ARM_FEATURE_DOTPROD )
26132661 // dot product into int32x4_t
@@ -2637,24 +2685,28 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26372685#endif
26382686 }
26392687
2640- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2688+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) + summs ;
26412689#elif defined(__AVX2__ )
26422690 // Initialize accumulator with zeros
26432691 __m256 acc = _mm256_setzero_ps ();
26442692
2693+ float summs = 0 ;
2694+
26452695 // Main loop
26462696 for (int i = 0 ; i < nb ; ++ i ) {
26472697 const float * d0 = & x [i ].d ;
26482698 const float * d1 = & y [i ].d ;
2649- const float * m0 = & x [i ].m ;
2699+ //const float * m0 = &x[i].m;
2700+
2701+ summs += x [i ].m * y [i ].s ;
26502702
26512703 const __m256 d0v = _mm256_broadcast_ss ( d0 );
26522704 const __m256 d1v = _mm256_broadcast_ss ( d1 );
2653- const __m256 m0v = _mm256_broadcast_ss ( m0 );
2705+ // const __m256 m0v = _mm256_broadcast_ss( m0 );
26542706
26552707 // Compute combined scales
26562708 const __m256 d0d1 = _mm256_mul_ps ( d0v , d1v );
2657- const __m256 d1m0 = _mm256_mul_ps ( d1v , m0v );
2709+ // const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
26582710
26592711 // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
26602712 const __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
@@ -2677,14 +2729,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26772729 // Accumulate d0*d1*x*y
26782730 acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
26792731
2680- // Compute sum of y values
2681- const __m256i y16_l = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
2682- const __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
2683- const __m256i ysumi = _mm256_madd_epi16 ( _mm256_add_epi16 (y16_l , y16_h ), ones );
2684- const __m256 ysum = _mm256_cvtepi32_ps ( ysumi );
2732+ // We no longer need this. We have computed the sum of the y quants during quantization,
2733+ // so we get the same as these via the single scalar instruction above (summs += x[i].m * y[i].s)
2734+ //// Compute sum of y values
2735+ //const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2736+ //const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2737+ //const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2738+ //const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
26852739
2686- // Accumulate d1*m0*y
2687- acc = _mm256_fmadd_ps ( d1m0 , ysum , acc );
2740+ //// Accumulate d1*m0*y
2741+ // acc = _mm256_fmadd_ps( d1m0, ysum, acc );
26882742 }
26892743
26902744 // Return horizontal sum of the acc vector
@@ -2693,7 +2747,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26932747 res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
26942748 res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
26952749
2696- sumf = _mm_cvtss_f32 ( res );
2750+ sumf = _mm_cvtss_f32 ( res ) + summs ;
26972751#else
26982752 // scalar
26992753 for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments