@@ -207,6 +207,8 @@ typedef struct {
207207static_assert (sizeof (block_q5_K) == 2 *sizeof (ggml_fp16_t ) + K_SCALE_SIZE + QK_K/2 + QK_K/8 , " wrong q5_K block size/padding" );
208208#endif
209209
210+ #define QR6_K 2
211+ #define QI6_K (QK_K / (4 *QR6_K))
210212typedef struct {
211213 uint8_t ql[QK_K/2 ]; // quants, lower 4 bits
212214 uint8_t qh[QK_K/4 ]; // quants, upper 2 bits
@@ -1577,6 +1579,47 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
15771579#endif // __CUDA_ARCH__ >= 610
15781580}
15791581
1582+ static __device__ __forceinline__ float vec_dot_q6_K_q8_1 (
1583+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1584+
1585+ #if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
1586+ const block_q6_K * bq6_K = (const block_q6_K *) vbq;
1587+
1588+ const int bq8_offset = 4 * (iqs / 16 ) + (iqs % 16 ) / 8 ;
1589+
1590+ float sumf = 0 .0f ;
1591+
1592+ const float d = bq6_K->d ;
1593+
1594+ int vil;
1595+ memcpy (&vil, &bq6_K->ql [sizeof (int ) * iqs], sizeof (int ));
1596+
1597+ int vih;
1598+ memcpy (&vih, &bq6_K->qh [sizeof (int ) * (8 * (iqs / 16 ) + iqs % 8 )], sizeof (int ));
1599+
1600+ for (int i = 0 ; i < 2 ; ++i) {
1601+ const int isc = 8 * (iqs / 16 ) + (iqs % 16 ) / 4 + 4 *i;
1602+ const int sc = bq6_K->scales [isc];
1603+
1604+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2 *i;
1605+ const int uii = *((int *) &bq8i->qs [sizeof (int ) * (iqs%8 )]);
1606+ const float d8i = bq8i->d ;
1607+
1608+ const int viil = (vil >> (4 *i)) & 0x0F0F0F0F ;
1609+
1610+ const int viih = ((vih >> (2 *((iqs%16 )/8 ) + 4 *i)) << 4 ) & 0x30303030 ;
1611+
1612+ const int vii = __vsubss4 ((viil | viih), 0x20202020 );
1613+
1614+ sumf += d8i * (__dp4a (vii, uii, 0 ) * sc);
1615+ }
1616+
1617+ return d*sumf;
1618+ #else
1619+ return 0 .0f ; // only to satisfy the compiler
1620+ #endif // __CUDA_ARCH__ >= 610
1621+ }
1622+
15801623template <int qk, int qi, typename block_q_t , vec_dot_q_cuda_t vec_dot_q_cuda>
15811624static __global__ void mul_mat_vec_q (const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
15821625 const int row = blockIdx .y *blockDim .y + threadIdx .y ;
@@ -2171,6 +2214,15 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
21712214 <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
21722215}
21732216
2217+ static void mul_mat_vec_q6_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
2218+ GGML_ASSERT (ncols % QK_K == 0 );
2219+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
2220+ const dim3 block_nums (1 , block_num_y, 1 );
2221+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
2222+ mul_mat_vec_q<QK_K, QI6_K, block_q6_K, vec_dot_q6_K_q8_1>
2223+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
2224+ }
2225+
21742226static void convert_fp16_to_fp32_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
21752227 const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
21762228 dequantize_block<1 , 1 , convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
@@ -2637,8 +2689,8 @@ inline void ggml_cuda_op_mul_mat_vec(
26372689 src0->type == GGML_TYPE_Q2_K ||
26382690 src0->type == GGML_TYPE_Q3_K ||
26392691 src0->type == GGML_TYPE_Q4_K ||
2640- src0->type == GGML_TYPE_Q5_K;
2641- // src0->type == GGML_TYPE_Q5_K ;
2692+ src0->type == GGML_TYPE_Q5_K ||
2693+ src0->type == GGML_TYPE_Q6_K ;
26422694
26432695 const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 610 && mul_mat_vec_q_implemented;
26442696#endif
@@ -2678,6 +2730,9 @@ inline void ggml_cuda_op_mul_mat_vec(
26782730 case GGML_TYPE_Q5_K:
26792731 mul_mat_vec_q5_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
26802732 break ;
2733+ case GGML_TYPE_Q6_K:
2734+ mul_mat_vec_q6_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
2735+ break ;
26812736 default :
26822737 GGML_ASSERT (false );
26832738 break ;
0 commit comments