@@ -80,192 +80,227 @@ typedef struct {
8080} block_q8_0;
8181static_assert (sizeof (block_q8_0) == sizeof (float ) + QK8_0, " wrong q8_0 block size/padding" );
8282
83- static __global__ void dequantize_block_q4_0 (const void * vx, float * y) {
83+ static __global__ void dequantize_block_q4_0 (const void * vx, float * y, int k ) {
8484 const block_q4_0 * x = (const block_q4_0 *) vx;
8585
86- const int i = blockIdx .x ;
86+ const int i = blockIdx .x * blockDim . x + threadIdx . x ;
8787
88- const float d = x[i].d ;
88+ if (i < k) {
89+ const float d = x[i].d ;
8990
90- const uint8_t * pp = x[i].qs ;
91+ const uint8_t * pp = x[i].qs ;
9192
92- for (int l = 0 ; l < QK4_0; l += 2 ) {
93- const uint8_t vi = pp[l/2 ];
93+ for (int l = 0 ; l < QK4_0; l += 2 ) {
94+ const uint8_t vi = pp[l/2 ];
9495
95- const int8_t vi0 = vi & 0xf ;
96- const int8_t vi1 = vi >> 4 ;
96+ const int8_t vi0 = vi & 0xf ;
97+ const int8_t vi1 = vi >> 4 ;
9798
98- const float v0 = (vi0 - 8 )*d;
99- const float v1 = (vi1 - 8 )*d;
99+ const float v0 = (vi0 - 8 )*d;
100+ const float v1 = (vi1 - 8 )*d;
100101
101- y[i*QK4_0 + l + 0 ] = v0;
102- y[i*QK4_0 + l + 1 ] = v1;
102+ y[i*QK4_0 + l + 0 ] = v0;
103+ y[i*QK4_0 + l + 1 ] = v1;
104+ }
103105 }
104106}
105107
106- static __global__ void dequantize_block_q4_1 (const void * vx, float * y) {
108+ static __global__ void dequantize_block_q4_1 (const void * vx, float * y, int k ) {
107109 const block_q4_1 * x = (const block_q4_1 *) vx;
108110
109- const int i = blockIdx .x ;
111+ const int i = blockIdx .x * blockDim . x + threadIdx . x ;
110112
111- const float d = x[i].d ;
112- const float m = x[i].m ;
113+ if (i < k) {
114+ const float d = x[i].d ;
115+ const float m = x[i].m ;
113116
114- const uint8_t * pp = x[i].qs ;
117+ const uint8_t * pp = x[i].qs ;
115118
116- for (int l = 0 ; l < QK4_1; l += 2 ) {
117- const uint8_t vi = pp[l/2 ];
119+ for (int l = 0 ; l < QK4_1; l += 2 ) {
120+ const uint8_t vi = pp[l/2 ];
118121
119- const int8_t vi0 = vi & 0xf ;
120- const int8_t vi1 = vi >> 4 ;
122+ const int8_t vi0 = vi & 0xf ;
123+ const int8_t vi1 = vi >> 4 ;
121124
122- const float v0 = vi0*d + m;
123- const float v1 = vi1*d + m;
125+ const float v0 = vi0*d + m;
126+ const float v1 = vi1*d + m;
124127
125- y[i*QK4_1 + l + 0 ] = v0;
126- y[i*QK4_1 + l + 1 ] = v1;
128+ y[i*QK4_1 + l + 0 ] = v0;
129+ y[i*QK4_1 + l + 1 ] = v1;
130+ }
127131 }
128132}
129133
130- static __global__ void dequantize_block_q4_2 (const void * vx, float * y) {
134+ static __global__ void dequantize_block_q4_2 (const void * vx, float * y, int k ) {
131135 const block_q4_2 * x = (const block_q4_2 *) vx;
132136
133- const int i = blockIdx .x ;
137+ const int i = blockIdx .x * blockDim . x + threadIdx . x ;
134138
135- const float d = x[i].d ;
139+ if (i < k) {
140+ const float d = x[i].d ;
136141
137- const uint8_t * pp = x[i].qs ;
142+ const uint8_t * pp = x[i].qs ;
138143
139- for (int l = 0 ; l < QK4_2; l += 2 ) {
140- const uint8_t vi = pp[l/2 ];
144+ for (int l = 0 ; l < QK4_2; l += 2 ) {
145+ const uint8_t vi = pp[l/2 ];
141146
142- const int8_t vi0 = vi & 0xf ;
143- const int8_t vi1 = vi >> 4 ;
147+ const int8_t vi0 = vi & 0xf ;
148+ const int8_t vi1 = vi >> 4 ;
144149
145- const float v0 = (vi0 - 8 )*d;
146- const float v1 = (vi1 - 8 )*d;
150+ const float v0 = (vi0 - 8 )*d;
151+ const float v1 = (vi1 - 8 )*d;
147152
148- y[i*QK4_2 + l + 0 ] = v0;
149- y[i*QK4_2 + l + 1 ] = v1;
153+ y[i*QK4_2 + l + 0 ] = v0;
154+ y[i*QK4_2 + l + 1 ] = v1;
155+ }
150156 }
151157}
152158
153- static __global__ void dequantize_block_q5_0 (const void * vx, float * y) {
159+ static __global__ void dequantize_block_q5_0 (const void * vx, float * y, int k ) {
154160 const block_q5_0 * x = (const block_q5_0 *) vx;
155161
156- const int i = blockIdx .x ;
162+ const int i = blockIdx .x * blockDim . x + threadIdx . x ;
157163
158- const float d = x[i].d ;
164+ if (i < k) {
165+ const float d = x[i].d ;
159166
160- const uint8_t * pp = x[i].qs ;
167+ const uint8_t * pp = x[i].qs ;
161168
162- uint32_t qh;
163- memcpy (&qh, x[i].qh , sizeof (qh));
169+ uint32_t qh;
170+ memcpy (&qh, x[i].qh , sizeof (qh));
164171
165- for (int l = 0 ; l < QK5_0; l += 2 ) {
166- const uint8_t vi = pp[l/2 ];
172+ for (int l = 0 ; l < QK5_0; l += 2 ) {
173+ const uint8_t vi = pp[l/2 ];
167174
168- const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
169- const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
175+ const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
176+ const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
170177
171- const int8_t vi0 = ((vi & 0xf ) | vh0);
172- const int8_t vi1 = ((vi >> 4 ) | vh1);
178+ const int8_t vi0 = ((vi & 0xf ) | vh0);
179+ const int8_t vi1 = ((vi >> 4 ) | vh1);
173180
174- const float v0 = (vi0 - 16 )*d;
175- const float v1 = (vi1 - 16 )*d;
181+ const float v0 = (vi0 - 16 )*d;
182+ const float v1 = (vi1 - 16 )*d;
176183
177- y[i*QK5_0 + l + 0 ] = v0;
178- y[i*QK5_0 + l + 1 ] = v1;
184+ y[i*QK5_0 + l + 0 ] = v0;
185+ y[i*QK5_0 + l + 1 ] = v1;
186+ }
179187 }
180188}
181189
182- static __global__ void dequantize_block_q5_1 (const void * vx, float * y) {
190+ static __global__ void dequantize_block_q5_1 (const void * vx, float * y, int k ) {
183191 const block_q5_1 * x = (const block_q5_1 *) vx;
184192
185- const int i = blockIdx .x ;
193+ const int i = blockIdx .x * blockDim . x + threadIdx . x ;
186194
187- const float d = x[i].d ;
188- const float m = x[i].m ;
195+ if (i < k) {
196+ const float d = x[i].d ;
197+ const float m = x[i].m ;
189198
190- const uint8_t * pp = x[i].qs ;
199+ const uint8_t * pp = x[i].qs ;
191200
192- uint32_t qh;
193- memcpy (&qh, x[i].qh , sizeof (qh));
201+ uint32_t qh;
202+ memcpy (&qh, x[i].qh , sizeof (qh));
194203
195- for (int l = 0 ; l < QK5_1; l += 2 ) {
196- const uint8_t vi = pp[l/2 ];
204+ for (int l = 0 ; l < QK5_1; l += 2 ) {
205+ const uint8_t vi = pp[l/2 ];
197206
198- const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
199- const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
207+ const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
208+ const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
200209
201- const int8_t vi0 = (vi & 0xf ) | vh0;
202- const int8_t vi1 = (vi >> 4 ) | vh1;
210+ const int8_t vi0 = (vi & 0xf ) | vh0;
211+ const int8_t vi1 = (vi >> 4 ) | vh1;
203212
204- const float v0 = vi0*d + m;
205- const float v1 = vi1*d + m;
213+ const float v0 = vi0*d + m;
214+ const float v1 = vi1*d + m;
206215
207- y[i*QK5_1 + l + 0 ] = v0;
208- y[i*QK5_1 + l + 1 ] = v1;
216+ y[i*QK5_1 + l + 0 ] = v0;
217+ y[i*QK5_1 + l + 1 ] = v1;
218+ }
209219 }
210220}
211221
212- static __global__ void dequantize_block_q8_0 (const void * vx, float * y) {
222+ static __global__ void dequantize_block_q8_0 (const void * vx, float * y, int k ) {
213223 const block_q8_0 * x = (const block_q8_0 *) vx;
214224
215- const int i = blockIdx .x ;
225+ const int i = blockIdx .x * blockDim . x + threadIdx . x ;
216226
217- const float d = x[i].d ;
227+ if (i < k) {
228+ const float d = x[i].d ;
218229
219- const int8_t * pp = x[i].qs ;
230+ const int8_t * pp = x[i].qs ;
220231
221- for (int l = 0 ; l < QK8_0; l++) {
222- const int8_t vi = pp[l];
232+ for (int l = 0 ; l < QK8_0; l++) {
233+ const int8_t vi = pp[l];
223234
224- y[i*QK8_0 + l] = vi*d;
235+ y[i*QK8_0 + l] = vi*d;
236+ }
225237 }
226238}
227239
228240static void dequantize_row_q4_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
229241 const int nb = k / QK4_0;
230- dequantize_block_q4_0<<<nb, 1 , 0 , stream>>> (vx, y);
242+ int min_grid_size, block_size = 1 ; // Initialize to suppress compiler warning.
243+ CUDA_CHECK (cudaOccupancyMaxPotentialBlockSize (&min_grid_size, &block_size, dequantize_block_q4_0, 0 , 0 ));
244+ int grid_size = (nb + block_size - 1 ) / block_size; // Round up.
245+ dequantize_block_q4_0<<<grid_size, block_size, 0 , stream>>> (vx, y, k);
231246}
232247
233248static void dequantize_row_q4_1_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
234249 const int nb = k / QK4_1;
235- dequantize_block_q4_1<<<nb, 1 , 0 , stream>>> (vx, y);
250+ int min_grid_size, block_size = 1 ; // Initialize to suppress compiler warning.
251+ CUDA_CHECK (cudaOccupancyMaxPotentialBlockSize (&min_grid_size, &block_size, dequantize_block_q4_1, 0 , 0 ));
252+ int grid_size = (nb + block_size - 1 ) / block_size; // Round up.
253+ dequantize_block_q4_1<<<grid_size, block_size, 0 , stream>>> (vx, y, k);
236254}
237255
238256static void dequantize_row_q4_2_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
239257 const int nb = k / QK4_2;
240- dequantize_block_q4_2<<<nb, 1 , 0 , stream>>> (vx, y);
258+ int min_grid_size, block_size = 1 ; // Initialize to suppress compiler warning.
259+ CUDA_CHECK (cudaOccupancyMaxPotentialBlockSize (&min_grid_size, &block_size, dequantize_block_q4_2, 0 , 0 ));
260+ int grid_size = (nb + block_size - 1 ) / block_size; // Round up.
261+ dequantize_block_q4_2<<<grid_size, block_size, 0 , stream>>> (vx, y, k);
241262}
242263
243264static void dequantize_row_q5_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
244265 const int nb = k / QK5_0;
245- dequantize_block_q5_0<<<nb, 1 , 0 , stream>>> (vx, y);
266+ int min_grid_size, block_size = 1 ; // Initialize to suppress compiler warning.
267+ CUDA_CHECK (cudaOccupancyMaxPotentialBlockSize (&min_grid_size, &block_size, dequantize_block_q5_0, 0 , 0 ));
268+ int grid_size = (nb + block_size - 1 ) / block_size; // Round up.
269+ dequantize_block_q5_0<<<grid_size, block_size, 0 , stream>>> (vx, y, k);
246270}
247271
248272static void dequantize_row_q5_1_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
249273 const int nb = k / QK5_1;
250- dequantize_block_q5_1<<<nb, 1 , 0 , stream>>> (vx, y);
274+ int min_grid_size, block_size = 1 ; // Initialize to suppress compiler warning.
275+ CUDA_CHECK (cudaOccupancyMaxPotentialBlockSize (&min_grid_size, &block_size, dequantize_block_q5_1, 0 , 0 ));
276+ int grid_size = (nb + block_size - 1 ) / block_size; // Round up.
277+ dequantize_block_q5_1<<<grid_size, block_size, 0 , stream>>> (vx, y, k);
251278}
252279
253280static void dequantize_row_q8_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
254281 const int nb = k / QK8_0;
255- dequantize_block_q8_0<<<nb, 1 , 0 , stream>>> (vx, y);
282+ int min_grid_size, block_size = 1 ; // Initialize to suppress compiler warning.
283+ CUDA_CHECK (cudaOccupancyMaxPotentialBlockSize (&min_grid_size, &block_size, dequantize_block_q8_0, 0 , 0 ));
284+ int grid_size = (nb + block_size - 1 ) / block_size; // Round up.
285+ dequantize_block_q8_0<<<grid_size, block_size, 0 , stream>>> (vx, y, k);
256286}
257287
258288// TODO: optimize
259- static __global__ void convert_fp16_to_fp32 (const void * vx, float * y) {
289+ static __global__ void convert_fp16_to_fp32 (const void * vx, float * y, int k ) {
260290 const half * x = (const half *) vx;
261291
262292 const int i = blockIdx .x ;
263293
264- y[i] = __half2float (x[i]);
294+ if (i < k) {
295+ y[i] = __half2float (x[i]);
296+ }
265297}
266298
267299static void convert_fp16_to_fp32_cuda (const void * x, float * y, int k, cudaStream_t stream) {
268- convert_fp16_to_fp32<<<k, 1 , 0 , stream>>> (x, y);
300+ int min_grid_size, block_size = 1 ; // Initialize to suppress compiler warning.
301+ CUDA_CHECK (cudaOccupancyMaxPotentialBlockSize (&min_grid_size, &block_size, convert_fp16_to_fp32, 0 , 0 ));
302+ int grid_size = (k + block_size - 1 ) / block_size; // Round up.
303+ convert_fp16_to_fp32<<<grid_size, block_size, 0 , stream>>> (x, y, k);
269304}
270305
271306static to_fp32_cuda_t ggml_get_to_fp32_cuda (ggml_type type) {
0 commit comments