@@ -83,7 +83,8 @@ typedef struct {
8383} block_q8_0;
8484static_assert (sizeof (block_q8_0) == sizeof (float ) + QK8_0, " wrong q8_0 block size/padding" );
8585
86- #define CUDA_DMMV_BLOCK_SIZE 32
86+ #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
87+ #define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
8788
8889static __device__ void dequantize_q4_0 (const void * vx, const int ib, const int iqs, float & v0, float & v1){
8990 const block_q4_0 * x = (const block_q4_0 *) vx;
@@ -170,104 +171,23 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
170171 v1 = __half2float (x[ib + 1 ]);
171172}
172173
173- static __global__ void dequantize_block_q4_0 (const void * vx, float * y) {
174- static const int qk = QK4_0;
174+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
175+ static __global__ void dequantize_block (const void * vx, float * y, const int k) {
176+ const int i = blockDim .x *blockIdx .x + 2 *threadIdx .x ;
175177
176- const block_q4_0 * x = (const block_q4_0 *) vx;
177-
178- const int i = blockIdx .x ;
179-
180- const float d = x[i].d ;
181-
182- for (int j = 0 ; j < qk/2 ; ++j) {
183- const int x0 = (x[i].qs [j] & 0xf ) - 8 ;
184- const int x1 = (x[i].qs [j] >> 4 ) - 8 ;
185-
186- y[i*qk + j + 0 ] = x0*d;
187- y[i*qk + j + qk/2 ] = x1*d;
188- }
189- }
190-
191- static __global__ void dequantize_block_q4_1 (const void * vx, float * y) {
192- static const int qk = QK4_1;
193-
194- const block_q4_1 * x = (const block_q4_1 *) vx;
195-
196- const int i = blockIdx .x ;
197-
198- const float d = x[i].d ;
199- const float m = x[i].m ;
200-
201- for (int j = 0 ; j < qk/2 ; ++j) {
202- const int x0 = (x[i].qs [j] & 0xf );
203- const int x1 = (x[i].qs [j] >> 4 );
204-
205- y[i*qk + j + 0 ] = x0*d + m;
206- y[i*qk + j + qk/2 ] = x1*d + m;
207- }
208- }
209-
210- static __global__ void dequantize_block_q5_0 (const void * vx, float * y) {
211- static const int qk = QK5_0;
212-
213- const block_q5_0 * x = (const block_q5_0 *) vx;
214-
215- const int i = blockIdx .x ;
216-
217- const float d = x[i].d ;
218-
219- uint32_t qh;
220- memcpy (&qh, x[i].qh , sizeof (qh));
221-
222- for (int j = 0 ; j < qk/2 ; ++j) {
223- const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
224- const uint8_t xh_1 = ((qh >> (j + 12 )) ) & 0x10 ;
225-
226- const int32_t x0 = ((x[i].qs [j] & 0xf ) | xh_0) - 16 ;
227- const int32_t x1 = ((x[i].qs [j] >> 4 ) | xh_1) - 16 ;
228-
229- y[i*qk + j + 0 ] = x0*d;
230- y[i*qk + j + qk/2 ] = x1*d;
178+ if (i >= k) {
179+ return ;
231180 }
232- }
233-
234- static __global__ void dequantize_block_q5_1 (const void * vx, float * y) {
235- static const int qk = QK5_1;
236-
237- const block_q5_1 * x = (const block_q5_1 *) vx;
238-
239- const int i = blockIdx .x ;
240-
241- const float d = x[i].d ;
242- const float m = x[i].m ;
243181
244- uint32_t qh;
245- memcpy (&qh, x[i].qh , sizeof (qh));
246-
247- for (int j = 0 ; j < qk/2 ; ++j) {
248- const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
249- const uint8_t xh_1 = ((qh >> (j + 12 )) ) & 0x10 ;
250-
251- const int x0 = (x[i].qs [j] & 0xf ) | xh_0;
252- const int x1 = (x[i].qs [j] >> 4 ) | xh_1;
253-
254- y[i*qk + j + 0 ] = x0*d + m;
255- y[i*qk + j + qk/2 ] = x1*d + m;
256- }
257- }
258-
259- static __global__ void dequantize_block_q8_0 (const void * vx, float * y) {
260- static const int qk = QK8_0;
261-
262- const block_q8_0 * x = (const block_q8_0 *) vx;
263-
264- const int i = blockIdx .x ;
265-
266- const float d = x[i].d ;
182+ const int ib = i/qk; // block index
183+ const int iqs = (i%qk)/qr; // quant index
184+ const int iybs = i - i%qk; // y block start index
185+ const int y_offset = qr == 1 ? 1 : qk/2 ;
267186
268- for (int j = 0 ; j < qk; ++j) {
269- y[i*qk + j] = x[i].qs [j]*d;
270- }
187+ // dequantize
188+ float & v0 = y[iybs + iqs + 0 ];
189+ float & v1 = y[iybs + iqs + y_offset];
190+ dequantize_kernel (vx, ib, iqs, v0, v1);
271191}
272192
273193template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -308,29 +228,29 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
308228 }
309229}
310230
311- static void dequantize_row_q4_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
312- const int nb = k / QK4_0 ;
313- dequantize_block_q4_0 <<<nb, 1 , 0 , stream>>> (vx, y);
231+ static void dequantize_row_q4_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
232+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE ;
233+ dequantize_block<QK4_0, QR4_0, dequantize_q4_0> <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE , 0 , stream>>> (vx, y, k );
314234}
315235
316- static void dequantize_row_q4_1_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
317- const int nb = k / QK4_1 ;
318- dequantize_block_q4_1 <<<nb, 1 , 0 , stream>>> (vx, y);
236+ static void dequantize_row_q4_1_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
237+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE ;
238+ dequantize_block<QK4_1, QR4_1, dequantize_q4_1> <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE , 0 , stream>>> (vx, y, k );
319239}
320240
321- static void dequantize_row_q5_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
322- const int nb = k / QK5_0 ;
323- dequantize_block_q5_0 <<<nb, 1 , 0 , stream>>> (vx, y);
241+ static void dequantize_row_q5_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
242+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE ;
243+ dequantize_block<QK5_0, QR5_0, dequantize_q5_0> <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE , 0 , stream>>> (vx, y, k );
324244}
325245
326- static void dequantize_row_q5_1_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
327- const int nb = k / QK5_1 ;
328- dequantize_block_q5_1 <<<nb, 1 , 0 , stream>>> (vx, y);
246+ static void dequantize_row_q5_1_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
247+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE ;
248+ dequantize_block<QK5_1, QR5_1, dequantize_q5_1> <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE , 0 , stream>>> (vx, y, k );
329249}
330250
331- static void dequantize_row_q8_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
332- const int nb = k / QK8_0 ;
333- dequantize_block_q8_0 <<<nb, 1 , 0 , stream>>> (vx, y);
251+ static void dequantize_row_q8_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
252+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE ;
253+ dequantize_block<QK8_0, QR8_0, dequantize_q8_0> <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE , 0 , stream>>> (vx, y, k );
334254}
335255
336256static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -363,17 +283,9 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
363283 <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
364284}
365285
366- // TODO: optimize
367- static __global__ void convert_fp16_to_fp32 (const void * vx, float * y) {
368- const half * x = (const half *) vx;
369-
370- const int i = blockIdx .x ;
371-
372- y[i] = __half2float (x[i]);
373- }
374-
375- static void convert_fp16_to_fp32_cuda (const void * x, float * y, int k, cudaStream_t stream) {
376- convert_fp16_to_fp32<<<k, 1 , 0 , stream>>> (x, y);
286+ static void convert_fp16_to_fp32_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
287+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
288+ dequantize_block<32 , 1 , convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
377289}
378290
379291static void convert_mul_mat_vec_f16_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
0 commit comments