@@ -934,6 +934,30 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
934934 return false ;
935935}
936936
937+ bool ggml_cuda_mul_mat_use_f16 (const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */ ) {
938+ size_t src0_sz = ggml_nbytes (src0);
939+ size_t src1_sz = ggml_nbytes (src1);
940+
941+ // mul_mat_q: src0 is converted to fp32 on device
942+ size_t mul_mat_q_transfer = src0_sz + src1_sz;
943+
944+ // mul_mat_f16: src1 is converted to fp16 on cpu
945+ size_t mul_mat_f16_transfer = src0_sz + sizeof (half) * ggml_nelements (src1);
946+
947+ // choose the smaller one to transfer to the device
948+ // TODO: this is not always the best choice due to the overhead of converting to fp16
949+ return mul_mat_f16_transfer < mul_mat_q_transfer;
950+ }
951+
952+ size_t ggml_cuda_mul_mat_get_wsize (const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
953+ if (ggml_cuda_mul_mat_use_f16 (src0, src1, dst)) {
954+ return ggml_nelements (src1) * sizeof (ggml_fp16_t );
955+ }
956+ else {
957+ return 0 ;
958+ }
959+ }
960+
937961void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
938962 GGML_ASSERT (ggml_cuda_can_mul_mat (src0, src1, dst));
939963
@@ -950,6 +974,99 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
950974 }
951975}
952976
977+ static void ggml_cuda_mul_mat_f16 (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */ ) {
978+ const int64_t ne00 = src0->ne [0 ];
979+ const int64_t ne01 = src0->ne [1 ];
980+ const int64_t ne02 = src0->ne [2 ];
981+ const int64_t ne03 = src0->ne [3 ];
982+
983+ const int64_t ne10 = src1->ne [0 ];
984+ const int64_t ne11 = src1->ne [1 ];
985+
986+ const int nb10 = src1->nb [0 ];
987+ const int nb11 = src1->nb [1 ];
988+ const int nb12 = src1->nb [2 ];
989+ const int nb13 = src1->nb [3 ];
990+
991+ const int nb2 = dst->nb [2 ];
992+ const int nb3 = dst->nb [3 ];
993+
994+ const float alpha = 1 .0f ;
995+ const float beta = 0 .0f ;
996+ const int x_ne = ne01 * ne00;
997+ const int y_ne = ne11 * ne10;
998+ const int d_ne = ne11 * ne01;
999+ const int n_mm = ne03 * ne02;
1000+
1001+ size_t x_size, y_size, d_size;
1002+ half * d_X = (half *) ggml_cuda_pool_malloc (n_mm * sizeof (half) * x_ne, &x_size);
1003+ half * d_Y = (half *) ggml_cuda_pool_malloc (n_mm * sizeof (half) * y_ne, &y_size);
1004+ float * d_D = (float *) ggml_cuda_pool_malloc (n_mm * sizeof (float ) * d_ne, &d_size);
1005+
1006+ bool src1_cont_rows = nb10 == sizeof (float );
1007+ bool src1_cont_cols = (size_t )nb11 == ne11*sizeof (float );
1008+
1009+ for (int64_t i03 = 0 ; i03 < ne03; i03++) {
1010+ for (int64_t i02 = 0 ; i02 < ne02; i02++) {
1011+ int i = i03*ne02 + i02;
1012+ cudaStream_t cudaStream = g_cudaStreams_main[0 ][i % GGML_CUDA_MAX_STREAMS];
1013+
1014+ half * c_X = d_X + i * x_ne;
1015+ half * c_Y = d_Y + i * y_ne;
1016+ float * c_D = d_D + i * d_ne;
1017+
1018+ // copy src0 to device
1019+ CUDA_CHECK (ggml_cuda_h2d_tensor_2d (c_X, src0, i03, i02, 0 , ne01, cudaStream));
1020+
1021+ // convert src1 to fp16
1022+ // TODO: use multiple threads
1023+ ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
1024+ char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
1025+ if (src1_cont_rows) {
1026+ if (src1_cont_cols) {
1027+ ggml_fp32_to_fp16_row ((float *) src1i, tmp, ne10*ne11);
1028+ }
1029+ else {
1030+ for (int64_t i01 = 0 ; i01 < ne11; i01++) {
1031+ ggml_fp32_to_fp16_row ((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
1032+ }
1033+ }
1034+ }
1035+ else {
1036+ for (int64_t i01 = 0 ; i01 < ne11; i01++) {
1037+ for (int64_t i00 = 0 ; i00 < ne10; i00++) {
1038+ // very slow due to no inlining
1039+ tmp[i01*ne10 + i00] = ggml_fp32_to_fp16 (*(float *) (src1i + i01*nb11 + i00*nb10));
1040+ }
1041+ }
1042+ }
1043+
1044+ // copy src1 to device
1045+ CUDA_CHECK (cudaMemcpyAsync (c_Y, tmp, sizeof (half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
1046+
1047+ // compute
1048+ CUBLAS_CHECK (cublasSetStream (g_cublas_handles[0 ], cudaStream));
1049+ CUBLAS_CHECK (
1050+ cublasGemmEx (g_cublas_handles[0 ], CUBLAS_OP_T, CUBLAS_OP_N,
1051+ ne01, ne11, ne10,
1052+ &alpha, c_X, CUDA_R_16F, ne00,
1053+ c_Y, CUDA_R_16F, ne10,
1054+ &beta, c_D, CUDA_R_32F, ne01,
1055+ CUBLAS_COMPUTE_32F_FAST_16F,
1056+ CUBLAS_GEMM_DEFAULT));
1057+
1058+ // copy dst to host
1059+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
1060+ CUDA_CHECK (cudaMemcpyAsync (d, c_D, sizeof (float ) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
1061+ }
1062+ }
1063+
1064+ CUDA_CHECK (cudaDeviceSynchronize ());
1065+ ggml_cuda_pool_free (d_X, x_size);
1066+ ggml_cuda_pool_free (d_Y, y_size);
1067+ ggml_cuda_pool_free (d_D, d_size);
1068+ }
1069+
9531070void ggml_cuda_load_data (const char * fname, struct ggml_tensor * tensor, const size_t offset, int n_layer) {
9541071 FILE * fp = fopen (fname, " rb" );
9551072 int nrows = ggml_nrows (tensor);
@@ -1054,6 +1171,19 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
10541171 if (!ggml_cuda_can_mul_mat (tensor->src0 , tensor->src1 , tensor)) {
10551172 return false ;
10561173 }
1174+ if (g_device_count == 1 && tensor->src0 ->type == GGML_TYPE_F16 &&
1175+ ggml_cuda_mul_mat_use_f16 (tensor->src0 , tensor->src1 , tensor)) {
1176+
1177+ if (params->ith != 0 ) {
1178+ return true ;
1179+ }
1180+ if (params->type == GGML_TASK_COMPUTE) {
1181+ ggml_cuda_mul_mat_f16 (tensor->src0 , tensor->src1 , tensor, params->wdata , params->wsize );
1182+ return true ;
1183+ }
1184+
1185+ return false ;
1186+ }
10571187 func = ggml_cuda_mul_mat;
10581188 break ;
10591189 default :
0 commit comments