@@ -127,8 +127,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
127127#define QR8_1 1
128128#define QI8_1 (QK8_1 / (4 * QR8_1))
129129typedef struct {
130- half d; // delta
131- half s; // unquantized sum
130+ half2 ds; // ds.x = delta, ds.y = sum
132131 int8_t qs[QK8_0]; // quants
133132} block_q8_1;
134133static_assert (sizeof (block_q8_1) == 2 *sizeof (ggml_fp16_t ) + QK8_0, " wrong q8_1 block size/padding" );
@@ -1258,8 +1257,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
12581257 return ;
12591258 }
12601259
1261- y[ib].d = d;
1262- y[ib].s = sum;
1260+ y[ib].ds . x = d;
1261+ y[ib].ds . y = sum;
12631262}
12641263
12651264template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -1284,18 +1283,18 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
12841283}
12851284
12861285static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl (
1287- const int & vi, const int & ui0, const int & ui1, const float & d4, const float & d8 ) {
1286+ const int & vi, const int & ui0, const int & ui1, const half & d4, const half2 & ds8 ) {
12881287
12891288#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
12901289 // subtract 8 from each quantized value
1291- const int vi0 = __vsub4 (( vi >> 0 ) & 0x0F0F0F0F , 0x08080808 ) ;
1292- const int vi1 = __vsub4 (( vi >> 4 ) & 0x0F0F0F0F , 0x08080808 ) ;
1290+ const int vi0 = ( vi >> 0 ) & 0x0F0F0F0F ;
1291+ const int vi1 = ( vi >> 4 ) & 0x0F0F0F0F ;
12931292
12941293 // SIMD dot product of quantized values
12951294 int sumi = __dp4a (vi0, ui0, 0 );
12961295 sumi = __dp4a (vi1, ui1, sumi);
12971296
1298- return sumi*d4*d8 ;
1297+ return __half2float (d4) * ( sumi * __half2float (ds8. x ) - ( 8 /QI4_0) * __half2float (ds8. y )) ;
12991298#else
13001299 return 0 .0f ; // only to satisfy the compiler
13011300#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1311,7 +1310,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
13111310 const int ui0 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
13121311 const int ui1 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + QI4_0)]);
13131312
1314- return vec_dot_q4_0_q8_1_impl (vi, ui0, ui1, __half2float ( bq4_0->d ), __half2float ( bq8_1->d ) );
1313+ return vec_dot_q4_0_q8_1_impl (vi, ui0, ui1, bq4_0->d , bq8_1->ds );
13151314}
13161315
13171316static __device__ __forceinline__ float vec_dot_q4_1_q8_1 (
@@ -1324,9 +1323,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
13241323 const int ui0 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
13251324 const int ui1 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + QI4_1)]);
13261325
1327- const float d = __half2float (bq4_1->d ) * __half2float (bq8_1->d );
1326+ const float d = __half2float (bq4_1->d ) * __half2float (bq8_1->ds . x );
13281327 const float m = bq4_1->m ;
1329- const float s = bq8_1->s ;
1328+ const float s = bq8_1->ds . y ;
13301329
13311330 const int vi0 = (vi >> 0 ) & 0x0F0F0F0F ;
13321331 const int vi1 = (vi >> 4 ) & 0x0F0F0F0F ;
@@ -1354,7 +1353,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
13541353 const int ui0 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
13551354 const int ui1 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + QI5_0)]);
13561355
1357- const float d = __half2float (bq5_0->d ) * __half2float (bq8_1->d );
1356+ const float d = __half2float (bq5_0->d ) * __half2float (bq8_1->ds . x );
13581357
13591358 int vi0 = (qs >> 0 ) & 0x0F0F0F0F ; // lower 4 qs bits, still need qh0 as 5th bits
13601359 vi0 |= (qh0 << 4 ) & 0x00000010 ; // 1 -> 5
@@ -1390,9 +1389,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
13901389 const int ui0 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
13911390 const int ui1 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + QI5_1)]);
13921391
1393- const float d = __half2float (bq5_1->d ) * __half2float (bq8_1->d );
1392+ const float d = __half2float (bq5_1->d ) * __half2float (bq8_1->ds . x );
13941393 const float m = bq5_1->m ;
1395- const float s = bq8_1->s ;
1394+ const float s = bq8_1->ds . y ;
13961395
13971396 int vi0 = (qs >> 0 ) & 0x0F0F0F0F ; // lower 4 qs bits, still need qh0 as 5th bits
13981397 vi0 |= (qh0 << 4 ) & 0x00000010 ; // 1 -> 5
@@ -1424,7 +1423,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
14241423 memcpy (&vi, &bq8_0->qs [sizeof (int ) * (iqs + 0 )], sizeof (int ));
14251424 const int ui = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
14261425
1427- const float d = __half2float (bq8_0->d ) * __half2float (bq8_1->d );
1426+ const float d = __half2float (bq8_0->d ) * __half2float (bq8_1->ds . x );
14281427
14291428 // SIMD dot product of quantized values
14301429 int sumi = __dp4a (vi, ui, 0 );
@@ -1456,7 +1455,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
14561455 const int sc = bq2_K->scales [scale_offset + 2 *i];
14571456
14581457 const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
1459- const float d8i = bq8i->d ;
1458+ const float d8i = bq8i->ds . x ;
14601459
14611460 const int vi = (v >> (2 *i)) & 0x03030303 ;
14621461 const int ui = *((int *) &bq8i->qs [sizeof (int ) * (iqs % QI8_1)]);
@@ -1507,7 +1506,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
15071506
15081507 const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
15091508 const int ui = *((int *) &bq8i->qs [sizeof (int ) * (iqs % QI8_1)]);
1510- const float d8i = bq8i->d ;
1509+ const float d8i = bq8i->ds . x ;
15111510
15121511 const int vil = (vl >> (2 *i)) & 0x03030303 ;
15131512
@@ -1548,7 +1547,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
15481547
15491548 const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
15501549 const int ui = *((int *) &bq8i->qs [sizeof (int ) * (iqs % QI8_1)]);
1551- const float d8i = bq8i->d ;
1550+ const float d8i = bq8i->ds . x ;
15521551
15531552 const int vi = (v >> (4 *i)) & 0x0F0F0F0F ;
15541553
@@ -1588,7 +1587,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
15881587
15891588 const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
15901589 const int ui = *((int *) &bq8i->qs [sizeof (int ) * (iqs % QI8_1)]);
1591- const float d8i = bq8i->d ;
1590+ const float d8i = bq8i->ds . x ;
15921591
15931592 const int vil = (vl >> (4 *i)) & 0x0F0F0F0F ;
15941593
@@ -1631,7 +1630,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
16311630
16321631 const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2 *i;
16331632 const int ui = *((int *) &bq8i->qs [sizeof (int ) * (iqs % (QI8_1))]);
1634- const float d8i = bq8i->d ;
1633+ const float d8i = bq8i->ds . x ;
16351634
16361635 const int vil = (vl >> (4 *i)) & 0x0F0F0F0F ;
16371636
@@ -1673,7 +1672,7 @@ static __global__ void mul_mat_q(
16731672 __shared__ int tile_x_qs[WARP_SIZE][WARP_SIZE + 1 ];
16741673 __shared__ half tile_x_d[WARP_SIZE][WARP_SIZE/QI4_0];
16751674 __shared__ int tile_y_qs[WARP_SIZE][2 *WARP_SIZE];
1676- __shared__ half tile_y_d [WARP_SIZE][2 *WARP_SIZE/QI8_1];
1675+ __shared__ half2 tile_y_ds [WARP_SIZE][2 *WARP_SIZE/QI8_1];
16771676 float sum[4 ] = {0 .0f };
16781677
16791678 for (int ib0 = 0 ; ib0 < blocks_per_row; ib0 += blocks_per_warp) {
@@ -1694,12 +1693,12 @@ static __global__ void mul_mat_q(
16941693 const block_q8_1 * __restrict__ by0 = &y[(col_y_0 + tid_y + i)*blocks_per_row + ib0 + iby0];
16951694
16961695 tile_y_qs[tid_y + i][tid_x] = *((int *) &by0->qs [iqsy]);
1697- tile_y_d [tid_y + i][iby0] = by0->d ;
1696+ tile_y_ds [tid_y + i][iby0] = by0->ds ;
16981697
16991698 const block_q8_1 * __restrict__ by1 = &y[(col_y_0 + tid_y + i)*blocks_per_row + ib0 + iby1];
17001699
17011700 tile_y_qs[tid_y + i][tid_x + WARP_SIZE] = *((int *) &by1->qs [iqsy]);
1702- tile_y_d [tid_y + i][iby1] = by1->d ;
1701+ tile_y_ds [tid_y + i][iby1] = by1->ds ;
17031702 }
17041703
17051704 __syncthreads ();
@@ -1709,7 +1708,7 @@ static __global__ void mul_mat_q(
17091708 for (int j = 0 ; j < WARP_SIZE; j += 8 ) {
17101709 sum[j/8 ] += vec_dot_q4_0_q8_1_impl (
17111710 tile_x_qs[tid_x][k], tile_y_qs[tid_y + j][iqsy + 0 ], tile_y_qs[tid_y + j][iqsy + (QI8_1/2 )],
1712- tile_x_d[tid_x][k / QI4_0], tile_y_d [tid_y + j][2 * k / QI8_1]);
1711+ tile_x_d[tid_x][k / QI4_0], tile_y_ds [tid_y + j][2 * k / QI8_1]);
17131712 }
17141713 }
17151714
0 commit comments