@@ -2875,77 +2875,53 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
28752875 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
28762876 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
28772877
2878- float summs = 0.0f ;
2878+ float summs0 = 0.0f ;
2879+ float summs1 = 0.0f ;
28792880
2880- for (int i = 0 ; i < nb ; i += 2 ) {
2881+ for (int i = 0 ; i < nb ; ++ i ) {
28812882 const block_q4_3 * restrict x0_0 = & x [2 * (i + 0 ) + 0 ];
28822883 const block_q4_3 * restrict x0_1 = & x [2 * (i + 0 ) + 1 ];
2883- const block_q4_3 * restrict x1_0 = & x [2 * (i + 1 ) + 0 ];
2884- const block_q4_3 * restrict x1_1 = & x [2 * (i + 1 ) + 1 ];
28852884
28862885 const block_q8_0 * restrict y0 = & y [i + 0 ];
2887- const block_q8_0 * restrict y1 = & y [i + 1 ];
28882886
2889- summs += GGML_FP16_TO_FP32 (x0_0 -> m ) * y0 -> s0 + GGML_FP16_TO_FP32 (x0_1 -> m ) * y0 -> s1 ;
2890- summs += GGML_FP16_TO_FP32 (x1_0 -> m ) * y1 -> s0 + GGML_FP16_TO_FP32 (x1_1 -> m ) * y1 -> s1 ;
2891-
2892- const uint8x16_t m4b = vdupq_n_u8 (0xf );
2893-
2894- const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
2895- const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
2896- const float x1_0d = GGML_FP16_TO_FP32 (x1_0 -> d );
2897- const float x1_1d = GGML_FP16_TO_FP32 (x1_1 -> d );
2887+ summs0 += GGML_FP16_TO_FP32 (x0_0 -> m ) * y0 -> s0 ;
2888+ summs1 += GGML_FP16_TO_FP32 (x0_1 -> m ) * y0 -> s1 ;
28982889
28992890 const uint8x16_t v0_0 = vcombine_u8 (vld1_u8 (x0_0 -> qs ), vld1_u8 (x0_1 -> qs ));
2900- const uint8x16_t v0_1 = vcombine_u8 (vld1_u8 (x1_0 -> qs ), vld1_u8 (x1_1 -> qs ));
29012891
29022892 // 4-bit -> 8-bit
2903- const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
2893+ const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , vdupq_n_u8 ( 0xf ) ));
29042894 const int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
2905- const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
2906- const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
29072895
29082896 // interleave
29092897 const int8x16_t v0_0lz = vzip1q_s8 (v0_0l , v0_0h );
29102898 const int8x16_t v0_0hz = vzip2q_s8 (v0_0l , v0_0h );
2911- const int8x16_t v0_1lz = vzip1q_s8 (v0_1l , v0_1h );
2912- const int8x16_t v0_1hz = vzip2q_s8 (v0_1l , v0_1h );
29132899
29142900 // load y
29152901 const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
29162902 const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
2917- const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2918- const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2903+
2904+ const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
2905+ const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
29192906
29202907#if defined(__ARM_FEATURE_DOTPROD )
29212908 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l )), x0_0d * y0 -> d );
2922- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
2923- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l )), x1_0d * y1 -> d );
2924- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hz , v1_1h )), x1_1d * y1 -> d );
2909+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
29252910#else
29262911 const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
29272912 const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
29282913 const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hz ), vget_low_s8 (v1_0h ));
29292914 const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hz ), vget_high_s8 (v1_0h ));
29302915
2931- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lz ), vget_low_s8 (v1_1l ));
2932- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lz ), vget_high_s8 (v1_1l ));
2933- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hz ), vget_low_s8 (v1_1h ));
2934- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hz ), vget_high_s8 (v1_1h ));
2935-
29362916 const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
29372917 const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
2938- const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
2939- const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
29402918
29412919 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (pl0 ), x0_0d * y0 -> d );
2942- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
2943- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (pl1 ), x1_0d * y1 -> d );
2944- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph1 ), x1_1d * y1 -> d );
2920+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
29452921#endif
29462922 }
29472923
2948- sumf = vaddvq_f32 (vaddq_f32 (sumv0 , sumv1 )) + summs ;
2924+ * s = vaddvq_f32 (vaddq_f32 (sumv0 , sumv1 )) + summs0 + summs1 ;
29492925#elif defined(__AVX2__ )
29502926 // Initialize accumulator with zeros
29512927 __m256 acc = _mm256_setzero_ps ();
0 commit comments