@@ -93,6 +93,7 @@ static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong
9393
9494#define QK5_0 32
9595#define QR5_0 2
96+ #define QI5_0 4
9697typedef struct {
9798 half d; // delta
9899 uint8_t qh[4 ]; // 5-th bit of quants
@@ -102,6 +103,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
102103
103104#define QK5_1 32
104105#define QR5_1 2
106+ #define QI5_1 4
105107typedef struct {
106108 half d; // delta
107109 half m; // min
@@ -112,6 +114,7 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
112114
113115#define QK8_0 32
114116#define QR8_0 1
117+ #define QI8_0 4
115118typedef struct {
116119 half d; // delta
117120 int8_t qs[QK8_0]; // quants
@@ -1273,6 +1276,36 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, cons
12731276 return sumi*d;
12741277}
12751278
1279+ static __device__ __forceinline__ float vec_dot_q5_1_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1280+ const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
1281+
1282+ const int qs = *((int *) &bq5_1->qs [sizeof (int ) * (iqs + 0 )]);
1283+ const int qh0 = bq5_1->qh [iqs/2 + 0 ] >> 4 *(iqs%2 );
1284+ const int qh1 = bq5_1->qh [iqs/2 + 2 ] >> 4 *(iqs%2 );
1285+ const int ui0 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
1286+ const int ui1 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + QI4_0)]);
1287+
1288+ const float d = bq5_1->d * bq8_1->d ;
1289+ const float m = bq5_1->m ;
1290+ const float s = bq8_1->s ;
1291+
1292+ int vi0 = (qs >> 0 ) & 0x0F0F0F0F ;
1293+ vi0 |= (qh0 << 4 ) & 0x00000010 ;
1294+ vi0 |= (qh0 << 11 ) & 0x00001000 ;
1295+ vi0 |= (qh0 << 18 ) & 0x00100000 ;
1296+ vi0 |= (qh0 << 25 ) & 0x10000000 ;
1297+ int sumi = __dp4a (vi0, ui0, 0 );
1298+
1299+ int vi1 = (qs >> 4 ) & 0x0F0F0F0F ;
1300+ vi1 |= (qh1 << 4 ) & 0x00000010 ;
1301+ vi1 |= (qh1 << 11 ) & 0x00001000 ;
1302+ vi1 |= (qh1 << 18 ) & 0x00100000 ;
1303+ vi1 |= (qh1 << 25 ) & 0x10000000 ;
1304+ sumi = __dp4a (vi1, ui1, sumi);
1305+
1306+ return sumi*d + m*s / QI5_1;
1307+ }
1308+
12761309template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
12771310static __global__ void dequantize_block (const void * vx, float * y, const int k) {
12781311 const int i = blockDim .x *blockIdx .x + 2 *threadIdx .x ;
@@ -1294,7 +1327,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
12941327 y[iybs + iqs + y_offset] = v.y ;
12951328}
12961329
1297- template <int qk, typename block_q_t , vec_dot_q_cuda_t vec_dot_q_cuda>
1330+ template <int qk, int qi, typename block_q_t , vec_dot_q_cuda_t vec_dot_q_cuda>
12981331static __global__ void mul_mat_vec_q (const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
12991332 const int row = blockIdx .y *blockDim .y + threadIdx .y ;
13001333
@@ -1304,7 +1337,6 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
13041337
13051338 const int blocks_per_row = ncols / qk;
13061339 const int blocks_per_warp = WARP_SIZE * sizeof (int )*2 /qk;
1307- const int ints_per_block = qk / (2 * sizeof (int ));
13081340
13091341// partial sum for each thread
13101342 float tmp = 0 .0f ;
@@ -1313,11 +1345,11 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
13131345 const block_q8_1 * y = (const block_q8_1 *) vy;
13141346
13151347 for (int i = 0 ; i < blocks_per_row; i += blocks_per_warp) {
1316- const int ibx = row*blocks_per_row + i + threadIdx .x /ints_per_block ; // x block index
1348+ const int ibx = row*blocks_per_row + i + threadIdx .x /qi ; // x block index
13171349
1318- const int iby = i + threadIdx .x /ints_per_block ;
1350+ const int iby = i + threadIdx .x /qi ;
13191351
1320- const int iqs = threadIdx .x % ints_per_block ;
1352+ const int iqs = threadIdx .x % qi ;
13211353
13221354 tmp += vec_dot_q_cuda (&x[ibx], &y[iby], iqs);
13231355 }
@@ -1812,7 +1844,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
18121844 const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1 ) / GGML_CUDA_DMMV_Y;
18131845 const dim3 block_nums (1 , block_num_y, 1 );
18141846 const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
1815- mul_mat_vec_q<QK4_0, block_q4_0, vec_dot_q4_0_q8_1>
1847+ mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, vec_dot_q4_0_q8_1>
18161848 <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
18171849}
18181850
@@ -1821,7 +1853,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
18211853 const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1 ) / GGML_CUDA_DMMV_Y;
18221854 const dim3 block_nums (1 , block_num_y, 1 );
18231855 const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
1824- mul_mat_vec_q<QK4_0, block_q4_1, vec_dot_q4_1_q8_1>
1856+ mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, vec_dot_q4_1_q8_1>
18251857 <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
18261858}
18271859
@@ -1830,7 +1862,16 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
18301862 const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1 ) / GGML_CUDA_DMMV_Y;
18311863 const dim3 block_nums (1 , block_num_y, 1 );
18321864 const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
1833- mul_mat_vec_q<QK5_0, block_q5_0, vec_dot_q5_0_q8_1>
1865+ mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, vec_dot_q5_0_q8_1>
1866+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
1867+ }
1868+
1869+ static void mul_mat_vec_q5_1_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1870+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
1871+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1 ) / GGML_CUDA_DMMV_Y;
1872+ const dim3 block_nums (1 , block_num_y, 1 );
1873+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
1874+ mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, vec_dot_q5_1_q8_1>
18341875 <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
18351876}
18361877
@@ -2360,6 +2401,9 @@ inline void ggml_cuda_op_mul_mat_vec_q(
23602401 case GGML_TYPE_Q5_0:
23612402 mul_mat_vec_q5_0_q8_1_cuda (src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
23622403 break ;
2404+ case GGML_TYPE_Q5_1:
2405+ mul_mat_vec_q5_1_q8_1_cuda (src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
2406+ break ;
23632407 default :
23642408 GGML_ASSERT (false );
23652409 break ;
@@ -2916,7 +2960,8 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
29162960 } else if (ggml_is_quantized (src0->type ) || src0->type == GGML_TYPE_F16) {
29172961 if (src1->ne [1 ] == 1 && src0->ne [0 ] % GGML_CUDA_DMMV_X == 0 && src0->ne [1 ] % GGML_CUDA_DMMV_Y == 0 ) {
29182962 bool use_mul_mat_vec_q = false ;
2919- use_mul_mat_vec_q = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0;
2963+ use_mul_mat_vec_q = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1
2964+ || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1;
29202965 if (use_mul_mat_vec_q) {
29212966 ggml_cuda_op (src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, false , false );
29222967 } else {
0 commit comments