@@ -208,6 +208,7 @@ typedef struct {
208208static_assert (sizeof (block_q6_K) == sizeof (ggml_fp16_t ) + 13 *QK_K/16 , " wrong q6_K block size/padding" );
209209
210210#define WARP_SIZE 32
211+ #define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
211212
212213#define CUDA_ADD_BLOCK_SIZE 256
213214#define CUDA_MUL_BLOCK_SIZE 256
@@ -1171,7 +1172,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
11711172 v.y = x[ib + iqs + 1 ];
11721173}
11731174
1174- static __global__ void quantize_q8_1 (const float * __restrict__ x, void * __restrict__ vy, const int k) {
1175+ static __global__ void quantize_q8_1 (const float * __restrict__ x, void * __restrict__ vy, const int ndata, const int k) {
11751176 const int i = blockDim .x *blockIdx .x + threadIdx .x ;
11761177
11771178 if (i >= k) {
@@ -1180,10 +1181,10 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
11801181
11811182 block_q8_1 * y = (block_q8_1 *) vy;
11821183
1183- const int ib = i / QK8_0 ; // block index
1184- const int iqs = i % QK8_0 ; // quant index
1184+ const int ib = i / QK8_1 ; // block index
1185+ const int iqs = i % QK8_1 ; // quant index
11851186
1186- const float xi = x[i];
1187+ const float xi = i < ndata ? x[i] : 0 . 0f ;
11871188 float amax = fabsf (xi);
11881189 float sum = xi;
11891190
@@ -1714,9 +1715,9 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
17141715 rms_norm_f32<<<nrows, block_dims, 0 , stream>>> (x, dst, ncols);
17151716}
17161717
1717- static void quantize_row_q8_1_cuda (const float * x, void * vy, const int k, cudaStream_t stream) {
1718+ static void quantize_row_q8_1_cuda (const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
17181719 const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1 ) / CUDA_QUANTIZE_BLOCK_SIZE;
1719- quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0 , stream>>> (x, vy, k);
1720+ quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0 , stream>>> (x, vy, ndata, k);
17201721}
17211722
17221723static void dequantize_row_q4_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -2359,9 +2360,11 @@ inline void ggml_cuda_op_mul_mat_vec(
23592360#endif
23602361
23612362 if (use_mul_mat_vec_q) {
2363+ int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1 ;
2364+ padded_row_size -= padded_row_size % MATRIX_ROW_PADDING;
23622365 size_t as;
2363- void * src1_q8_1 = ggml_cuda_pool_malloc (ne00 *sizeof (block_q8_1)/QK8_1, &as);
2364- quantize_row_q8_1_cuda (src1_ddf_i, src1_q8_1, ne00, cudaStream_main);
2366+ void * src1_q8_1 = ggml_cuda_pool_malloc (padded_row_size *sizeof (block_q8_1)/QK8_1, &as);
2367+ quantize_row_q8_1_cuda (src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
23652368
23662369 switch (src0->type ) {
23672370 case GGML_TYPE_Q4_0:
@@ -3105,7 +3108,11 @@ void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
31053108
31063109void ggml_cuda_transform_tensor (void * data, struct ggml_tensor * tensor) {
31073110 int nrows = ggml_nrows (tensor);
3111+
3112+ const int64_t ne0 = tensor->ne [0 ];
3113+
31083114 const size_t nb1 = tensor->nb [1 ];
3115+
31093116 ggml_backend backend = tensor->backend ;
31103117 struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu ;
31113118 memset (extra, 0 , sizeof (*extra));
@@ -3134,11 +3141,24 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
31343141 int64_t nrows_split = row_high - row_low;
31353142
31363143 const size_t offset_split = row_low*nb1;
3137- const size_t size = ggml_nbytes_split (tensor, nrows_split);
3144+ size_t size = ggml_nbytes_split (tensor, nrows_split);
3145+ const size_t original_size = size;
3146+
3147+ // pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses
3148+ if (ne0 % MATRIX_ROW_PADDING != 0 ) {
3149+ size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
3150+ * ggml_type_size (tensor->type )/ggml_blck_size (tensor->type );
3151+ }
31383152
3139- void * buf;
3153+ char * buf;
31403154 CUDA_CHECK (cudaMalloc (&buf, size));
3141- void * buf_host = (char *)data + offset_split;
3155+ char * buf_host = (char *)data + offset_split;
3156+
3157+ // set padding to 0 to avoid possible NaN values
3158+ if (size > original_size) {
3159+ CUDA_CHECK (cudaMemset (buf + original_size, 0 , size - original_size));
3160+ }
3161+
31423162
31433163 cudaMemcpy (buf, buf_host, size, cudaMemcpyHostToDevice);
31443164
0 commit comments