@@ -2522,7 +2522,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
25222522static __device__ __forceinline__ void mul_mat_q_process_tile (
25232523 const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
25242524 const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2525- const int nrows_x, const int ncols_y , const int stride_row_x , const int stride_col_dst,
2525+ const int nrows_x, const int stride_row_x , const int ncols_y , const int stride_col_dst,
25262526 const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
25272527
25282528 constexpr int qk = ggml_cuda_type_traits<type>::qk;
@@ -2606,7 +2606,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
26062606static __global__ void mul_mat_q (
26072607 const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
26082608 const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2609- const int ncols_x, const int nrows_x, const int ncols_y , const int stride_row_x, const int stride_col_dst,
2609+ const int ncols_x, const int nrows_x, const int ncols_dst , const int stride_row_x, const int ncols_y , const int stride_col_dst,
26102610 const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
26112611 const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
26122612
@@ -2619,8 +2619,8 @@ static __global__ void mul_mat_q(
26192619 constexpr int qk = ggml_cuda_type_traits<type>::qk;
26202620 constexpr int mmq_y = get_mmq_y_device ();
26212621
2622- const int ntx = (ncols_y + mmq_x - 1 ) / mmq_x; // Number of tiles x
2623- const int nty = (nrows_x + mmq_y - 1 ) / mmq_y; // Number of tiles y
2622+ const int ntx = (ncols_dst + mmq_x - 1 ) / mmq_x; // Number of tiles x
2623+ const int nty = (nrows_x + mmq_y - 1 ) / mmq_y; // Number of tiles y
26242624
26252625 // Initialize the ids for writing back data with just the index.
26262626 // For regular matrix multiplications this is never changed.
@@ -2648,8 +2648,8 @@ static __global__ void mul_mat_q(
26482648
26492649 // Defaults for regular matrix multiplication:
26502650 int col_low = 0 ;
2651- int col_high = ncols_y ;
2652- int col_diff = ncols_y ;
2651+ int col_high = ncols_dst ;
2652+ int col_diff = ncols_dst ;
26532653 int offset_y = wt*stride_sample_y + zt*stride_channel_y;
26542654 int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
26552655
@@ -2689,7 +2689,7 @@ static __global__ void mul_mat_q(
26892689
26902690 constexpr bool fixup = false ;
26912691 mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2692- (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x , stride_col_dst,
2692+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y , stride_col_dst,
26932693 tile_x_max_i, tile_y_max_j, 0 , ncols_x/qk);
26942694 return ;
26952695 }
@@ -2720,8 +2720,8 @@ static __global__ void mul_mat_q(
27202720
27212721 // Defaults for regular matrix multiplication:
27222722 int col_low = 0 ;
2723- int col_high = ncols_y ;
2724- int col_diff = ncols_y ;
2723+ int col_high = ncols_dst ;
2724+ int col_diff = ncols_dst ;
27252725 int offset_y = wt*stride_sample_y + zt*stride_channel_y;
27262726 int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
27272727
@@ -2767,7 +2767,7 @@ static __global__ void mul_mat_q(
27672767
27682768 constexpr bool fixup = false ; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
27692769 mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2770- (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x , stride_col_dst,
2770+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y , stride_col_dst,
27712771 tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
27722772
27732773 kbc += blocks_per_ne00;
@@ -2792,8 +2792,8 @@ static __global__ void mul_mat_q(
27922792
27932793 // Defaults for regular matrix multiplication:
27942794 int col_low = 0 ;
2795- int col_high = ncols_y ;
2796- int col_diff = ncols_y ;
2795+ int col_high = ncols_dst ;
2796+ int col_diff = ncols_dst ;
27972797 int offset_y = wt*stride_sample_y + zt*stride_channel_y;
27982798 int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
27992799
@@ -2834,15 +2834,15 @@ static __global__ void mul_mat_q(
28342834
28352835 constexpr bool fixup = true ; // Last index writes its data to fixup buffer to avoid data races with other blocks.
28362836 mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2837- (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x , stride_col_dst,
2837+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y , stride_col_dst,
28382838 tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
28392839}
28402840
28412841
28422842template <ggml_type type, int mmq_x, int nwarps, bool need_check>
28432843static __global__ void mul_mat_q_stream_k_fixup (
28442844 const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
2845- const int ncols_x, const int nrows_x, const int ncols_y , const int stride_col_dst,
2845+ const int ncols_x, const int nrows_x, const int ncols_dst , const int stride_col_dst,
28462846 const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
28472847 constexpr int mmq_y = get_mmq_y_device ();
28482848 constexpr int qk = ggml_cuda_type_traits<type>::qk;
@@ -2851,8 +2851,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
28512851
28522852 float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0 .0f };
28532853
2854- const int ntx = (ncols_y + mmq_x - 1 ) / mmq_x;
2855- const int nty = (nrows_x + mmq_y - 1 ) / mmq_y;
2854+ const int ntx = (ncols_dst + mmq_x - 1 ) / mmq_x;
2855+ const int nty = (nrows_x + mmq_y - 1 ) / mmq_y;
28562856
28572857 const int bidx0 = blockIdx .x ;
28582858
@@ -2925,8 +2925,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
29252925 const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
29262926 dst += offset_dst;
29272927
2928- const int i_max = nrows_x - it*mmq_y - 1 ;
2929- const int j_max = ncols_y - jt*mmq_x - 1 ;
2928+ const int i_max = nrows_x - it*mmq_y - 1 ;
2929+ const int j_max = ncols_dst - jt*mmq_x - 1 ;
29302930
29312931#pragma unroll
29322932 for (int j0 = 0 ; j0 < mmq_x; j0 += nwarps) {
@@ -2989,7 +2989,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
29892989
29902990struct mmq_args {
29912991 const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
2992- int64_t ncols_x; int64_t nrows_x; int64_t ncols_y ; int64_t stride_row_x; int64_t nrows_dst;
2992+ int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst ; int64_t stride_row_x; int64_t ncols_y ; int64_t nrows_dst;
29932993 int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
29942994 int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
29952995 bool use_stream_k;
@@ -3025,8 +3025,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
30253025 }
30263026#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
30273027
3028- const int nty = (args.nrows_x + mmq_y - 1 ) / mmq_y;
3029- const int ntx = (args.ncols_y + mmq_x - 1 ) / mmq_x;
3028+ const int nty = (args.nrows_x + mmq_y - 1 ) / mmq_y;
3029+ const int ntx = (args.ncols_dst + mmq_x - 1 ) / mmq_x;
30303030 const int ntzw = args.nchannels_y * args.nsamples_y ;
30313031 const dim3 block_nums_xy_tiling (nty, ntx, ntzw);
30323032
@@ -3040,14 +3040,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
30403040 constexpr bool need_check = false ;
30413041 mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
30423042 (args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , nullptr ,
3043- args.ncols_x , args.nrows_x , args.ncols_y , args.stride_row_x , args.nrows_dst ,
3043+ args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args. ncols_y , args.nrows_dst ,
30443044 channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
30453045 sample_ratio, args.nsamples_y , args.stride_sample_x , args.stride_sample_y , args.stride_sample_dst );
30463046 } else {
30473047 constexpr bool need_check = true ;
30483048 mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
30493049 (args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , nullptr ,
3050- args.ncols_x , args.nrows_x , args.ncols_y , args.stride_row_x , args.nrows_dst ,
3050+ args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args. ncols_y , args.nrows_dst ,
30513051 channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
30523052 sample_ratio, args.nsamples_y , args.stride_sample_x , args.stride_sample_y , args.stride_sample_dst );
30533053 }
@@ -3068,7 +3068,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
30683068
30693069 mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
30703070 (args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr ,
3071- args.ncols_x , args.nrows_x , args.ncols_y , args.stride_row_x , args.nrows_dst ,
3071+ args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args. ncols_y , args.nrows_dst ,
30723072 channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
30733073 sample_ratio, args.nsamples_y , args.stride_sample_x , args.stride_sample_y , args.stride_sample_dst );
30743074
@@ -3077,14 +3077,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
30773077 }
30783078
30793079 mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0 , stream>>>
3080- (args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr , args.ncols_x , args.nrows_x , args.ncols_y ,
3080+ (args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr , args.ncols_x , args.nrows_x , args.ncols_dst ,
30813081 args.nrows_dst , args.nchannels_y , args.stride_channel_dst , args.nsamples_y , args.stride_sample_dst );
30823082 } else {
30833083 constexpr bool need_check = true ;
30843084
30853085 mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
30863086 (args.x , args.y , args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr ,
3087- args.ncols_x , args.nrows_x , args.ncols_y , args.stride_row_x , args.nrows_dst ,
3087+ args.ncols_x , args.nrows_x , args.ncols_dst , args.stride_row_x , args. ncols_y , args.nrows_dst ,
30883088 channel_ratio, args.nchannels_y , args.stride_channel_x , args.stride_channel_y , args.stride_channel_dst ,
30893089 sample_ratio, args.nsamples_y , args.stride_sample_x , args.stride_sample_y , args.stride_sample_dst );
30903090
@@ -3093,7 +3093,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
30933093 }
30943094
30953095 mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0 , stream>>>
3096- (args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr , args.ncols_x , args.nrows_x , args.ncols_y ,
3096+ (args.ids_dst , args.expert_bounds , args.dst , tmp_fixup.ptr , args.ncols_x , args.nrows_x , args.ncols_dst ,
30973097 args.nrows_dst , args.nchannels_y , args.stride_channel_dst , args.nsamples_y , args.stride_sample_dst );
30983098 }
30993099}
0 commit comments