@@ -260,6 +260,101 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
260260#endif
261261}
262262
263+ void ggml_vec_dot_mxfp4_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
264+ assert (nrc == 1 );
265+ UNUSED (nrc );
266+ UNUSED (bx );
267+ UNUSED (by );
268+ UNUSED (bs );
269+ assert (n % QK_MXFP4 == 0 );
270+ static_assert (QK_MXFP4 == QK8_0 , "QK_MXFP4 and QK8_0 must be the same" );
271+
272+ const int qk = QK_MXFP4 ;
273+ const int nb = n / qk ;
274+
275+ const block_mxfp4 * GGML_RESTRICT x = vx ;
276+ const block_q8_0 * GGML_RESTRICT y = vy ;
277+
278+ int ib = 0 ;
279+ float sumf = 0.0f ;
280+
281+ #if defined(__VXE__ ) || defined(__VXE2__ )
282+ const int8x16_t v_k = vec_xl (0 , kvalues_mxfp4 );
283+ const uint8x16_t v_m = vec_splats ((const uint8_t )0x0F );
284+
285+ float32x4_t v_acc = vec_splats (0.0f );
286+
287+ #pragma GCC unroll 8
288+ for (; ib + 1 < nb ; ib += 2 ) {
289+ const block_mxfp4 * GGML_RESTRICT x0 = & x [ib + 0 ];
290+ const block_mxfp4 * GGML_RESTRICT x1 = & x [ib + 1 ];
291+ const block_q8_0 * GGML_RESTRICT y0 = & y [ib + 0 ];
292+ const block_q8_0 * GGML_RESTRICT y1 = & y [ib + 1 ];
293+
294+ const uint8x16_t v_x0 = vec_xl (0 , x0 -> qs );
295+ const uint8x16_t v_x1 = vec_xl (0 , x1 -> qs );
296+
297+ int8x16_t v_x0l = (int8x16_t )vec_and (v_x0 , v_m );
298+ int8x16_t v_x0h = (int8x16_t )vec_sr (v_x0 , 4 );
299+ int8x16_t v_x1l = (int8x16_t )vec_and (v_x1 , v_m );
300+ int8x16_t v_x1h = (int8x16_t )vec_sr (v_x1 , 4 );
301+
302+ v_x0l = vec_perm (v_k , v_k , (uchar8x16_t )v_x0l );
303+ v_x0h = vec_perm (v_k , v_k , (uchar8x16_t )v_x0h );
304+ v_x1l = vec_perm (v_k , v_k , (uchar8x16_t )v_x1l );
305+ v_x1h = vec_perm (v_k , v_k , (uchar8x16_t )v_x1h );
306+
307+ const int8x16_t v_y0l = vec_xl (0 , y0 -> qs );
308+ const int8x16_t v_y0h = vec_xl (QK8_0 /2 , y0 -> qs );
309+ const int8x16_t v_y1l = vec_xl (0 , y1 -> qs );
310+ const int8x16_t v_y1h = vec_xl (QK8_0 /2 , y1 -> qs );
311+
312+ const int32x4_t v_xy0 = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_x0l , v_y0l ), v_x0h , v_y0h );
313+ const int32x4_t v_xy1 = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_x1l , v_y1l ), v_x1h , v_y1h );
314+
315+ const float32x4_t v_xy0f = vec_float (v_xy0 );
316+ const float32x4_t v_xy1f = vec_float (v_xy1 );
317+
318+ const float32x4_t v_d0 = vec_splats (GGML_E8M0_TO_FP32_HALF (x0 -> e ) * GGML_CPU_FP16_TO_FP32 (y0 -> d ));
319+ const float32x4_t v_d1 = vec_splats (GGML_E8M0_TO_FP32_HALF (x1 -> e ) * GGML_CPU_FP16_TO_FP32 (y1 -> d ));
320+
321+ v_acc = vec_madd (v_xy0f , v_d0 , v_acc );
322+ v_acc = vec_madd (v_xy1f , v_d1 , v_acc );
323+ }
324+
325+ for (; ib < nb ; ++ ib ) {
326+ const block_mxfp4 * GGML_RESTRICT x0 = & x [ib + 0 ];
327+ const block_q8_0 * GGML_RESTRICT y0 = & y [ib + 0 ];
328+
329+ const uint8x16_t v_x = vec_xl (0 , x0 -> qs );
330+
331+ int8x16_t v_xl = (int8x16_t )vec_and (v_x , v_m );
332+ int8x16_t v_xh = (int8x16_t )vec_sr (v_x , 4 );
333+
334+ v_xl = vec_perm (v_k , v_k , (uchar8x16_t )v_xl );
335+ v_xh = vec_perm (v_k , v_k , (uchar8x16_t )v_xh );
336+
337+ const int8x16_t v_yl = vec_xl (0 , y0 -> qs );
338+ const int8x16_t v_yh = vec_xl (QK8_0 /2 , y0 -> qs );
339+
340+ const int32x4_t v_xy = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_xl , v_yl ), v_xh , v_yh );
341+ const float32x4_t v_xyf = vec_float (v_xy );
342+
343+ const float32x4_t v_d = vec_splats (GGML_E8M0_TO_FP32_HALF (x0 -> e ) * GGML_CPU_FP16_TO_FP32 (y0 -> d ));
344+ v_acc = vec_madd (v_xyf , v_d , v_acc );
345+ }
346+
347+ sumf = vec_hsum_f32x4 (v_acc );
348+ * s = sumf ;
349+ #else
350+ UNUSED (x );
351+ UNUSED (y );
352+ UNUSED (ib );
353+ UNUSED (sumf );
354+ ggml_vec_dot_mxfp4_q8_0_generic (n , s , bs , vx , bx , vy , by , nrc );
355+ #endif
356+ }
357+
263358void ggml_vec_dot_q5_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
264359 const int qk = QK8_0 ;
265360 const int nb = n / qk ;
0 commit comments