@@ -83,9 +83,19 @@ typedef struct {
8383} block_q8_0;
8484static_assert (sizeof (block_q8_0) == sizeof (ggml_fp16_t ) + QK8_0, " wrong q8_0 block size/padding" );
8585
86+ #define WARP_SIZE 32
87+
8688#define CUDA_MUL_BLOCK_SIZE 256
89+
8790#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
88- #define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
91+
92+ // dmmv = dequantize_mul_mat_vec
93+ #ifndef GGML_CUDA_DMMV_X
94+ #define GGML_CUDA_DMMV_X 32
95+ #endif
96+ #ifndef GGML_CUDA_DMMV_Y
97+ #define GGML_CUDA_DMMV_Y 1
98+ #endif
8999
90100static __global__ void mul_f32 (const float * x, const float * y, float * dst, const int kx, const int ky) {
91101 const int i = blockDim .x *blockIdx .x + threadIdx .x ;
@@ -200,41 +210,51 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
200210 dequantize_kernel (vx, ib, iqs, v0, v1);
201211}
202212
203- template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
213+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
204214static __global__ void dequantize_mul_mat_vec (const void * vx, const float * y, float * dst, const int ncols) {
205- const int row = blockIdx .x ;
215+ // qk = quantized weights per x block
216+ // qr = number of quantized weights per data value in x block
217+ const int row = blockIdx .x *blockDim .y + threadIdx .y ;
206218 const int tid = threadIdx .x ;
207219
220+ const int iter_stride = 2 *GGML_CUDA_DMMV_X;
221+ const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
208222 const int y_offset = qr == 1 ? 1 : qk/2 ;
209223
210- __shared__ float tmp[block_size]; // separate sum for each thread
211- tmp[tid] = 0 ;
224+ float tmp = 0 ; // partial sum for thread in warp
212225
213- for (int i = 0 ; i < ncols/block_size ; i += 2 ) {
214- const int col = i*block_size + 2 *tid;
215- const int ib = (row*ncols + col)/qk; // block index
216- const int iqs = (col%qk)/qr; // quant index
226+ for (int i = 0 ; i < ncols; i += iter_stride ) {
227+ const int col = i + vals_per_iter *tid;
228+ const int ib = (row*ncols + col)/qk; // x block index
229+ const int iqs = (col%qk)/qr; // x quant index
217230 const int iybs = col - col%qk; // y block start index
218231
219- // dequantize
220- float v0, v1;
221- dequantize_kernel (vx, ib, iqs, v0, v1);
232+ // processing >2 values per i iter is faster for fast GPUs
233+ #pragma unroll
234+ for (int j = 0 ; j < vals_per_iter; j += 2 ) {
235+ // process 2 vals per j iter
236+
237+ // dequantize
238+ float v0, v1;
239+ dequantize_kernel (vx, ib, iqs + j/qr, v0, v1);
240+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
222241
223- // matrix multiplication
224- tmp[tid] += v0 * y[iybs + iqs + 0 ];
225- tmp[tid] += v1 * y[iybs + iqs + y_offset];
242+ // matrix multiplication
243+ tmp += v0 * y[iybs + iqs + j/qr + 0 ];
244+ tmp += v1 * y[iybs + iqs + j/qr + y_offset];
245+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
246+ }
226247 }
227248
228249 // sum up partial sums and write back result
229250 __syncthreads ();
230- for (int s=block_size/2 ; s>0 ; s>>=1 ) {
231- if (tid < s) {
232- tmp[tid] += tmp[tid + s];
233- }
234- __syncthreads ();
251+ #pragma unroll
252+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
253+ tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
235254 }
255+
236256 if (tid == 0 ) {
237- dst[row] = tmp[ 0 ] ;
257+ dst[row] = tmp;
238258 }
239259}
240260
@@ -269,33 +289,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
269289}
270290
271291static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
272- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
273- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
274- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
292+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
293+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
294+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
295+ dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
296+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
275297}
276298
277299static void dequantize_mul_mat_vec_q4_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
278- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
279- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
280- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
300+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
301+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
302+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
303+ dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
304+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
281305}
282306
283307static void dequantize_mul_mat_vec_q5_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
284- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
285- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
286- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
308+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
309+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
310+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
311+ dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
312+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
287313}
288314
289315static void dequantize_mul_mat_vec_q5_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
290- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
291- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
292- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
316+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
317+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
318+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
319+ dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
320+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
293321}
294322
295323static void dequantize_mul_mat_vec_q8_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
296- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
297- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
298- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
324+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
325+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
326+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
327+ dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
328+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
299329}
300330
301331static void convert_fp16_to_fp32_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -304,9 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
304334}
305335
306336static void convert_mul_mat_vec_f16_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
307- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
308- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32 , 1 , convert_f16>
309- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
337+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
338+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
339+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
340+ dequantize_mul_mat_vec<1 , 1 , convert_f16>
341+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
310342}
311343
312344static to_fp32_cuda_t ggml_get_to_fp32_cuda (ggml_type type) {
0 commit comments