@@ -8394,8 +8394,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
83948394 }
83958395
83968396#if defined(GGML_USE_CUBLAS )
8397- ggml_fp16_t * const wdata = params -> wdata ;
8398-
83998397 const float alpha = 1.0f ;
84008398 const float beta = 0.0f ;
84018399 const int x_ne = ne01 * ne00 ;
@@ -8413,6 +8411,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
84138411 for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
84148412#if defined(GGML_USE_CUBLAS )
84158413 // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
8414+ ggml_fp16_t * const wdata = (ggml_fp16_t * ) params -> wdata + (ne11 * ne10 ) * (i03 * ne02 + i02 );
84168415 {
84178416 size_t id = 0 ;
84188417 for (int64_t i01 = 0 ; i01 < ne11 ; ++ i01 ) {
@@ -8688,7 +8687,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
86888687 const float * x = wdata ;
86898688#endif
86908689
8691-
86928690#if defined(GGML_USE_CUBLAS )
86938691 // copy data to device
86948692 CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (float ) * y_ne , cudaMemcpyHostToDevice , g_cudaStream ));
@@ -11550,7 +11548,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1155011548 if (ggml_compute_forward_mul_mat_use_blas (node -> src0 , node -> src1 , node )) {
1155111549 node -> n_tasks = 1 ; // TODO: this actually is doing nothing
1155211550 // the threads are still spinning
11553- cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ]* ( node -> src0 -> ne [ 0 ] * node -> src0 -> ne [ 1 ] );
11551+ cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ]* MAX ( ggml_nelements ( node -> src1 ), ggml_nelements ( node -> src0 ) );
1155411552 //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
1155511553 //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
1155611554 //printf("cur = %zu\n", cur);
@@ -11562,6 +11560,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1156211560#endif
1156311561 } else if (node -> src0 -> type == GGML_TYPE_F32 && node -> src1 -> type == GGML_TYPE_F32 ) {
1156411562 cur = 0 ;
11563+ #if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined(GGML_USE_CUBLAS )
11564+ if (ggml_compute_forward_mul_mat_use_blas (node -> src0 , node -> src1 , node )) {
11565+ node -> n_tasks = 1 ;
11566+ }
11567+ #endif
1156511568 } else if (ggml_is_quantized (node -> src0 -> type ) && node -> src1 -> type == GGML_TYPE_F32 ) {
1156611569#if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS ) || defined(GGML_USE_CUBLAS )
1156711570 if (ggml_compute_forward_mul_mat_use_blas (node -> src0 , node -> src1 , node )) {
0 commit comments