@@ -186,6 +186,8 @@ typedef struct {
186186static_assert (sizeof (block_q4_K) == 2 *sizeof (ggml_fp16_t ) + 3 *QK_K/64 + QK_K/2 , " wrong q4_K block size/padding" );
187187#endif
188188
189+ #define QR5_K 2
190+ #define QI5_K (QK_K / (4 *QR5_K))
189191#ifdef GGML_QKK_64
190192typedef struct {
191193 half d; // super-block scale
@@ -1531,6 +1533,50 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
15311533#endif // __CUDA_ARCH__ >= 610
15321534}
15331535
1536+ static __device__ __forceinline__ float vec_dot_q5_K_q8_1 (
1537+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1538+
1539+ #if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
1540+ const block_q5_K * bq5_K = (const block_q5_K *) vbq;
1541+
1542+ const int bq8_offset = 2 * (iqs / 8 );
1543+
1544+ float sumf_d = 0 .0f ;
1545+ float sumf_m = 0 ;
1546+
1547+ const float d = bq5_K->d ;
1548+ const float dmin = bq5_K->dmin ;
1549+
1550+ const int vil = *((int *) &bq5_K->qs [sizeof (int ) * iqs]);
1551+
1552+ const int vih = (*((int *) &bq5_K->qh [sizeof (int ) * (iqs % 8 )])) >> bq8_offset;
1553+
1554+ for (int i = 0 ; i < 2 ; ++i) {
1555+ const int isc = bq8_offset + i;
1556+
1557+ uint8_t sc, m;
1558+ get_scale_min_k4 (isc, bq5_K->scales , sc, m);
1559+
1560+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
1561+ const int uii = *((int *) &bq8i->qs [sizeof (int ) * (iqs%8 )]);
1562+ const float d8i = bq8i->d ;
1563+
1564+ const int viil = (vil >> (4 *i)) & 0x0F0F0F0F ;
1565+
1566+ const int viih = ((vih >> i) << 4 ) & 0x10101010 ;
1567+
1568+ const int vii = viil | viih;
1569+
1570+ sumf_d += d8i * (__dp4a (vii, uii, 0 ) * sc);
1571+ sumf_m += d8i * (__dp4a (0x01010101 , uii, 0 ) * m);
1572+ }
1573+
1574+ return d*sumf_d - dmin*sumf_m;
1575+ #else
1576+ return 0 .0f ; // only to satisfy the compiler
1577+ #endif // __CUDA_ARCH__ >= 610
1578+ }
1579+
15341580template <int qk, int qi, typename block_q_t , vec_dot_q_cuda_t vec_dot_q_cuda>
15351581static __global__ void mul_mat_vec_q (const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
15361582 const int row = blockIdx .y *blockDim .y + threadIdx .y ;
@@ -2116,6 +2162,15 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
21162162 <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
21172163}
21182164
2165+ static void mul_mat_vec_q5_K_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
2166+ GGML_ASSERT (ncols % QK_K == 0 );
2167+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
2168+ const dim3 block_nums (1 , block_num_y, 1 );
2169+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
2170+ mul_mat_vec_q<QK_K, QI5_K, block_q5_K, vec_dot_q5_K_q8_1>
2171+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
2172+ }
2173+
21192174static void convert_fp16_to_fp32_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
21202175 const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
21212176 dequantize_block<1 , 1 , convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
@@ -2581,8 +2636,8 @@ inline void ggml_cuda_op_mul_mat_vec(
25812636 src0->type == GGML_TYPE_Q8_0 ||
25822637 src0->type == GGML_TYPE_Q2_K ||
25832638 src0->type == GGML_TYPE_Q3_K ||
2584- src0->type == GGML_TYPE_Q4_K;
2585- // src0->type == GGML_TYPE_Q5_K ||
2639+ src0->type == GGML_TYPE_Q4_K ||
2640+ src0->type == GGML_TYPE_Q5_K;
25862641 // src0->type == GGML_TYPE_Q5_K;
25872642
25882643 const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 610 && mul_mat_vec_q_implemented;
@@ -2620,6 +2675,9 @@ inline void ggml_cuda_op_mul_mat_vec(
26202675 case GGML_TYPE_Q4_K:
26212676 mul_mat_vec_q4_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
26222677 break ;
2678+ case GGML_TYPE_Q5_K:
2679+ mul_mat_vec_q5_K_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
2680+ break ;
26232681 default :
26242682 GGML_ASSERT (false );
26252683 break ;
0 commit comments