@@ -77,13 +77,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
7777 uint8_t * packed,
7878 const int8x16_t & unpacked0,
7979 const int8x16_t & unpacked1) {
80- static_assert (nbit < 8 );
80+ static_assert (nbit < 9 );
8181 static_assert (nbit >= 1 );
8282
83- // Shift unpacked values to nonnegative range
84- int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
85- uint8x16_t shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
86- uint8x16_t shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
83+ // Shift unpacked values to nonnegative range for quantization of 1-7 bits
84+ // No shifting is needed for 8-bit packing
85+ uint8x16_t shifted0;
86+ uint8x16_t shifted1;
87+ if constexpr (nbit < 8 ) {
88+ int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
89+ shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
90+ shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
91+ }
8792
8893 switch (nbit) {
8994 case 1 :
@@ -151,6 +156,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
151156 torchao::bitpacking::internal::pack_8_uint7_values (packed + 14 , buffer7 + 16 );
152157 torchao::bitpacking::internal::pack_8_uint7_values (packed + 21 , buffer7 + 24 );
153158 break ;
159+ case 8 :
160+ vst1q_u8 (packed, vreinterpretq_u8_s8 (unpacked0));
161+ vst1q_u8 (packed + 16 , vreinterpretq_u8_s8 (unpacked1));
162+ break ;
154163 default :
155164 assert (false );
156165 }
@@ -161,7 +170,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
161170 int8x16_t & unpacked0,
162171 int8x16_t & unpacked1,
163172 const uint8_t * packed) {
164- static_assert (nbit < 8 );
173+ static_assert (nbit < 9 );
165174 static_assert (nbit >= 1 );
166175
167176 uint8x16_t shifted0;
@@ -234,14 +243,21 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
234243 shifted0 = vld1q_u8 (buffer7);
235244 shifted1 = vld1q_u8 (buffer7 + 16 );
236245 break ;
246+ case 8 :
247+ unpacked0 = vreinterpretq_s8_u8 (vld1q_u8 (packed));
248+ unpacked1 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 16 ));
249+ break ;
237250 default :
238251 assert (false );
239252 }
240253
241254 // unshift to move unpacked values to full range
242- int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
243- unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
244- unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
255+ // no shifting is needed for 8-bit packing
256+ if constexpr (nbit < 8 ) {
257+ int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
258+ unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
259+ unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
260+ }
245261}
246262
247263template <int nbit>
@@ -251,15 +267,23 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
251267 const int8x16_t & unpacked1,
252268 const int8x16_t & unpacked2,
253269 const int8x16_t & unpacked3) {
254- static_assert (nbit < 8 );
270+ static_assert (nbit < 9 );
255271 static_assert (nbit >= 1 );
256272
257- // Shift unpacked values to nonnegative range
258- int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
259- uint8x16_t shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
260- uint8x16_t shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
261- uint8x16_t shifted2 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked2, shift));
262- uint8x16_t shifted3 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked3, shift));
273+ // Shift unpacked values to nonnegative range for quantization of 1-7 bits
274+ // No shifting is needed for 8-bit packing
275+ uint8x16_t shifted0;
276+ uint8x16_t shifted1;
277+ uint8x16_t shifted2;
278+ uint8x16_t shifted3;
279+ if constexpr (nbit < 8 ) {
280+ int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
281+ shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
282+ shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
283+ shifted2 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked2, shift));
284+ shifted3 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked3, shift));
285+ }
286+
263287
264288 switch (nbit) {
265289 case 1 :
@@ -292,6 +316,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
292316 torchao::bitpacking::internal::vec_pack_64_uint7_values (
293317 packed, shifted0, shifted1, shifted2, shifted3);
294318 break ;
319+ case 8 :
320+ vst1q_u8 (packed, vreinterpretq_u8_s8 (unpacked0));
321+ vst1q_u8 (packed + 16 , vreinterpretq_u8_s8 (unpacked1));
322+ vst1q_u8 (packed + 32 , vreinterpretq_u8_s8 (unpacked2));
323+ vst1q_u8 (packed + 48 , vreinterpretq_u8_s8 (unpacked3));
324+ break ;
295325 default :
296326 assert (false );
297327 }
@@ -304,7 +334,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
304334 int8x16_t & unpacked2,
305335 int8x16_t & unpacked3,
306336 const uint8_t * packed) {
307- static_assert (nbit < 8 );
337+ static_assert (nbit < 9 );
308338 static_assert (nbit >= 1 );
309339
310340 uint8x16_t shifted0;
@@ -343,16 +373,25 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
343373 torchao::bitpacking::internal::vec_unpack_64_uint7_values (
344374 shifted0, shifted1, shifted2, shifted3, packed);
345375 break ;
376+ case 8 :
377+ unpacked0 = vreinterpretq_s8_u8 (vld1q_u8 (packed));
378+ unpacked1 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 16 ));
379+ unpacked2 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 32 ));
380+ unpacked3 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 48 ));
381+ break ;
346382 default :
347383 assert (false );
348384 }
349385
350386 // unshift to move unpacked values to full range
351- int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
352- unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
353- unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
354- unpacked2 = vaddq_s8 (vreinterpretq_s8_u8 (shifted2), unshift);
355- unpacked3 = vaddq_s8 (vreinterpretq_s8_u8 (shifted3), unshift);
387+ // no shifting is needed for 8-bit packing
388+ if constexpr (nbit < 8 ) {
389+ int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
390+ unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
391+ unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
392+ unpacked2 = vaddq_s8 (vreinterpretq_s8_u8 (shifted2), unshift);
393+ unpacked3 = vaddq_s8 (vreinterpretq_s8_u8 (shifted3), unshift);
394+ }
356395}
357396
358397template <int nbit>
@@ -366,19 +405,31 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
366405 const int8x16_t & unpacked5,
367406 const int8x16_t & unpacked6,
368407 const int8x16_t & unpacked7) {
369- static_assert (nbit < 8 );
408+ static_assert (nbit < 9 );
370409 static_assert (nbit >= 1 );
371410
372- // Shift unpacked values to nonnegative range
373- int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
374- uint8x16_t shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
375- uint8x16_t shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
376- uint8x16_t shifted2 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked2, shift));
377- uint8x16_t shifted3 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked3, shift));
378- uint8x16_t shifted4 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked4, shift));
379- uint8x16_t shifted5 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked5, shift));
380- uint8x16_t shifted6 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked6, shift));
381- uint8x16_t shifted7 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked7, shift));
411+ // Shift unpacked values to nonnegative range for quantization of 1-7 bits
412+ // No shifting is needed for 8-bit packing
413+ uint8x16_t shifted0;
414+ uint8x16_t shifted1;
415+ uint8x16_t shifted2;
416+ uint8x16_t shifted3;
417+ uint8x16_t shifted4;
418+ uint8x16_t shifted5;
419+ uint8x16_t shifted6;
420+ uint8x16_t shifted7;
421+ if constexpr (nbit < 8 ) {
422+ int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
423+ shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
424+ shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
425+ shifted2 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked2, shift));
426+ shifted3 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked3, shift));
427+ shifted4 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked4, shift));
428+ shifted5 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked5, shift));
429+ shifted6 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked6, shift));
430+ shifted7 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked7, shift));
431+ }
432+
382433
383434 switch (nbit) {
384435 case 1 :
@@ -451,6 +502,16 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
451502 shifted6,
452503 shifted7);
453504 break ;
505+ case 8 :
506+ vst1q_u8 (packed, vreinterpretq_u8_s8 (unpacked0));
507+ vst1q_u8 (packed + 16 , vreinterpretq_u8_s8 (unpacked1));
508+ vst1q_u8 (packed + 32 , vreinterpretq_u8_s8 (unpacked2));
509+ vst1q_u8 (packed + 48 , vreinterpretq_u8_s8 (unpacked3));
510+ vst1q_u8 (packed + 64 , vreinterpretq_u8_s8 (unpacked4));
511+ vst1q_u8 (packed + 80 , vreinterpretq_u8_s8 (unpacked5));
512+ vst1q_u8 (packed + 96 , vreinterpretq_u8_s8 (unpacked6));
513+ vst1q_u8 (packed + 112 , vreinterpretq_u8_s8 (unpacked7));
514+ break ;
454515 default :
455516 assert (false );
456517 }
@@ -467,7 +528,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
467528 int8x16_t & unpacked6,
468529 int8x16_t & unpacked7,
469530 const uint8_t * packed) {
470- static_assert (nbit < 8 );
531+ static_assert (nbit < 9 );
471532 static_assert (nbit >= 1 );
472533
473534 uint8x16_t shifted0;
@@ -550,20 +611,33 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
550611 shifted7,
551612 packed);
552613 break ;
614+ case 8 :
615+ unpacked0 = vreinterpretq_s8_u8 (vld1q_u8 (packed));
616+ unpacked1 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 16 ));
617+ unpacked2 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 32 ));
618+ unpacked3 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 48 ));
619+ unpacked4 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 64 ));
620+ unpacked5 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 80 ));
621+ unpacked6 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 96 ));
622+ unpacked7 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 112 ));
623+ break ;
553624 default :
554625 assert (false );
555626 }
556627
557628 // unshift to move unpacked values to full range
558- int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
559- unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
560- unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
561- unpacked2 = vaddq_s8 (vreinterpretq_s8_u8 (shifted2), unshift);
562- unpacked3 = vaddq_s8 (vreinterpretq_s8_u8 (shifted3), unshift);
563- unpacked4 = vaddq_s8 (vreinterpretq_s8_u8 (shifted4), unshift);
564- unpacked5 = vaddq_s8 (vreinterpretq_s8_u8 (shifted5), unshift);
565- unpacked6 = vaddq_s8 (vreinterpretq_s8_u8 (shifted6), unshift);
566- unpacked7 = vaddq_s8 (vreinterpretq_s8_u8 (shifted7), unshift);
629+ // no shifting is needed for 8-bit packing
630+ if constexpr (nbit < 8 ) {
631+ int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
632+ unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
633+ unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
634+ unpacked2 = vaddq_s8 (vreinterpretq_s8_u8 (shifted2), unshift);
635+ unpacked3 = vaddq_s8 (vreinterpretq_s8_u8 (shifted3), unshift);
636+ unpacked4 = vaddq_s8 (vreinterpretq_s8_u8 (shifted4), unshift);
637+ unpacked5 = vaddq_s8 (vreinterpretq_s8_u8 (shifted5), unshift);
638+ unpacked6 = vaddq_s8 (vreinterpretq_s8_u8 (shifted6), unshift);
639+ unpacked7 = vaddq_s8 (vreinterpretq_s8_u8 (shifted7), unshift);
640+ }
567641}
568642
569643} // namespace bitpacking
0 commit comments