@@ -225,16 +225,18 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
225225 }
226226}
227227
228- template <int block_size > static __global__ void dequantize_mul_mat_q4_0 (const void * vx, const float * y, float * dst, const int ncols) {
228+ template <int reduce_size > static __global__ void dequantize_mul_mat_q4_0 (const void * vx, const float * y, float * dst, const int ncols) {
229229 const block_q4_0 * x = (const block_q4_0 *) vx;
230230
231231 const int row = blockIdx .x * 2 + threadIdx .y ;
232232 const int tid = threadIdx .x ;
233233
234- float tmp = 0 ;
234+ __shared__ float full_tmp[reduce_size * 2 ]; // separate sum for each thread
235+ float * tmp = full_tmp + reduce_size * threadIdx .y ;
236+ tmp[tid] = 0 ;
235237
236- for (int i = 0 ; i < ncols/block_size ; i += 2 ) {
237- const int col = i*block_size + 2 *tid;
238+ for (int i = 0 ; i < ncols/reduce_size ; i += 2 ) {
239+ const int col = i*reduce_size + 2 *tid;
238240
239241 // dequantize
240242 const float d = x[(row*ncols + col)/QK4_0].d ;
@@ -250,16 +252,18 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
250252 const float v1 = (vi1 - 8 )*d;
251253
252254 // matrix multiplication
253- tmp += v0 * y[col + 0 ];
254- tmp += v1 * y[col + 1 ];
255+ tmp[tid] += v0 * y[col + 0 ];
256+ tmp[tid] += v1 * y[col + 1 ];
255257 }
256258
257- #pragma unroll
258- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
259- tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
259+ for (int s=reduce_size/2 ; s>0 ; s>>=1 ) {
260+ if (tid < s) {
261+ tmp[tid] += tmp[tid + s];
262+ }
263+ __syncthreads ();
260264 }
261265 if (tid == 0 ) {
262- dst[row] = tmp;
266+ dst[row] = tmp[ 0 ] ;
263267 }
264268}
265269
0 commit comments