@@ -59,8 +59,8 @@ typedef float2 dfloat2;
5959#endif // GGML_CUDA_DMMV_F16
6060
6161typedef void (*dequantize_kernel_t )(const void * vx, const int ib, const int iqs, dfloat2 & v);
62- typedef void (*to_fp32_cuda_t )(const void * x, float * y, int k, cudaStream_t stream);
63- typedef void (*dot_kernel_k_t )(const void * vx, const int ib, const int iqs, const float * y, float & v);
62+ typedef void (*to_fp32_cuda_t )(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
63+ typedef void (*dot_kernel_k_t )(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
6464typedef void (*cpy_kernel_t )(const char * cx, char * cdst);
6565typedef void (*ggml_cuda_func_t )(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
6666typedef void (*ggml_cuda_op_t )(
@@ -131,7 +131,7 @@ typedef struct {
131131} block_q8_1;
132132static_assert (sizeof (block_q8_1) == 2 *sizeof (ggml_fp16_t ) + QK8_0, " wrong q8_1 block size/padding" );
133133
134- typedef float (*vec_dot_q_cuda_t )(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
134+ typedef float (*vec_dot_q_cuda_t )(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs);
135135
136136// ================================= k-quants
137137
@@ -407,7 +407,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
407407
408408// ================================== k-quants
409409
410- static __global__ void dequantize_block_q2_K (const void * vx, float * yy) {
410+ static __global__ void dequantize_block_q2_K (const void * __restrict__ vx, float * __restrict__ yy) {
411411
412412 const int i = blockIdx .x ;
413413 const block_q2_K * x = (const block_q2_K *) vx;
@@ -440,7 +440,7 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
440440
441441}
442442
443- static __global__ void dequantize_block_q3_K (const void * vx, float * yy) {
443+ static __global__ void dequantize_block_q3_K (const void * __restrict__ vx, float * __restrict__ yy) {
444444
445445 const int i = blockIdx .x ;
446446 const block_q3_K * x = (const block_q3_K *) vx;
@@ -504,7 +504,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
504504}
505505#endif
506506
507- static __global__ void dequantize_block_q4_K (const void * vx, float * yy) {
507+ static __global__ void dequantize_block_q4_K (const void * __restrict__ vx, float * __restrict__ yy) {
508508 const block_q4_K * x = (const block_q4_K *) vx;
509509
510510 const int i = blockIdx .x ;
@@ -544,7 +544,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
544544#endif
545545}
546546
547- static __global__ void dequantize_block_q5_K (const void * vx, float * yy) {
547+ static __global__ void dequantize_block_q5_K (const void * __restrict__ vx, float * __restrict__ yy) {
548548 const block_q5_K * x = (const block_q5_K *) vx;
549549
550550 const int i = blockIdx .x ;
@@ -590,7 +590,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
590590#endif
591591}
592592
593- static __global__ void dequantize_block_q6_K (const void * vx, float * yy) {
593+ static __global__ void dequantize_block_q6_K (const void * __restrict__ vx, float * __restrict__ yy) {
594594 const block_q6_K * x = (const block_q6_K *) vx;
595595
596596 const int i = blockIdx .x ;
@@ -634,7 +634,7 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
634634#endif
635635}
636636
637- static __global__ void dequantize_mul_mat_vec_q2_k (const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
637+ static __global__ void dequantize_mul_mat_vec_q2_k (const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
638638
639639 static_assert (16 %K_QUANTS_PER_ITERATION == 0 , " 16 must be divisible by K_QUANTS_PER_ITERATION" );
640640
@@ -742,7 +742,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
742742 }
743743}
744744
745- static __global__ void dequantize_mul_mat_vec_q3_k (const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
745+ static __global__ void dequantize_mul_mat_vec_q3_k (const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
746746
747747 const int row = blockIdx .y *blockDim .y + threadIdx .y ;
748748 if (row > nrows) return ;
@@ -846,7 +846,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
846846 }
847847}
848848
849- static __global__ void dequantize_mul_mat_vec_q4_k (const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
849+ static __global__ void dequantize_mul_mat_vec_q4_k (const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
850850
851851 const int row = blockIdx .y *blockDim .y + threadIdx .y ;
852852 if (row > nrows) return ;
@@ -949,7 +949,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
949949 }
950950}
951951
952- static __global__ void dequantize_mul_mat_vec_q5_k (const void * vx, const float * yy, float * dst, const int ncols) {
952+ static __global__ void dequantize_mul_mat_vec_q5_k (const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
953953
954954 const int row = blockIdx .x ;
955955 const int num_blocks_per_row = ncols / QK_K;
@@ -1053,7 +1053,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
10531053 }
10541054}
10551055
1056- static __global__ void dequantize_mul_mat_vec_q6_k (const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
1056+ static __global__ void dequantize_mul_mat_vec_q6_k (const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
10571057
10581058 static_assert (16 %K_QUANTS_PER_ITERATION == 0 , " 16 must be divisible by K_QUANTS_PER_ITERATION" );
10591059
@@ -1171,7 +1171,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
11711171 v.y = x[ib + iqs + 1 ];
11721172}
11731173
1174- static __global__ void quantize_q8_1 (const float * x, void * vy, const int k) {
1174+ static __global__ void quantize_q8_1 (const float * __restrict__ x, void * __restrict__ vy, const int k) {
11751175 const int i = blockDim .x *blockIdx .x + threadIdx .x ;
11761176
11771177 if (i >= k) {
@@ -1207,7 +1207,7 @@ static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
12071207}
12081208
12091209template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
1210- static __global__ void dequantize_block (const void * vx, float * y, const int k) {
1210+ static __global__ void dequantize_block (const void * __restrict__ vx, float * __restrict__ y, const int k) {
12111211 const int i = blockDim .x *blockIdx .x + 2 *threadIdx .x ;
12121212
12131213 if (i >= k) {
@@ -1227,7 +1227,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
12271227 y[iybs + iqs + y_offset] = v.y ;
12281228}
12291229
1230- static __device__ __forceinline__ float vec_dot_q4_0_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1230+ static __device__ __forceinline__ float vec_dot_q4_0_q8_1 (const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
12311231#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
12321232 const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
12331233
@@ -1252,7 +1252,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, cons
12521252#endif // __CUDA_ARCH__ >= 600
12531253}
12541254
1255- static __device__ __forceinline__ float vec_dot_q4_1_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1255+ static __device__ __forceinline__ float vec_dot_q4_1_q8_1 (const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
12561256#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
12571257 const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
12581258
@@ -1277,7 +1277,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, cons
12771277#endif // __CUDA_ARCH__ >= 600
12781278}
12791279
1280- static __device__ __forceinline__ float vec_dot_q5_0_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1280+ static __device__ __forceinline__ float vec_dot_q5_0_q8_1 (const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
12811281#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
12821282 const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
12831283
@@ -1312,7 +1312,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, cons
13121312#endif // __CUDA_ARCH__ >= 600
13131313}
13141314
1315- static __device__ __forceinline__ float vec_dot_q5_1_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1315+ static __device__ __forceinline__ float vec_dot_q5_1_q8_1 (const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
13161316#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
13171317 const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
13181318
@@ -1346,7 +1346,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, cons
13461346#endif // __CUDA_ARCH__ >= 600
13471347}
13481348
1349- static __device__ __forceinline__ float vec_dot_q8_0_q8_1 (const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
1349+ static __device__ __forceinline__ float vec_dot_q8_0_q8_1 (const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
13501350#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
13511351 const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
13521352
@@ -1366,7 +1366,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, cons
13661366}
13671367
13681368template <int qk, int qi, typename block_q_t , vec_dot_q_cuda_t vec_dot_q_cuda>
1369- static __global__ void mul_mat_vec_q (const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
1369+ static __global__ void mul_mat_vec_q (const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
13701370 const int row = blockIdx .y *blockDim .y + threadIdx .y ;
13711371
13721372 if (row >= nrows) {
@@ -1404,7 +1404,7 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
14041404}
14051405
14061406template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
1407- static __global__ void dequantize_mul_mat_vec (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
1407+ static __global__ void dequantize_mul_mat_vec (const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
14081408 // qk = quantized weights per x block
14091409 // qr = number of quantized weights per data value in x block
14101410 const int row = blockIdx .y *blockDim .y + threadIdx .y ;
@@ -1471,7 +1471,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
14711471 }
14721472}
14731473
1474- static __global__ void mul_mat_p021_f16_f32 (const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
1474+ static __global__ void mul_mat_p021_f16_f32 (const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
14751475 const half * x = (const half *) vx;
14761476
14771477 const int row_x = blockDim .y *blockIdx .y + threadIdx .y ;
@@ -1518,7 +1518,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
15181518}
15191519
15201520static __global__ void mul_mat_vec_nc_f16_f32 ( // nc == non-contiguous
1521- const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
1521+ const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
15221522 const int row_stride_x, const int channel_stride_x) {
15231523
15241524 const half * x = (const half *) vx;
@@ -2355,10 +2355,7 @@ inline void ggml_cuda_op_mul_mat_vec(
23552355 src0->type == GGML_TYPE_Q5_1 ||
23562356 src0->type == GGML_TYPE_Q8_0;
23572357
2358- // The integer intrinsics used in mul_mat_vec_q are available with compute capability 6.
2359- // However, they have bad performance with Pascal cards.
2360- // Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q.
2361- const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented;
2358+ const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented;
23622359#endif
23632360
23642361 if (use_mul_mat_vec_q) {
0 commit comments