@@ -4003,42 +4003,141 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
40034003 float sumf = 0;
40044004
40054005#if defined(__ARM_FEATURE_SVE)
4006- if (ggml_sve_cnt_b == QK8_0) {
4007- const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
4008- const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
4006+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
4007+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
40094008
4010- svfloat32_t sumv0 = svdup_n_f32(0.0f);
4011- svfloat32_t sumv1 = svdup_n_f32(0.0f);
4009+ const int vector_length = ggml_sve_cnt_b*8;
40124010
4013- for (; ib + 1 < nb; ib += 2) {
4014- const block_q4_0 * restrict x0 = &x[ib + 0];
4015- const block_q4_0 * restrict x1 = &x[ib + 1];
4016- const block_q8_0 * restrict y0 = &y[ib + 0];
4017- const block_q8_0 * restrict y1 = &y[ib + 1];
4018-
4019- // load x
4020- const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4021- const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4022-
4023- // 4-bit -> 8-bit
4024- const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
4025- const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
4026-
4027- // sub 8
4028- const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
4029- const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
4011+ // VLA Implementation using switch case
4012+ switch (vector_length) {
4013+ case 128:
4014+ {
4015+ // predicate for activating higher lanes for 4 float32 elements
4016+ const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
4017+
4018+ for (; ib + 1 < nb; ib += 2) {
4019+ const block_q4_0 * restrict x0 = &x[ib + 0];
4020+ const block_q4_0 * restrict x1 = &x[ib + 1];
4021+ const block_q8_0 * restrict y0 = &y[ib + 0];
4022+ const block_q8_0 * restrict y1 = &y[ib + 1];
4023+
4024+ // load x
4025+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4026+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4027+
4028+ // 4-bit -> 8-bit
4029+ const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
4030+ const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
4031+ const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
4032+ const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
4033+
4034+ // sub 8
4035+ const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
4036+ const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
4037+ const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
4038+ const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
4039+
4040+ // load y
4041+ const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
4042+ const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
4043+ const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
4044+ const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
4045+
4046+ // dot product
4047+ sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
4048+ svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
4049+ svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4050+ sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
4051+ svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
4052+ svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4053+ }
40304054
4031- // load y
4032- const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
4033- const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
4055+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4056+ } break;
4057+ case 256:
4058+ {
4059+ // predicate for activating higher lanes for 16 int8 elements
4060+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
4061+ // predicate for activating lower lanes for 16 int8 elements
4062+ const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
4063+
4064+ for (; ib + 1 < nb; ib += 2) {
4065+ const block_q4_0 * restrict x0 = &x[ib + 0];
4066+ const block_q4_0 * restrict x1 = &x[ib + 1];
4067+ const block_q8_0 * restrict y0 = &y[ib + 0];
4068+ const block_q8_0 * restrict y1 = &y[ib + 1];
4069+
4070+ // load x
4071+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4072+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4073+
4074+ // 4-bit -> 8-bit
4075+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
4076+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
4077+
4078+ // sub 8
4079+ const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
4080+ const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
4081+
4082+ // load y
4083+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
4084+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
4085+
4086+ // dot product
4087+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
4088+ svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4089+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
4090+ svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4091+ }
40344092
4035- // dot product
4036- sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4037- sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4038- }
4093+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4094+ } break;
4095+ case 512:
4096+ {
4097+ // predicate for activating higher lanes for 32 int8 elements
4098+ const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
4099+
4100+ // predicate for activating higher lanes for 16 int8 elements
4101+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
4102+ // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
4103+ const svbool_t pl16 = svnot_b_z(ph32, ph16);
4104+
4105+ for (; ib + 1 < nb; ib += 2) {
4106+ const block_q4_0 * restrict x0 = &x[ib + 0];
4107+ const block_q4_0 * restrict x1 = &x[ib + 1];
4108+ const block_q8_0 * restrict y0 = &y[ib + 0];
4109+ const block_q8_0 * restrict y1 = &y[ib + 1];
4110+
4111+ // load x
4112+ const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
4113+ const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
4114+
4115+ // 4-bit -> 8-bit
4116+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
4117+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
4118+
4119+ // sub 8
4120+ const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
4121+ const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
4122+
4123+ // load y
4124+ const svint8_t qy0 = svld1_s8(ph32, y0->qs);
4125+ const svint8_t qy1 = svld1_s8(ph32, y1->qs);
4126+
4127+ // dot product
4128+ sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
4129+ svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
4130+ sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
4131+ svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4132+ }
40394133
4040- sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4134+ sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
4135+ } break;
4136+ default:
4137+ assert(false && "Unsupported vector length");
4138+ break;
40414139 }
4140+
40424141#elif defined(__ARM_NEON)
40434142 float32x4_t sumv0 = vdupq_n_f32(0.0f);
40444143 float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -5488,29 +5587,124 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
54885587 float sumf = 0;
54895588
54905589#if defined(__ARM_FEATURE_SVE)
5491- if (ggml_sve_cnt_b == QK8_0) {
5492- svfloat32_t sumv0 = svdup_n_f32(0.0f);
5493- svfloat32_t sumv1 = svdup_n_f32(0.0f);
5590+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
5591+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
54945592
5495- for (; ib + 1 < nb; ib += 2) {
5496- const block_q8_0 * restrict x0 = &x[ib + 0];
5497- const block_q8_0 * restrict x1 = &x[ib + 1];
5498- const block_q8_0 * restrict y0 = &y[ib + 0];
5499- const block_q8_0 * restrict y1 = &y[ib + 1];
5593+ const int vector_length = ggml_sve_cnt_b*8;
55005594
5501- // load x
5502- const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5503- const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5595+ //VLA Implemenation for SVE
5596+ switch (vector_length) {
5597+ case 128:
5598+ {
5599+ // predicate for activating lanes for 16 Int8 elements
5600+ const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
5601+ const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
5602+
5603+ for (; ib + 1 < nb; ib += 2) {
5604+ const block_q8_0 * restrict x0 = &x[ib + 0];
5605+ const block_q8_0 * restrict x1 = &x[ib + 1];
5606+ const block_q8_0 * restrict y0 = &y[ib + 0];
5607+ const block_q8_0 * restrict y1 = &y[ib + 1];
5608+
5609+ // load x
5610+ const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
5611+ const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
5612+ const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
5613+ const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
5614+
5615+ // load y
5616+ const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
5617+ const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
5618+ const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
5619+ const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
5620+
5621+ sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
5622+ svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
5623+ svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5624+ sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
5625+ svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
5626+ svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5627+ }
55045628
5505- // load y
5506- const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5507- const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5629+ sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
5630+ } break;
5631+ case 256:
5632+ {
5633+ //printf("sve256");
5634+ for (; ib + 1 < nb; ib += 2) {
5635+ const block_q8_0 * restrict x0 = &x[ib + 0];
5636+ const block_q8_0 * restrict x1 = &x[ib + 1];
5637+ const block_q8_0 * restrict y0 = &y[ib + 0];
5638+ const block_q8_0 * restrict y1 = &y[ib + 1];
5639+
5640+ // load x
5641+ const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5642+ const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5643+
5644+ // load y
5645+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5646+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5647+
5648+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
5649+ svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5650+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
5651+ svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5652+ }
55085653
5509- sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5510- sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5511- }
5654+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5655+ } break;
5656+ case 512:
5657+ {
5658+ // predicate for activating high 256 bit
5659+ const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
5660+ // predicate for activating low 256 bit
5661+ const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
5662+
5663+ // predicate for activating high lanes for 8 float32 elements
5664+ const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
5665+ // predicate for activating low lanes for 8 float32 elements
5666+ const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
5667+
5668+ svfloat32_t sumv00 = svdup_n_f32(0.0f);
5669+
5670+ for (; ib + 1 < nb; ib += 2) {
5671+ const block_q8_0 * restrict x0 = &x[ib + 0];
5672+ const block_q8_0 * restrict x1 = &x[ib + 1];
5673+ const block_q8_0 * restrict y0 = &y[ib + 0];
5674+ const block_q8_0 * restrict y1 = &y[ib + 1];
5675+
5676+ //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
5677+ // and add them to make one 64 element vector
5678+ // load x
5679+ const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
5680+ svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
5681+
5682+ qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
55125683
5513- sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5684+ // load y
5685+ const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
5686+ svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
5687+
5688+ qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
5689+
5690+ // scale creation
5691+ const float32_t deq1 = GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d);
5692+ const float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d);
5693+
5694+ // duplicate deq1 in first half of vector and deq2 in second half of vector
5695+ const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
5696+
5697+ const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
5698+
5699+ sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
5700+ }
5701+
5702+ sumf = svaddv_f32(svptrue_b32(), sumv00);
5703+ break;
5704+ }
5705+ default:
5706+ assert(false && "Unsupported vector length");
5707+ break;
55145708 }
55155709#elif defined(__ARM_NEON)
55165710 float32x4_t sumv0 = vdupq_n_f32(0.0f);
0 commit comments