@@ -3079,32 +3079,50 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
30793079 float32x4_t sumv0 = vdupq_n_f32 (0.0f );
30803080 float32x4_t sumv1 = vdupq_n_f32 (0.0f );
30813081
3082- for (int i = 0 ; i < nb ; ++ i ) {
3083- const block_q8_0 * restrict x0 = & x [i ];
3084- const block_q8_0 * restrict y0 = & y [i ];
3082+ for (int i = 0 ; i < nb ; i += 2 ) {
3083+ const block_q8_0 * restrict x0 = & x [i + 0 ];
3084+ const block_q8_0 * restrict x1 = & x [i + 1 ];
3085+ const block_q8_0 * restrict y0 = & y [i + 0 ];
3086+ const block_q8_0 * restrict y1 = & y [i + 1 ];
30853087
3086- const int8x16_t v0_0 = vld1q_s8 (x0 -> qs );
3087- const int8x16_t v0_1 = vld1q_s8 (x0 -> qs + 16 );
3088+ const int8x16_t x0_0 = vld1q_s8 (x0 -> qs );
3089+ const int8x16_t x0_1 = vld1q_s8 (x0 -> qs + 16 );
3090+ const int8x16_t x1_0 = vld1q_s8 (x1 -> qs );
3091+ const int8x16_t x1_1 = vld1q_s8 (x1 -> qs + 16 );
30883092
30893093 // load y
3090- const int8x16_t v1_0 = vld1q_s8 (y0 -> qs );
3091- const int8x16_t v1_1 = vld1q_s8 (y0 -> qs + 16 );
3094+ const int8x16_t y0_0 = vld1q_s8 (y0 -> qs );
3095+ const int8x16_t y0_1 = vld1q_s8 (y0 -> qs + 16 );
3096+ const int8x16_t y1_0 = vld1q_s8 (y1 -> qs );
3097+ const int8x16_t y1_1 = vld1q_s8 (y1 -> qs + 16 );
30923098
30933099#if defined(__ARM_FEATURE_DOTPROD )
30943100 sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (
3095- vdotq_s32 (vdupq_n_s32 (0 ), v0_0 , v1_0 ),
3096- vdotq_s32 (vdupq_n_s32 (0 ), v0_1 , v1_1 ))), x0 -> d * y0 -> d );
3097- #else
3098- const int16x8_t p0l = vmull_s8 (vget_low_s8 (v0_0 ), vget_low_s8 (v1_0 ));
3099- const int16x8_t p0h = vmull_s8 (vget_high_s8 (v0_0 ), vget_high_s8 (v1_0 ));
3100- const int16x8_t p1l = vmull_s8 (vget_low_s8 (v0_1 ), vget_low_s8 (v1_1 ));
3101- const int16x8_t p1h = vmull_s8 (vget_high_s8 (v0_1 ), vget_high_s8 (v1_1 ));
3101+ vdotq_s32 (vdupq_n_s32 (0 ), x0_0 , y0_0 ),
3102+ vdotq_s32 (vdupq_n_s32 (0 ), x0_1 , y0_1 ))), x0 -> d * y0 -> d );
31023103
3103- const int32x4_t pl = vaddq_s32 (vpaddlq_s16 (p0l ), vpaddlq_s16 (p0h ));
3104- const int32x4_t ph = vaddq_s32 (vpaddlq_s16 (p1l ), vpaddlq_s16 (p1h ));
3104+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (
3105+ vdotq_s32 (vdupq_n_s32 (0 ), x1_0 , y1_0 ),
3106+ vdotq_s32 (vdupq_n_s32 (0 ), x1_1 , y1_1 ))), x1 -> d * y1 -> d );
31053107
3106- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (pl ), x0 -> d * y0 -> d );
3107- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph ), x0 -> d * y0 -> d );
3108+ #else
3109+ const int16x8_t p0_0 = vmull_s8 (vget_low_s8 (x0_0 ), vget_low_s8 (y0_0 ));
3110+ const int16x8_t p0_1 = vmull_s8 (vget_high_s8 (x0_0 ), vget_high_s8 (y0_0 ));
3111+ const int16x8_t p0_2 = vmull_s8 (vget_low_s8 (x0_1 ), vget_low_s8 (y0_1 ));
3112+ const int16x8_t p0_3 = vmull_s8 (vget_high_s8 (x0_1 ), vget_high_s8 (y0_1 ));
3113+
3114+ const int16x8_t p1_0 = vmull_s8 (vget_low_s8 (x1_0 ), vget_low_s8 (y1_0 ));
3115+ const int16x8_t p1_1 = vmull_s8 (vget_high_s8 (x1_0 ), vget_high_s8 (y1_0 ));
3116+ const int16x8_t p1_2 = vmull_s8 (vget_low_s8 (x1_1 ), vget_low_s8 (y1_1 ));
3117+ const int16x8_t p1_3 = vmull_s8 (vget_high_s8 (x1_1 ), vget_high_s8 (y1_1 ));
3118+
3119+ const int32x4_t p0 = vaddq_s32 (vpaddlq_s16 (p0_0 ), vpaddlq_s16 (p0_1 ));
3120+ const int32x4_t p1 = vaddq_s32 (vpaddlq_s16 (p0_2 ), vpaddlq_s16 (p0_3 ));
3121+ const int32x4_t p2 = vaddq_s32 (vpaddlq_s16 (p1_0 ), vpaddlq_s16 (p1_1 ));
3122+ const int32x4_t p3 = vaddq_s32 (vpaddlq_s16 (p1_2 ), vpaddlq_s16 (p1_3 ));
3123+
3124+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (p0 , p1 )), x0 -> d * y0 -> d );
3125+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (p2 , p3 )), x1 -> d * y1 -> d );
31083126#endif
31093127 }
31103128
0 commit comments