@@ -1662,24 +1662,24 @@ static __global__ void mul_mat_q(
16621662 const int tid_x = threadIdx .x ;
16631663 const int tid_y = threadIdx .y ;
16641664
1665- const int row_dst_0 = blockIdx .x *WARP_SIZE;
1665+ const int row_dst_0 = 2 * blockIdx .x *WARP_SIZE;
16661666 const int & row_x_0 = row_dst_0;
16671667 const int row_dst = row_dst_0 + tid_x;
16681668
16691669 const int col_dst_0 = blockIdx .y *WARP_SIZE;
16701670 const int & col_y_0 = col_dst_0;
16711671
1672- __shared__ int tile_x_qs[WARP_SIZE][WARP_SIZE + 1 ];
1673- __shared__ half tile_x_d[WARP_SIZE][WARP_SIZE/QI4_0];
1672+ __shared__ int tile_x_qs[2 * WARP_SIZE][WARP_SIZE + 1 ];
1673+ __shared__ half tile_x_d[2 * WARP_SIZE][WARP_SIZE/QI4_0];
16741674 __shared__ int tile_y_qs[WARP_SIZE][2 *WARP_SIZE];
16751675 __shared__ half2 tile_y_ds[WARP_SIZE][2 *WARP_SIZE/QI8_1];
1676- float sum[4 ] = {0 .0f };
1676+ float sum[2 ][ 4 ] = {0 .0f };
16771677
16781678 for (int ib0 = 0 ; ib0 < blocks_per_row; ib0 += blocks_per_warp) {
16791679 const int ibx = tid_x / QI4_0;
16801680 const int iqsx = sizeof (int ) * (tid_x % QI4_0);
16811681
1682- for (int j = 0 ; j < WARP_SIZE; j += 8 ) {
1682+ for (int j = 0 ; j < 2 * WARP_SIZE; j += 8 ) {
16831683 const block_q4_0 * __restrict__ bx = &x[(row_x_0 + j + tid_y)*blocks_per_row + ib0 + ibx];
16841684 memcpy (&tile_x_qs[j + tid_y][tid_x], &bx->qs [iqsx], sizeof (int ));
16851685 tile_x_d[j + tid_y][ibx] = bx->d ;
@@ -1706,9 +1706,12 @@ static __global__ void mul_mat_q(
17061706 for (int k = 0 ; k < WARP_SIZE; ++k) {
17071707 const int iqsy = k % (QI8_1/2 ) + QI8_1 * (k / (QI8_1/2 ));
17081708 for (int j = 0 ; j < WARP_SIZE; j += 8 ) {
1709- sum[j/8 ] += vec_dot_q4_0_q8_1_impl (
1709+ sum[0 ][ j/8 ] += vec_dot_q4_0_q8_1_impl (
17101710 tile_x_qs[tid_x][k], tile_y_qs[tid_y + j][iqsy + 0 ], tile_y_qs[tid_y + j][iqsy + (QI8_1/2 )],
17111711 tile_x_d[tid_x][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
1712+ sum[1 ][j/8 ] += vec_dot_q4_0_q8_1_impl (
1713+ tile_x_qs[tid_x + WARP_SIZE][k], tile_y_qs[tid_y + j][iqsy + 0 ], tile_y_qs[tid_y + j][iqsy + (QI8_1/2 )],
1714+ tile_x_d[tid_x + WARP_SIZE][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
17121715 }
17131716 }
17141717
@@ -1727,7 +1730,8 @@ static __global__ void mul_mat_q(
17271730 return ;
17281731 }
17291732
1730- dst[col_dst*nrows_dst + row_dst] = sum[j/8 ];
1733+ dst[col_dst*nrows_dst + row_dst] = sum[0 ][j/8 ];
1734+ dst[col_dst*nrows_dst + row_dst + WARP_SIZE] = sum[1 ][j/8 ];
17311735 }
17321736}
17331737
@@ -2417,7 +2421,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
24172421}
24182422
24192423static void ggml_mul_mat_q4_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
2420- const int block_num_x = (nrows_x + WARP_SIZE - 1 ) / WARP_SIZE;
2424+ const int block_num_x = (nrows_x + 2 * WARP_SIZE - 1 ) / ( 2 * WARP_SIZE) ;
24212425 const int block_num_y = (ncols_y + WARP_SIZE - 1 ) / WARP_SIZE;
24222426 const dim3 block_nums (block_num_x, block_num_y, 1 );
24232427 const dim3 block_dims (WARP_SIZE, WARP_SIZE/4 , 1 );
0 commit comments