@@ -656,10 +656,11 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
656656#define QK8_0 32
657657typedef struct {
658658 float d ; // delta
659- float s ; // d * sum(qs[i])
659+ float s0 ; // d * sum(qs[i]) low
660+ float s1 ; // d * sum(qs[i]) high
660661 int8_t qs [QK8_0 ]; // quants
661662} block_q8_0 ;
662- static_assert (sizeof (block_q8_0 ) == 2 * sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663+ static_assert (sizeof (block_q8_0 ) == 3 * sizeof (float ) + QK8_0 , "wrong q8_0 block size/padding" );
663664
664665
665666// reference implementation for deterministic creation of model files
@@ -1299,13 +1300,22 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
12991300
13001301 y [i ].d = d ;
13011302
1302- int sum = 0 ;
1303- for (int l = 0 ; l < QK8_0 ; ++ l ) {
1304- const float v = x [i * QK8_0 + l ]* id ;
1305- y [i ].qs [l ] = roundf (v );
1306- sum += y [i ].qs [l ];
1303+ int sum0 = 0 ;
1304+ int sum1 = 0 ;
1305+
1306+ for (int l = 0 ; l < QK8_0 /2 ; ++ l ) {
1307+ const float v0 = x [i * QK8_0 + l ]* id ;
1308+ const float v1 = x [i * QK8_0 + QK8_0 /2 + l ]* id ;
1309+
1310+ y [i ].qs [ l ] = roundf (v0 );
1311+ y [i ].qs [QK8_0 /2 + l ] = roundf (v1 );
1312+
1313+ sum0 += y [i ].qs [ l ];
1314+ sum1 += y [i ].qs [QK8_0 /2 + l ];
13071315 }
1308- y [i ].s = d * sum ;
1316+
1317+ y [i ].s0 = d * sum0 ;
1318+ y [i ].s1 = d * sum1 ;
13091319 }
13101320}
13111321
@@ -1335,9 +1345,24 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13351345
13361346 y [i ].d = d ;
13371347
1338- int32x4_t accv = vdupq_n_s32 (0 );
1348+ int32x4_t accv0 = vdupq_n_s32 (0 );
1349+ int32x4_t accv1 = vdupq_n_s32 (0 );
13391350
1340- for (int l = 0 ; l < 8 ; l ++ ) {
1351+ // low half
1352+ for (int l = 0 ; l < 4 ; l ++ ) {
1353+ const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1354+ const int32x4_t vi = vcvtnq_s32_f32 (v );
1355+
1356+ y [i ].qs [4 * l + 0 ] = vgetq_lane_s32 (vi , 0 );
1357+ y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
1358+ y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1359+ y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1360+
1361+ accv0 = vaddq_s32 (accv0 , vi );
1362+ }
1363+
1364+ // high half
1365+ for (int l = 4 ; l < 8 ; l ++ ) {
13411366 const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
13421367 const int32x4_t vi = vcvtnq_s32_f32 (v );
13431368
@@ -1346,12 +1371,17 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13461371 y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
13471372 y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
13481373
1349- accv = vaddq_s32 (accv , vi );
1374+ accv1 = vaddq_s32 (accv1 , vi );
13501375 }
1351- int32_t sum = vaddvq_s32 (accv );
1352- y [i ].s = d * sum ;
1376+
1377+ const int32_t sum0 = vaddvq_s32 (accv0 );
1378+ const int32_t sum1 = vaddvq_s32 (accv1 );
1379+
1380+ y [i ].s0 = d * sum0 ;
1381+ y [i ].s1 = d * sum1 ;
13531382 }
13541383#elif defined(__AVX2__ ) || defined(__AVX__ )
1384+ // TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
13551385 for (int i = 0 ; i < nb ; i ++ ) {
13561386 // Load elements into 4 AVX vectors
13571387 __m256 v0 = _mm256_loadu_ps ( x );
@@ -1398,7 +1428,9 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13981428
13991429#if defined(__AVX2__ )
14001430 // Compute the sum of the quants and set y[i].s
1401- y [i ].s = d * hsum_i32_8 (_mm256_add_epi32 (_mm256_add_epi32 (i0 , i1 ), _mm256_add_epi32 (i2 , i3 )));
1431+ //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1432+ y [i ].s0 = d * hsum_i32_8 (_mm256_add_epi32 (i0 , i1 ));
1433+ y [i ].s1 = d * hsum_i32_8 (_mm256_add_epi32 (i2 , i3 ));
14021434
14031435 // Convert int32 to int16
14041436 i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@@ -2395,7 +2427,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23952427 const block_q8_0 * restrict y0 = & y [i + 0 ];
23962428 const block_q8_0 * restrict y1 = & y [i + 1 ];
23972429
2398- sum8 += x0 -> d * y0 -> s + x1 -> d * y1 -> s ;
2430+ sum8 += x0 -> d * ( y0 -> s0 + y0 -> s1 ) + x1 -> d * ( y1 -> s0 + y1 -> s1 ) ;
23992431
24002432 const uint8x16_t m4b = vdupq_n_u8 (0xf );
24012433
@@ -2562,7 +2594,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25622594 const block_q8_0 * restrict y0 = & y [i + 0 ];
25632595 const block_q8_0 * restrict y1 = & y [i + 1 ];
25642596
2565- summs += x0 -> m * y0 -> s + x1 -> m * y1 -> s ;
2597+ summs += x0 -> m * ( y0 -> s0 + y0 -> s1 ) + x1 -> m * ( y1 -> s0 + y1 -> s1 ) ;
25662598
25672599 const uint8x16_t m4b = vdupq_n_u8 (0xf );
25682600
@@ -2575,22 +2607,22 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25752607 const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
25762608 const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
25772609
2610+ // interleave
2611+ const int8x16_t v0_0lz = vzip1q_s8 (v0_0l , v0_0h );
2612+ const int8x16_t v0_0hz = vzip2q_s8 (v0_0l , v0_0h );
2613+ const int8x16_t v0_1lz = vzip1q_s8 (v0_1l , v0_1h );
2614+ const int8x16_t v0_1hz = vzip2q_s8 (v0_1l , v0_1h );
2615+
25782616 // load y
25792617 const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
25802618 const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
25812619 const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
25822620 const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
25832621
2584- // interleave
2585- const int8x16_t v1_0ls = vuzp1q_s8 (v1_0l , v1_0h );
2586- const int8x16_t v1_0hs = vuzp2q_s8 (v1_0l , v1_0h );
2587- const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2588- const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2589-
25902622#if defined(__ARM_FEATURE_DOTPROD )
25912623 // dot product into int32x4_t
2592- const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2593- const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
2624+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l ), v0_0hz , v1_0h );
2625+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l ), v0_1hz , v1_1h );
25942626
25952627 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), x0 -> d * y0 -> d );
25962628 sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), x1 -> d * y1 -> d );
@@ -2627,7 +2659,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26272659 const float * d0 = & x [i ].d ;
26282660 const float * d1 = & y [i ].d ;
26292661
2630- summs += x [i ].m * y [i ].s ;
2662+ summs += x [i ].m * ( y [i ].s0 + y [ i ]. s1 ) ;
26312663
26322664 const __m256 d0v = _mm256_broadcast_ss ( d0 );
26332665 const __m256 d1v = _mm256_broadcast_ss ( d1 );
@@ -2845,88 +2877,53 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
28452877 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
28462878 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
28472879
2848- for (int i = 0 ; i < nb ; i += 2 ) {
2880+ float summs0 = 0.0f ;
2881+ float summs1 = 0.0f ;
2882+
2883+ for (int i = 0 ; i < nb ; ++ i ) {
28492884 const block_q4_3 * restrict x0_0 = & x [2 * (i + 0 ) + 0 ];
28502885 const block_q4_3 * restrict x0_1 = & x [2 * (i + 0 ) + 1 ];
2851- const block_q4_3 * restrict x1_0 = & x [2 * (i + 1 ) + 0 ];
2852- const block_q4_3 * restrict x1_1 = & x [2 * (i + 1 ) + 1 ];
28532886
28542887 const block_q8_0 * restrict y0 = & y [i + 0 ];
2855- const block_q8_0 * restrict y1 = & y [i + 1 ];
2856-
2857- const uint8x16_t m4b = vdupq_n_u8 (0xf );
2858-
2859- const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
2860- const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
2861- const float x1_0d = GGML_FP16_TO_FP32 (x1_0 -> d );
2862- const float x1_1d = GGML_FP16_TO_FP32 (x1_1 -> d );
28632888
2864- const float x0_0m = GGML_FP16_TO_FP32 (x0_0 -> m );
2865- const float x0_1m = GGML_FP16_TO_FP32 (x0_1 -> m );
2866- const float x1_0m = GGML_FP16_TO_FP32 (x1_0 -> m );
2867- const float x1_1m = GGML_FP16_TO_FP32 (x1_1 -> m );
2889+ summs0 += GGML_FP16_TO_FP32 (x0_0 -> m ) * y0 -> s0 ;
2890+ summs1 += GGML_FP16_TO_FP32 (x0_1 -> m ) * y0 -> s1 ;
28682891
28692892 const uint8x16_t v0_0 = vcombine_u8 (vld1_u8 (x0_0 -> qs ), vld1_u8 (x0_1 -> qs ));
2870- const uint8x16_t v0_1 = vcombine_u8 (vld1_u8 (x1_0 -> qs ), vld1_u8 (x1_1 -> qs ));
28712893
28722894 // 4-bit -> 8-bit
2873- const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
2895+ const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , vdupq_n_u8 ( 0xf ) ));
28742896 const int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
2875- const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
2876- const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
28772897
28782898 // interleave
28792899 const int8x16_t v0_0lz = vzip1q_s8 (v0_0l , v0_0h );
28802900 const int8x16_t v0_0hz = vzip2q_s8 (v0_0l , v0_0h );
2881- const int8x16_t v0_1lz = vzip1q_s8 (v0_1l , v0_1h );
2882- const int8x16_t v0_1hz = vzip2q_s8 (v0_1l , v0_1h );
28832901
28842902 // load y
28852903 const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
28862904 const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
2887- const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2888- const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2889-
2890- const int16x8_t sy0_0 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0l )), vmovl_s8 (vget_high_s8 (v1_0l )));
2891- const int16x8_t sy0_1 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_0h )), vmovl_s8 (vget_high_s8 (v1_0h )));
28922905
2893- const int16x8_t sy1_0 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1l )), vmovl_s8 (vget_high_s8 (v1_1l )));
2894- const int16x8_t sy1_1 = vaddq_s16 (vmovl_s8 (vget_low_s8 (v1_1h )), vmovl_s8 (vget_high_s8 (v1_1h )));
2895-
2896- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy0_0 ), vget_high_s16 (sy0_0 ))), x0_0m * y0 -> d );
2897- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy0_1 ), vget_high_s16 (sy0_1 ))), x0_1m * y0 -> d );
2898- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy1_0 ), vget_high_s16 (sy1_0 ))), x1_0m * y1 -> d );
2899- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddl_s16 (vget_low_s16 (sy1_1 ), vget_high_s16 (sy1_1 ))), x1_1m * y1 -> d );
2906+ const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
2907+ const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
29002908
29012909#if defined(__ARM_FEATURE_DOTPROD )
29022910 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l )), x0_0d * y0 -> d );
2903- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
2904- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l )), x1_0d * y1 -> d );
2905- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hz , v1_1h )), x1_1d * y1 -> d );
2911+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
29062912#else
29072913 const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
29082914 const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
29092915 const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hz ), vget_low_s8 (v1_0h ));
29102916 const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hz ), vget_high_s8 (v1_0h ));
29112917
2912- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lz ), vget_low_s8 (v1_1l ));
2913- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lz ), vget_high_s8 (v1_1l ));
2914- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hz ), vget_low_s8 (v1_1h ));
2915- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hz ), vget_high_s8 (v1_1h ));
2916-
29172918 const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
29182919 const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
2919- const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
2920- const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
29212920
29222921 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (pl0 ), x0_0d * y0 -> d );
2923- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
2924- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (pl1 ), x1_0d * y1 -> d );
2925- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph1 ), x1_1d * y1 -> d );
2922+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
29262923#endif
29272924 }
29282925
2929- * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 ( sumv1 ) ;
2926+ * s = vaddvq_f32 (vaddq_f32 ( sumv0 , sumv1 )) + summs0 + summs1 ;
29302927#elif defined(__AVX2__ )
29312928 // Initialize accumulator with zeros
29322929 __m256 acc = _mm256_setzero_ps ();
@@ -2971,9 +2968,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
29712968 const float d1 = GGML_FP16_TO_FP32 (x [2 * i + 1 ].d );
29722969 const float m1 = GGML_FP16_TO_FP32 (x [2 * i + 1 ].m );
29732970
2974- int sy_0 = 0 ;
2975- int sy_1 = 0 ;
2976-
29772971 int sxy_0 = 0 ;
29782972 int sxy_1 = 0 ;
29792973
@@ -2993,15 +2987,11 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
29932987 const int y0_1 = y0 [2 * (j + QK8_0 /4 ) + 0 ];
29942988 const int y1_1 = y0 [2 * (j + QK8_0 /4 ) + 1 ];
29952989
2996- sy_0 += y0_0 + y1_0 ;
2997- sy_1 += y0_1 + y1_1 ;
2998-
29992990 sxy_0 += x0_0 * y0_0 + x1_0 * y1_0 ;
30002991 sxy_1 += x0_1 * y0_1 + x1_1 * y1_1 ;
30012992 }
30022993
3003- sumf += (d0 * sxy_0 + m0 * sy_0 )* y [i ].d ;
3004- sumf += (d1 * sxy_1 + m1 * sy_1 )* y [i ].d ;
2994+ sumf += (d0 * sxy_0 + d1 * sxy_1 )* y [i ].d + m0 * y [i ].s0 + m1 * y [i ].s1 ;
30052995 }
30062996 * s = sumf ;
30072997#endif
0 commit comments