@@ -647,9 +647,10 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
647647#define QK8_0 32
648648typedef struct {
649649 float d ; // delta
650+ float s ; // d * sum(qs[i])
650651 int8_t qs [QK8_0 ]; // quants
651652} block_q8_0 ;
652- static_assert (sizeof (block_q8_0 ) == sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
653+ static_assert (sizeof (block_q8_0 ) == 2 * sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
653654
654655
655656// reference implementation for deterministic creation of model files
@@ -1247,10 +1248,13 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
12471248
12481249 y [i ].d = d ;
12491250
1251+ int sum = 0 ;
12501252 for (int l = 0 ; l < QK8_0 ; ++ l ) {
12511253 const float v = x [i * QK8_0 + l ]* id ;
12521254 y [i ].qs [l ] = roundf (v );
1255+ sum += y [i ].qs [l ];
12531256 }
1257+ y [i ].s = d * sum ;
12541258 }
12551259}
12561260
@@ -1280,6 +1284,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
12801284
12811285 y [i ].d = d ;
12821286
1287+ int32x4_t accv = vdupq_n_s32 (0 );
1288+
12831289 for (int l = 0 ; l < 8 ; l ++ ) {
12841290 const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
12851291 const int32x4_t vi = vcvtnq_s32_f32 (v );
@@ -1288,7 +1294,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
12881294 y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
12891295 y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
12901296 y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1297+
1298+ accv = vaddq_s32 (accv , vi );
12911299 }
1300+ int32_t sum = vaddvq_s32 (accv );
1301+ y [i ].s = d * sum ;
12921302 }
12931303#elif defined(__AVX2__ ) || defined(__AVX__ )
12941304 for (int i = 0 ; i < nb ; i ++ ) {
@@ -1336,6 +1346,16 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13361346 __m256i i3 = _mm256_cvtps_epi32 ( v3 );
13371347
13381348#if defined(__AVX2__ )
1349+
1350+ // Compute the sum of the quants
1351+ // There is not better way of doing this???
1352+ __m256i acc = _mm256_add_epi32 (_mm256_add_epi32 (i0 , i1 ), _mm256_add_epi32 (i2 , i3 ));
1353+ __m128i sum128 = _mm_add_epi32 (_mm256_castsi256_si128 (acc ), _mm256_extracti128_si256 (acc , 1 ));
1354+ __m128i hi64 = _mm_unpackhi_epi64 (sum128 , sum128 );
1355+ __m128i sum64 = _mm_add_epi32 (hi64 , sum128 );
1356+ __m128i hi32 = _mm_shuffle_epi32 (sum64 , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
1357+ y [i ].s = d * _mm_cvtsi128_si32 (_mm_add_epi32 (sum64 , hi32 ));
1358+
13391359 // Convert int32 to int16
13401360 i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
13411361 i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1378,6 +1398,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13781398 // scalar
13791399 quantize_row_q8_0_reference (x , y , k );
13801400#endif
1401+ #if defined __AVX__
1402+ // TODO: vectorize this
1403+ for (int i = 0 ; i < nb ; ++ i ) {
1404+ int sum = 0 ;
1405+ for (int l = 0 ; l < QK8_0 ; ++ l ) sum += y [i ].qs [l ];
1406+ y [i ].s = y [i ].d * sum ;
1407+ }
1408+ #endif
13811409}
13821410
13831411static void dequantize_row_q4_0 (const void * restrict vx , float * restrict y , int k ) {
@@ -2282,14 +2310,18 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
22822310 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
22832311 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
22842312
2313+ float sum8 = 0 ;
2314+
22852315 for (int i = 0 ; i < nb ; i += 2 ) {
22862316 const block_q4_0 * restrict x0 = & x [i + 0 ];
22872317 const block_q4_0 * restrict x1 = & x [i + 1 ];
22882318 const block_q8_0 * restrict y0 = & y [i + 0 ];
22892319 const block_q8_0 * restrict y1 = & y [i + 1 ];
22902320
2321+ sum8 += x0 -> d * y0 -> s + x1 -> d * y1 -> s ;
2322+
22912323 const uint8x16_t m4b = vdupq_n_u8 (0xf );
2292- const int8x16_t s8b = vdupq_n_s8 (0x8 );
2324+ // const int8x16_t s8b = vdupq_n_s8(0x8);
22932325
22942326 const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
22952327 const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
@@ -2301,10 +2333,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23012333 const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
23022334
23032335 // sub 8
2304- const int8x16_t v0_0ls = vsubq_s8 (v0_0l , s8b );
2305- const int8x16_t v0_0hs = vsubq_s8 (v0_0h , s8b );
2306- const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
2307- const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
2336+ // const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2337+ // const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2338+ // const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2339+ // const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
23082340
23092341 // load y
23102342 const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
@@ -2320,21 +2352,31 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23202352
23212353#if defined(__ARM_FEATURE_DOTPROD )
23222354 // dot product into int32x4_t
2323- const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls ), v0_0hs , v1_0hs );
2324- const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls ), v0_1hs , v1_1hs );
2355+ //const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2356+ //const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2357+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2358+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
23252359
23262360 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
23272361 sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
23282362#else
2329- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
2330- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
2331- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
2332- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
2363+ //const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2364+ //const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2365+ //const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2366+ //const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2367+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0l ), vget_low_s8 (v1_0ls ));
2368+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0l ), vget_high_s8 (v1_0ls ));
2369+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0h ), vget_low_s8 (v1_0hs ));
2370+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0h ), vget_high_s8 (v1_0hs ));
23332371
2334- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
2335- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
2336- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
2337- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
2372+ //const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2373+ //const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2374+ //const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2375+ //const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2376+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1l ), vget_low_s8 (v1_1ls ));
2377+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1l ), vget_high_s8 (v1_1ls ));
2378+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1h ), vget_low_s8 (v1_1hs ));
2379+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1h ), vget_high_s8 (v1_1hs ));
23382380
23392381 const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
23402382 const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
@@ -2346,7 +2388,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23462388#endif
23472389 }
23482390
2349- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2391+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) - 8 * sum8 ;
23502392#elif defined(__AVX2__ )
23512393 // Initialize accumulator with zeros
23522394 __m256 acc = _mm256_setzero_ps ();
@@ -2479,12 +2521,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
24792521 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
24802522 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
24812523
2524+ float summs = 0 ;
2525+
24822526 for (int i = 0 ; i < nb ; i += 2 ) {
24832527 const block_q4_1 * restrict x0 = & x [i + 0 ];
24842528 const block_q4_1 * restrict x1 = & x [i + 1 ];
24852529 const block_q8_0 * restrict y0 = & y [i + 0 ];
24862530 const block_q8_0 * restrict y1 = & y [i + 1 ];
24872531
2532+ summs += x0 -> m * y0 -> s + x1 -> m * y1 -> s ;
2533+
24882534 const uint8x16_t m4b = vdupq_n_u8 (0xf );
24892535
24902536 const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
@@ -2508,16 +2554,18 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25082554 const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
25092555 const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
25102556
2511- const int16x8_t s0i = vaddq_s16 (
2512- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0ls )), vmovl_s8 (vget_high_s8 (v1_0ls ))),
2513- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0hs )), vmovl_s8 (vget_high_s8 (v1_0hs ))));
2557+ // We no longer need this. We have computed the sum of the y quants during quantization,
2558+ // so we get the same as these via the scalar instruction above (summs += x0->m * y0->s + x1->m * y1->s)
2559+ //const int16x8_t s0i = vaddq_s16(
2560+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2561+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
25142562
2515- const int16x8_t s1i = vaddq_s16 (
2516- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1ls )), vmovl_s8 (vget_high_s8 (v1_1ls ))),
2517- vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1hs )), vmovl_s8 (vget_high_s8 (v1_1hs ))));
2563+ // const int16x8_t s1i = vaddq_s16(
2564+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2565+ // vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
25182566
2519- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s0i ), vget_high_s16 (s0i ))), x0 -> m * y0 -> d );
2520- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (s1i ), vget_high_s16 (s1i ))), x1 -> m * y1 -> d );
2567+ // sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2568+ // sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
25212569
25222570#if defined(__ARM_FEATURE_DOTPROD )
25232571 // dot product into int32x4_t
@@ -2547,24 +2595,28 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25472595#endif
25482596 }
25492597
2550- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2598+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) + summs ;
25512599#elif defined(__AVX2__ )
25522600 // Initialize accumulator with zeros
25532601 __m256 acc = _mm256_setzero_ps ();
25542602
2603+ float summs = 0 ;
2604+
25552605 // Main loop
25562606 for (int i = 0 ; i < nb ; ++ i ) {
25572607 const float * d0 = & x [i ].d ;
25582608 const float * d1 = & y [i ].d ;
2559- const float * m0 = & x [i ].m ;
2609+ //const float * m0 = &x[i].m;
2610+
2611+ summs += x [i ].m * y [i ].s ;
25602612
25612613 const __m256 d0v = _mm256_broadcast_ss ( d0 );
25622614 const __m256 d1v = _mm256_broadcast_ss ( d1 );
2563- const __m256 m0v = _mm256_broadcast_ss ( m0 );
2615+ // const __m256 m0v = _mm256_broadcast_ss( m0 );
25642616
25652617 // Compute combined scales
25662618 const __m256 d0d1 = _mm256_mul_ps ( d0v , d1v );
2567- const __m256 d1m0 = _mm256_mul_ps ( d1v , m0v );
2619+ // const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
25682620
25692621 // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
25702622 const __m256i bx = bytesFromNibbles ( x [i ].qs );
@@ -2587,14 +2639,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25872639 // Accumulate d0*d1*x*y
25882640 acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
25892641
2590- // Compute sum of y values
2591- const __m256i y16_l = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
2592- const __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
2593- const __m256i ysumi = _mm256_madd_epi16 ( _mm256_add_epi16 (y16_l , y16_h ), ones );
2594- const __m256 ysum = _mm256_cvtepi32_ps ( ysumi );
2642+ // We no longer need this. We have computed the sum of the y quants during quantization,
2643+ // so we get the same as these via the single scalar instruction above (summs += x[i].m * y[i].s)
2644+ //// Compute sum of y values
2645+ //const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2646+ //const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2647+ //const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2648+ //const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
25952649
2596- // Accumulate d1*m0*y
2597- acc = _mm256_fmadd_ps ( d1m0 , ysum , acc );
2650+ //// Accumulate d1*m0*y
2651+ // acc = _mm256_fmadd_ps( d1m0, ysum, acc );
25982652 }
25992653
26002654 // Return horizontal sum of the acc vector
@@ -2603,7 +2657,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26032657 res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
26042658 res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
26052659
2606- sumf = _mm_cvtss_f32 ( res );
2660+ sumf = _mm_cvtss_f32 ( res ) + summs ;
26072661#else
26082662 // scalar
26092663 for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments