Skip to content

Commit 35de04e

Browse files
Method for checking if can dequantize mul mat
1 parent f0af475 commit 35de04e

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

ggml-cuda.cu

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,8 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
618618
const int nb3 = dst->nb[3];
619619
const ggml_type type = src0->type;
620620

621+
const bool can_dmmv = ggml_cuda_can_dequantize_mul_mat_vec(src0, src1, dst);
622+
621623
const float alpha = 1.0f;
622624
const float beta = 0.0f;
623625
const int x_ne = ne01 * ne00;
@@ -628,7 +630,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
628630

629631
size_t x_size, y_size, d_size, q_size;
630632
float * d_X;
631-
if (ne11 > 1) {
633+
if (!can_dmmv) {
632634
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
633635
}
634636
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
@@ -658,7 +660,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
658660
} else {
659661
GGML_ASSERT(false);
660662
}
661-
if (ne11 == 1) {
663+
if (can_dmmv) { // specialized dequantize_mul_mat_vec kernel
662664
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
663665

664666
// copy src1 to device
@@ -671,7 +673,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
671673
dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
672674
CUDA_CHECK(cudaGetLastError());
673675

674-
} else {
676+
} else { // general matrix matrix multiplication
675677
float * c_X = d_X + i * x_ne;
676678

677679
// convert src0 to fp32 on device
@@ -702,7 +704,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
702704
}
703705

704706
CUDA_CHECK(cudaDeviceSynchronize());
705-
if (ne11 > 1) {
707+
if (!can_dmmv) {
706708
ggml_cuda_pool_free(d_X, x_size);
707709
}
708710
ggml_cuda_pool_free(d_Y, y_size);
@@ -720,13 +722,26 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
720722
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
721723
src1->type == GGML_TYPE_F32 &&
722724
dst->type == GGML_TYPE_F32 &&
723-
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
725+
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
724726
return true;
725727
}
726728

727729
return false;
728730
}
729731

732+
bool ggml_cuda_can_dequantize_mul_mat_vec(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
733+
const int64_t ne11 = src1->ne[1];
734+
const int64_t ne12 = src1->ne[2];
735+
const int64_t ne13 = src1->ne[3];
736+
if (ggml_is_quantized(src0->type) &&
737+
src1->type == GGML_TYPE_F32 &&
738+
dst->type == GGML_TYPE_F32 &&
739+
(ne11 == 1 && ne12 == 1 && ne13 == 1)) {
740+
return true;
741+
}
742+
return false;
743+
}
744+
730745
bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
731746
size_t src0_sz = ggml_nbytes(src0);
732747
size_t src1_sz = ggml_nbytes(src1);
@@ -743,7 +758,9 @@ bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggm
743758
}
744759

745760
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
746-
GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
761+
GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst) ||
762+
ggml_cuda_can_dequantize_mul_mat_vec(src0, src1, dst) ||
763+
src0->backend == GGML_BACKEND_CUDA);
747764

748765
if (src0->type == GGML_TYPE_F32) {
749766
ggml_cuda_mul_mat_f32(src0, src1, dst);

ggml-cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ extern "C" {
77
void ggml_init_cublas(void);
88

99
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
10+
bool ggml_cuda_can_dequantize_mul_mat_vec(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
1011
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
1112
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
1213

ggml.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7799,7 +7799,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
77997799
// compute by src0 rows
78007800

78017801
#if defined(GGML_USE_CUBLAS)
7802-
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
7802+
if (ggml_cuda_can_mul_mat(src0, src1, dst) ||
7803+
ggml_cuda_can_dequantize_mul_mat_vec(src0, src1, dst) ||
7804+
src0->backend == GGML_BACKEND_CUDA) {
78037805
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
78047806
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
78057807
}

0 commit comments

Comments
 (0)