@@ -2978,77 +2978,53 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
29782978 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
29792979 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
29802980
2981- float summs = 0.0f ;
2981+ float summs0 = 0.0f ;
2982+ float summs1 = 0.0f ;
29822983
2983- for (int i = 0 ; i < nb ; i += 2 ) {
2984+ for (int i = 0 ; i < nb ; ++ i ) {
29842985 const block_q4_3 * restrict x0_0 = & x [2 * (i + 0 ) + 0 ];
29852986 const block_q4_3 * restrict x0_1 = & x [2 * (i + 0 ) + 1 ];
2986- const block_q4_3 * restrict x1_0 = & x [2 * (i + 1 ) + 0 ];
2987- const block_q4_3 * restrict x1_1 = & x [2 * (i + 1 ) + 1 ];
29882987
29892988 const block_q8_0 * restrict y0 = & y [i + 0 ];
2990- const block_q8_0 * restrict y1 = & y [i + 1 ];
29912989
2992- summs += GGML_FP16_TO_FP32 (x0_0 -> m ) * y0 -> s0 + GGML_FP16_TO_FP32 (x0_1 -> m ) * y0 -> s1 ;
2993- summs += GGML_FP16_TO_FP32 (x1_0 -> m ) * y1 -> s0 + GGML_FP16_TO_FP32 (x1_1 -> m ) * y1 -> s1 ;
2994-
2995- const uint8x16_t m4b = vdupq_n_u8 (0xf );
2996-
2997- const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
2998- const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
2999- const float x1_0d = GGML_FP16_TO_FP32 (x1_0 -> d );
3000- const float x1_1d = GGML_FP16_TO_FP32 (x1_1 -> d );
2990+ summs0 += GGML_FP16_TO_FP32 (x0_0 -> m ) * y0 -> s0 ;
2991+ summs1 += GGML_FP16_TO_FP32 (x0_1 -> m ) * y0 -> s1 ;
30012992
30022993 const uint8x16_t v0_0 = vcombine_u8 (vld1_u8 (x0_0 -> qs ), vld1_u8 (x0_1 -> qs ));
3003- const uint8x16_t v0_1 = vcombine_u8 (vld1_u8 (x1_0 -> qs ), vld1_u8 (x1_1 -> qs ));
30042994
30052995 // 4-bit -> 8-bit
3006- const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
2996+ const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , vdupq_n_u8 ( 0xf ) ));
30072997 const int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
3008- const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
3009- const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
30102998
30112999 // interleave
30123000 const int8x16_t v0_0lz = vzip1q_s8 (v0_0l , v0_0h );
30133001 const int8x16_t v0_0hz = vzip2q_s8 (v0_0l , v0_0h );
3014- const int8x16_t v0_1lz = vzip1q_s8 (v0_1l , v0_1h );
3015- const int8x16_t v0_1hz = vzip2q_s8 (v0_1l , v0_1h );
30163002
30173003 // load y
30183004 const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
30193005 const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
3020- const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
3021- const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
3006+
3007+ const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
3008+ const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
30223009
30233010#if defined(__ARM_FEATURE_DOTPROD )
30243011 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l )), x0_0d * y0 -> d );
3025- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
3026- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l )), x1_0d * y1 -> d );
3027- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hz , v1_1h )), x1_1d * y1 -> d );
3012+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
30283013#else
30293014 const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
30303015 const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
30313016 const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hz ), vget_low_s8 (v1_0h ));
30323017 const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hz ), vget_high_s8 (v1_0h ));
30333018
3034- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lz ), vget_low_s8 (v1_1l ));
3035- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lz ), vget_high_s8 (v1_1l ));
3036- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hz ), vget_low_s8 (v1_1h ));
3037- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hz ), vget_high_s8 (v1_1h ));
3038-
30393019 const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
30403020 const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
3041- const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
3042- const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
30433021
30443022 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (pl0 ), x0_0d * y0 -> d );
3045- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
3046- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (pl1 ), x1_0d * y1 -> d );
3047- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph1 ), x1_1d * y1 -> d );
3023+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
30483024#endif
30493025 }
30503026
3051- sumf = vaddvq_f32 (vaddq_f32 (sumv0 , sumv1 )) + summs ;
3027+ sumf = vaddvq_f32 (vaddq_f32 (sumv0 , sumv1 )) + summs0 + summs1 ;
30523028#else
30533029 // scalar
30543030 for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments