@@ -1539,7 +1539,37 @@ static void quantize_row_q8_0c(const float * restrict x, void * restrict vy, int
15391539 int8_t * restrict qs = vy ;
15401540 float * restrict ds = (float * ) ((uint8_t * ) vy + nb * QK8_0C );
15411541
1542- #if __AVX512F__
1542+ #if defined(__ARM_NEON )
1543+ for (int i = 0 ; i < nb ; i ++ ) {
1544+ float32x4_t srcv [8 ];
1545+ float32x4_t asrcv [8 ];
1546+ float32x4_t amaxv [8 ];
1547+
1548+ for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = vld1q_f32 (x + i * 32 + 4 * l );
1549+ for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vabsq_f32 (srcv [l ]);
1550+
1551+ for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vmaxq_f32 (asrcv [2 * l ], asrcv [2 * l + 1 ]);
1552+ for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = vmaxq_f32 (amaxv [4 * l ], amaxv [4 * l + 2 ]);
1553+ for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = vmaxq_f32 (amaxv [8 * l ], amaxv [8 * l + 4 ]);
1554+
1555+ const float amax = vmaxvq_f32 (amaxv [0 ]);
1556+
1557+ const float d = amax / ((1 << 7 ) - 1 );
1558+ const float id = d ? 1.0f /d : 0.0f ;
1559+
1560+ ds [i ] = d ;
1561+
1562+ for (int l = 0 ; l < 8 ; l ++ ) {
1563+ const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1564+ const int32x4_t vi = vcvtnq_s32_f32 (v );
1565+
1566+ qs [i * QK8_0C + 4 * l + 0 ] = vgetq_lane_s32 (vi , 0 );
1567+ qs [i * QK8_0C + 4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
1568+ qs [i * QK8_0C + 4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1569+ qs [i * QK8_0C + 4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1570+ }
1571+ }
1572+ #elif defined(__AVX512F__ )
15431573 for (int i = 0 ; i < nb ; i ++ ) {
15441574 const __m512 x0 = _mm512_loadu_ps ( x + i * QK8_0C );
15451575 const __m512 x1 = _mm512_loadu_ps ( x + i * QK8_0C + QK8_0C /2 );
@@ -2817,7 +2847,69 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
28172847
28182848 float sumf = 0.0 ;
28192849
2820- #if __AVX512F__
2850+ #if defined(__ARM_NEON )
2851+ float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2852+ float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2853+
2854+ for (int i = 0 ; i < nb /2 ; i ++ ) {
2855+ const int dst0 = i + i /2 * 2 ; // 0, 1, 4, 5, 8, 9, ...
2856+ const int dst1 = i + i /2 * 2 + 2 ; // 2, 3, 6, 7, 10, 11 ...
2857+
2858+ const uint8x16_t m4b = vdupq_n_u8 (0xf );
2859+ const int8x16_t s8b = vdupq_n_s8 (0x8 );
2860+
2861+ const uint8x16_t v0_01l = vld1q_u8 (& xqs [i * QK4_0 ]);
2862+ const uint8x16_t v0_01h = vld1q_u8 (& xqs [i * QK4_0 + QK4_0 /2 ]);
2863+
2864+ // 4-bit -> 8-bit
2865+ const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_01l , m4b ));
2866+ const int8x16_t v0_0h = vreinterpretq_s8_u8 (vandq_u8 (v0_01h , m4b ));
2867+ const int8x16_t v0_1l = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_01l , 4 ));
2868+ const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_01h , 4 ));
2869+
2870+ // sub 8
2871+ const int8x16_t v0_0ls = vsubq_s8 (v0_0l , s8b );
2872+ const int8x16_t v0_0hs = vsubq_s8 (v0_0h , s8b );
2873+ const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
2874+ const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
2875+
2876+ // load y
2877+ const int8x16_t v1_0l = vld1q_s8 (& yqs [dst0 * QK8_0C ]);
2878+ const int8x16_t v1_0h = vld1q_s8 (& yqs [dst0 * QK8_0C + 16 ]);
2879+ const int8x16_t v1_1l = vld1q_s8 (& yqs [dst1 * QK8_0C ]);
2880+ const int8x16_t v1_1h = vld1q_s8 (& yqs [dst1 * QK8_0C + 16 ]);
2881+
2882+ #if defined(__ARM_FEATURE_DOTPROD )
2883+ // dot product into int32x4_t
2884+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0l ), v0_0hs , v1_0h );
2885+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1l ), v0_1hs , v1_1h );
2886+
2887+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), xds [dst0 ]* yds [dst0 ]);
2888+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), xds [dst1 ]* yds [dst1 ]);
2889+ #else
2890+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0l ));
2891+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0l ));
2892+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0h ));
2893+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0h ));
2894+
2895+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1l ));
2896+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1l ));
2897+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1h ));
2898+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1h ));
2899+
2900+ const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2901+ const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
2902+ const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
2903+ const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
2904+
2905+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (pl0 , ph0 )), xds [dst0 ]* yds [dst0 ]);
2906+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (pl1 , ph1 )), xds [dst1 ]* yds [dst1 ]);
2907+ #endif
2908+ }
2909+
2910+ sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2911+
2912+ #elif defined(__AVX512F__ )
28212913 // Initialize accumulator with zeros
28222914 __m512 acc = _mm512_setzero_ps ();
28232915 for (int i = 0 ; i < nb ; i += 4 ) {
0 commit comments