@@ -618,6 +618,8 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
618618 const int nb3 = dst->nb [3 ];
619619 const ggml_type type = src0->type ;
620620
621+ const bool can_dmmv = ggml_cuda_can_dequantize_mul_mat_vec (src0, src1, dst);
622+
621623 const float alpha = 1 .0f ;
622624 const float beta = 0 .0f ;
623625 const int x_ne = ne01 * ne00;
@@ -628,7 +630,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
628630
629631 size_t x_size, y_size, d_size, q_size;
630632 float * d_X;
631- if (ne11 > 1 ) {
633+ if (!can_dmmv ) {
632634 d_X = (float *) ggml_cuda_pool_malloc (n_mm * sizeof (float ) * x_ne, &x_size);
633635 }
634636 float * d_Y = (float *) ggml_cuda_pool_malloc (n_mm * sizeof (float ) * y_ne, &y_size);
@@ -658,7 +660,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
658660 } else {
659661 GGML_ASSERT (false );
660662 }
661- if (ne11 == 1 ) {
663+ if (can_dmmv ) { // specialized dequantize_mul_mat_vec kernel
662664 CUDA_CHECK (cudaEventRecord (cudaEvent, cudaStream2));
663665
664666 // copy src1 to device
@@ -671,7 +673,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
671673 dmmv (c_Q, c_Y, c_D, ne00, ne01, cudaStream);
672674 CUDA_CHECK (cudaGetLastError ());
673675
674- } else {
676+ } else { // general matrix matrix multiplication
675677 float * c_X = d_X + i * x_ne;
676678
677679 // convert src0 to fp32 on device
@@ -702,7 +704,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
702704 }
703705
704706 CUDA_CHECK (cudaDeviceSynchronize ());
705- if (ne11 > 1 ) {
707+ if (!can_dmmv ) {
706708 ggml_cuda_pool_free (d_X, x_size);
707709 }
708710 ggml_cuda_pool_free (d_Y, y_size);
@@ -720,13 +722,26 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
720722 if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) &&
721723 src1->type == GGML_TYPE_F32 &&
722724 dst->type == GGML_TYPE_F32 &&
723- (( ne0 >= 32 && ne1 >= 32 && ne10 >= 32 ) || src0-> backend == GGML_BACKEND_CUDA )) {
725+ (ne0 >= 32 && ne1 >= 32 && ne10 >= 32 )) {
724726 return true ;
725727 }
726728
727729 return false ;
728730}
729731
732+ bool ggml_cuda_can_dequantize_mul_mat_vec (const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
733+ const int64_t ne11 = src1->ne [1 ];
734+ const int64_t ne12 = src1->ne [2 ];
735+ const int64_t ne13 = src1->ne [3 ];
736+ if (ggml_is_quantized (src0->type ) &&
737+ src1->type == GGML_TYPE_F32 &&
738+ dst->type == GGML_TYPE_F32 &&
739+ (ne11 == 1 && ne12 == 1 && ne13 == 1 )) {
740+ return true ;
741+ }
742+ return false ;
743+ }
744+
730745bool ggml_cuda_mul_mat_use_f16 (const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */ ) {
731746 size_t src0_sz = ggml_nbytes (src0);
732747 size_t src1_sz = ggml_nbytes (src1);
@@ -743,7 +758,9 @@ bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggm
743758}
744759
745760void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
746- GGML_ASSERT (ggml_cuda_can_mul_mat (src0, src1, dst));
761+ GGML_ASSERT (ggml_cuda_can_mul_mat (src0, src1, dst) ||
762+ ggml_cuda_can_dequantize_mul_mat_vec (src0, src1, dst) ||
763+ src0->backend == GGML_BACKEND_CUDA);
747764
748765 if (src0->type == GGML_TYPE_F32) {
749766 ggml_cuda_mul_mat_f32 (src0, src1, dst);
0 commit comments