@@ -931,6 +931,101 @@ inline static float vaddvq_f32(float32x4_t v) {
931931 #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
932932#endif
933933
934+ #elif defined(__AVX512F__)
935+
936+ #define GGML_SIMD
937+
938+ // F32 AVX512
939+
940+ #define GGML_F32_STEP 64
941+ #define GGML_F32_EPR 16
942+
943+ #define GGML_F32x16 __m512
944+ #define GGML_F32x16_ZERO _mm512_setzero_ps()
945+ #define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
946+ #define GGML_F32x16_LOAD _mm512_loadu_ps
947+ #define GGML_F32x16_STORE _mm512_storeu_ps
948+ // _mm512_fmadd_ps is defined in AVX512F so no guard is required
949+ #define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
950+ #define GGML_F32x16_ADD _mm512_add_ps
951+ #define GGML_F32x16_MUL _mm512_mul_ps
952+ #define GGML_F32x16_REDUCE(res, x) \
953+ do { \
954+ int offset = GGML_F32_ARR >> 1; \
955+ for (int i = 0; i < offset; ++i) { \
956+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
957+ } \
958+ offset >>= 1; \
959+ for (int i = 0; i < offset; ++i) { \
960+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
961+ } \
962+ offset >>= 1; \
963+ for (int i = 0; i < offset; ++i) { \
964+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
965+ } \
966+ res = _mm512_reduce_add_ps(x[0]); \
967+ } while (0)
968+
969+ // TODO: is this optimal ?
970+
971+ #define GGML_F32_VEC GGML_F32x16
972+ #define GGML_F32_VEC_ZERO GGML_F32x16_ZERO
973+ #define GGML_F32_VEC_SET1 GGML_F32x16_SET1
974+ #define GGML_F32_VEC_LOAD GGML_F32x16_LOAD
975+ #define GGML_F32_VEC_STORE GGML_F32x16_STORE
976+ #define GGML_F32_VEC_FMA GGML_F32x16_FMA
977+ #define GGML_F32_VEC_ADD GGML_F32x16_ADD
978+ #define GGML_F32_VEC_MUL GGML_F32x16_MUL
979+ #define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
980+
981+ // F16 AVX512
982+
983+ // F16 AVX
984+
985+ #define GGML_F16_STEP 64
986+ #define GGML_F16_EPR 16
987+
988+ // AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
989+
990+ #define GGML_F32Cx16 __m512
991+ #define GGML_F32Cx16_ZERO _mm512_setzero_ps()
992+ #define GGML_F32Cx16_SET1(x) _mm512_set1_ps(x)
993+
994+ // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
995+ // so F16C guard isn't required
996+ #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
997+ #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
998+
999+ #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
1000+ #define GGML_F32Cx16_ADD _mm512_add_ps
1001+ #define GGML_F32Cx16_MUL _mm512_mul_ps
1002+ #define GGML_F32Cx16_REDUCE(res, x) \
1003+ do { \
1004+ int offset = GGML_F32_ARR >> 1; \
1005+ for (int i = 0; i < offset; ++i) { \
1006+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
1007+ } \
1008+ offset >>= 1; \
1009+ for (int i = 0; i < offset; ++i) { \
1010+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
1011+ } \
1012+ offset >>= 1; \
1013+ for (int i = 0; i < offset; ++i) { \
1014+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
1015+ } \
1016+ res = _mm512_reduce_add_ps(x[0]); \
1017+ } while (0)
1018+
1019+ #define GGML_F16_VEC GGML_F32Cx16
1020+ #define GGML_F16_VEC_ZERO GGML_F32Cx16_ZERO
1021+ #define GGML_F16_VEC_SET1 GGML_F32Cx16_SET1
1022+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx16_LOAD(p)
1023+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
1024+ #define GGML_F16_VEC_FMA GGML_F32Cx16_FMA
1025+ #define GGML_F16_VEC_ADD GGML_F32Cx16_ADD
1026+ #define GGML_F16_VEC_MUL GGML_F32Cx16_MUL
1027+ #define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE
1028+
9341029#elif defined(__AVX__)
9351030
9361031#define GGML_SIMD
0 commit comments