@@ -81,49 +81,39 @@ typedef struct {
8181static_assert (sizeof (block_q8_0) == sizeof (float ) + QK8_0, " wrong q8_0 block size/padding" );
8282
8383static __global__ void dequantize_block_q4_0 (const void * vx, float * y) {
84+ static const int qk = QK4_0;
85+
8486 const block_q4_0 * x = (const block_q4_0 *) vx;
8587
8688 const int i = blockIdx .x ;
8789
8890 const float d = x[i].d ;
8991
90- const uint8_t * pp = x[i].qs ;
91-
92- for (int l = 0 ; l < QK4_0; l += 2 ) {
93- const uint8_t vi = pp[l/2 ];
94-
95- const int8_t vi0 = vi & 0xf ;
96- const int8_t vi1 = vi >> 4 ;
92+ for (int j = 0 ; j < qk/2 ; ++j) {
93+ const int x0 = (x[i].qs [j] & 0xf ) - 8 ;
94+ const int x1 = (x[i].qs [j] >> 4 ) - 8 ;
9795
98- const float v0 = (vi0 - 8 )*d;
99- const float v1 = (vi1 - 8 )*d;
100-
101- y[i*QK4_0 + l + 0 ] = v0;
102- y[i*QK4_0 + l + 1 ] = v1;
96+ y[i*qk + j + 0 ] = x0*d;
97+ y[i*qk + j + qk/2 ] = x1*d;
10398 }
10499}
105100
106101static __global__ void dequantize_block_q4_1 (const void * vx, float * y) {
102+ static const int qk = QK4_1;
103+
107104 const block_q4_1 * x = (const block_q4_1 *) vx;
108105
109106 const int i = blockIdx .x ;
110107
111108 const float d = x[i].d ;
112109 const float m = x[i].m ;
113110
114- const uint8_t * pp = x[i].qs ;
115-
116- for (int l = 0 ; l < QK4_1; l += 2 ) {
117- const uint8_t vi = pp[l/2 ];
118-
119- const int8_t vi0 = vi & 0xf ;
120- const int8_t vi1 = vi >> 4 ;
111+ for (int j = 0 ; j < qk/2 ; ++j) {
112+ const int x0 = (x[i].qs [j] & 0xf );
113+ const int x1 = (x[i].qs [j] >> 4 );
121114
122- const float v0 = vi0*d + m;
123- const float v1 = vi1*d + m;
124-
125- y[i*QK4_1 + l + 0 ] = v0;
126- y[i*QK4_1 + l + 1 ] = v1;
115+ y[i*qk + j + 0 ] = x0*d + m;
116+ y[i*qk + j + qk/2 ] = x1*d + m;
127117 }
128118}
129119
@@ -151,61 +141,51 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
151141}
152142
153143static __global__ void dequantize_block_q5_0 (const void * vx, float * y) {
144+ static const int qk = QK5_0;
145+
154146 const block_q5_0 * x = (const block_q5_0 *) vx;
155147
156148 const int i = blockIdx .x ;
157149
158150 const float d = x[i].d ;
159151
160- const uint8_t * pp = x[i].qs ;
161-
162152 uint32_t qh;
163153 memcpy (&qh, x[i].qh , sizeof (qh));
164154
165- for (int l = 0 ; l < QK5_0; l += 2 ) {
166- const uint8_t vi = pp[l/2 ];
167-
168- const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
169- const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
155+ for (int j = 0 ; j < qk/2 ; ++j) {
156+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4 ;
157+ const uint8_t xh_1 = ((qh & (1u << (j + 16 ))) >> (j + 12 ));
170158
171- const int8_t vi0 = ((vi & 0xf ) | vh0) ;
172- const int8_t vi1 = ((vi >> 4 ) | vh1) ;
159+ const int32_t x0 = ((x[i]. qs [j] & 0xf ) | xh_0) - 16 ;
160+ const int32_t x1 = ((x[i]. qs [j] >> 4 ) | xh_1) - 16 ;
173161
174- const float v0 = (vi0 - 16 )*d;
175- const float v1 = (vi1 - 16 )*d;
176-
177- y[i*QK5_0 + l + 0 ] = v0;
178- y[i*QK5_0 + l + 1 ] = v1;
162+ y[i*qk + j + 0 ] = x0*d;
163+ y[i*qk + j + qk/2 ] = x1*d;
179164 }
180165}
181166
182167static __global__ void dequantize_block_q5_1 (const void * vx, float * y) {
168+ static const int qk = QK5_1;
169+
183170 const block_q5_1 * x = (const block_q5_1 *) vx;
184171
185172 const int i = blockIdx .x ;
186173
187174 const float d = x[i].d ;
188175 const float m = x[i].m ;
189176
190- const uint8_t * pp = x[i].qs ;
191-
192177 uint32_t qh;
193178 memcpy (&qh, x[i].qh , sizeof (qh));
194179
195- for (int l = 0 ; l < QK5_1; l += 2 ) {
196- const uint8_t vi = pp[l/2 ];
197-
198- const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
199- const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
200-
201- const int8_t vi0 = (vi & 0xf ) | vh0;
202- const int8_t vi1 = (vi >> 4 ) | vh1;
180+ for (int j = 0 ; j < qk/2 ; ++j) {
181+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4 ;
182+ const uint8_t xh_1 = ((qh & (1u << (j + 16 ))) >> (j + 12 ));
203183
204- const float v0 = vi0*d + m ;
205- const float v1 = vi1*d + m ;
184+ const int x0 = (x[i]. qs [j] & 0xf ) | xh_0 ;
185+ const int x1 = (x[i]. qs [j] >> 4 ) | xh_1 ;
206186
207- y[i*QK5_1 + l + 0 ] = v0 ;
208- y[i*QK5_1 + l + 1 ] = v1 ;
187+ y[i*qk + j + 0 ] = x0*d + m ;
188+ y[i*qk + j + qk/ 2 ] = x1*d + m ;
209189 }
210190}
211191
0 commit comments