@@ -903,13 +903,17 @@ typedef struct {
903903} block_q4_0;
904904static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
905905
906+ #define Q4_1DM (2.0f/15.0f)
907+ #define Q4_1MM (2.0f )
908+ #define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f)
909+ #define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f)
910+
906911#define QK4_1 32
907912typedef struct {
908- ggml_fp16_t d; // delta
909- ggml_fp16_t m; // min
910- uint8_t qs[QK4_1 / 2]; // nibbles / quants
913+ uint16_t dm; // 8-bit delta + 8-bit min (can be adjusted easily)
914+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
911915} block_q4_1;
912- static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t ) + QK4_1 / 2, "wrong q4_1 block size/padding");
916+ static_assert(sizeof(block_q4_1) == sizeof(uint16_t ) + QK4_1 / 2, "wrong q4_1 block size/padding");
913917
914918#define QK5_0 32
915919typedef struct {
@@ -929,7 +933,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
929933
930934#define QK5_1 32
931935typedef struct {
932- uint8_t dm; // 4-bit delta + 4-bit min
936+ uint8_t dm; // 4-bit delta + 4-bit min (can be adjusted easily)
933937 uint8_t qh[4]; // 5-th bit of quants
934938 uint8_t qs[QK5_1 / 2]; // nibbles / quants
935939} block_q5_1;
@@ -1013,11 +1017,17 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
10131017 if (v > max) max = v;
10141018 }
10151019
1016- const float d = (max - min) / ((1 << 4) - 1);
1017- const float id = d ? 1.0f/d : 0.0f;
1020+ y[i].dm = (uint16_t)(floorf((255.0f * (min + 1.0f)) / Q4_1MM)) << 8;
10181021
1019- y[i].d = GGML_FP32_TO_FP16(d);
1020- y[i].m = GGML_FP32_TO_FP16(min);
1022+ min = Q4_1M(y[i].dm);
1023+
1024+ float d = (max - min) / ((1 << 4) - 1);
1025+
1026+ y[i].dm |= (uint16_t)(ceilf((255.0f * d) / Q4_1DM));
1027+
1028+ d = Q4_1D(y[i].dm);
1029+
1030+ const float id = d ? 1.0f/d : 0.0f;
10211031
10221032 for (int j = 0; j < qk/2; ++j) {
10231033 const float x0 = (x[i*qk + 0 + j] - min)*id;
@@ -1570,8 +1580,8 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
15701580 const int nb = k / qk;
15711581
15721582 for (int i = 0; i < nb; i++) {
1573- const float d = GGML_FP16_TO_FP32 (x[i].d );
1574- const float m = GGML_FP16_TO_FP32 (x[i].m );
1583+ const float d = Q4_1D (x[i].dm );
1584+ const float m = Q4_1M (x[i].dm );
15751585
15761586 for (int j = 0; j < qk/2; ++j) {
15771587 const int x0 = (x[i].qs[j] & 0x0F);
@@ -2671,7 +2681,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
26712681 const block_q8_1 * restrict y0 = &y[i + 0];
26722682 const block_q8_1 * restrict y1 = &y[i + 1];
26732683
2674- summs += GGML_FP16_TO_FP32 (x0->m ) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
2684+ summs += Q4_1M (x0->dm ) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
26752685
26762686 const uint8x16_t m4b = vdupq_n_u8(0x0F);
26772687
@@ -2695,8 +2705,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
26952705 const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
26962706 const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
26972707
2698- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32 (x0->d )*y0->d);
2699- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32 (x1->d )*y1->d);
2708+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), Q4_1D (x0->dm )*y0->d);
2709+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), Q4_1D (x1->dm )*y1->d);
27002710#else
27012711 const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
27022712 const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
@@ -2713,8 +2723,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
27132723 const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
27142724 const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
27152725
2716- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32 (x0->d )*y0->d);
2717- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32 (x1->d )*y1->d);
2726+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), Q4_1D (x0->dm )*y0->d);
2727+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), Q4_1D (x1->dm )*y1->d);
27182728#endif
27192729 }
27202730
@@ -2727,10 +2737,10 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
27272737
27282738 // Main loop
27292739 for (int i = 0; i < nb; ++i) {
2730- const float d0 = GGML_FP16_TO_FP32 (x[i].d );
2740+ const float d0 = Q4_1D (x[i].dm );
27312741 const float d1 = y[i].d;
27322742
2733- summs += GGML_FP16_TO_FP32 (x[i].m ) * y[i].s;
2743+ summs += Q4_1M (x[i].dm ) * y[i].s;
27342744
27352745 const __m256 d0v = _mm256_set1_ps( d0 );
27362746 const __m256 d1v = _mm256_set1_ps( d1 );
@@ -2767,7 +2777,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
27672777 sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
27682778 }
27692779
2770- sumf += (GGML_FP16_TO_FP32 (x[i].d )*y[i].d)*sumi + GGML_FP16_TO_FP32 (x[i].m )*y[i].s;
2780+ sumf += (Q4_1D (x[i].dm )*y[i].d)*sumi + Q4_1M (x[i].dm )*y[i].s;
27712781 }
27722782
27732783 *s = sumf;
0 commit comments