@@ -8401,15 +8401,19 @@ static void ggml_compute_forward_mul_mat_f16_f32(
84018401 const int d_ne = ne11 * ne01 ;
84028402
84038403 size_t x_size , y_size , d_size ;
8404- float * d_X = ggml_cuda_pool_malloc (sizeof (float ) * x_ne , & x_size );
8405- float * d_Y = ggml_cuda_pool_malloc (sizeof (float ) * y_ne , & y_size );
8406- float * d_D = ggml_cuda_pool_malloc (sizeof (float ) * d_ne , & d_size );
8404+ ggml_fp16_t * d_X = ggml_cuda_pool_malloc (sizeof (float ) * x_ne , & x_size );
8405+ ggml_fp16_t * d_Y = ggml_cuda_pool_malloc (sizeof (float ) * y_ne , & y_size );
8406+ float * d_D = ggml_cuda_pool_malloc (sizeof (float ) * d_ne , & d_size );
84078407#else
84088408 float * const wdata = params -> wdata ;
84098409#endif
84108410 for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
84118411 for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
84128412#if defined(GGML_USE_CUBLAS )
8413+ // copy src0 while converting src1
8414+ const ggml_fp16_t * x = (ggml_fp16_t * ) ((char * ) src0 -> data + i02 * nb02 + i03 * nb03 );
8415+ CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (ggml_fp16_t ) * x_ne , cudaMemcpyHostToDevice , g_cudaStream ));
8416+
84138417 // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
84148418 ggml_fp16_t * const wdata = (ggml_fp16_t * ) params -> wdata + (ne11 * ne10 ) * (i03 * ne02 + i02 );
84158419 {
@@ -8432,13 +8436,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
84328436#endif
84338437
84348438#if defined(GGML_USE_CUBLAS )
8435- const ggml_fp16_t * x = (ggml_fp16_t * ) ((char * ) src0 -> data + i02 * nb02 + i03 * nb03 );
84368439 const ggml_fp16_t * y = (ggml_fp16_t * ) wdata ;
8437-
84388440 float * d = (float * ) ((char * ) dst -> data + i02 * nb2 + i03 * nb3 );
84398441
84408442 // copy data to device
8441- CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (ggml_fp16_t ) * x_ne , cudaMemcpyHostToDevice , g_cudaStream ));
84428443 CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (ggml_fp16_t ) * y_ne , cudaMemcpyHostToDevice , g_cudaStream ));
84438444
84448445 // compute
0 commit comments