From 101e980e4490a12e8885e5e257bc2bba77b58b7a Mon Sep 17 00:00:00 2001 From: Alex Oleinikov Date: Fri, 8 Nov 2024 09:55:49 -0800 Subject: [PATCH] 8-bit packing support Summary: Add support for 8-bit quantization for Llama in torchchat. 8-bit packing behaves exactly like 1-7 bit packing, but the implementation is slightly different: 8-bit chunks are used as-is without shifting. Reviewed By: metascroy Differential Revision: D65570988 --- .../kernels/cpu/aarch64/bitpacking/bitpack.h | 160 +++++++++++++----- .../cpu/aarch64/tests/test_bitpacking.cpp | 33 +--- 2 files changed, 123 insertions(+), 70 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h index 5442f8319b..0f7cd3ffd4 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h @@ -77,13 +77,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( uint8_t* packed, const int8x16_t& unpacked0, const int8x16_t& unpacked1) { - static_assert(nbit < 8); + static_assert(nbit < 9); static_assert(nbit >= 1); - // Shift unpacked values to nonnegative range - int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); - uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); - uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); + // Shift unpacked values to nonnegative range for quantization of 1-7 bits + // No shifting is needed for 8-bit packing + uint8x16_t shifted0; + uint8x16_t shifted1; + if constexpr (nbit < 8) { + int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); + shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); + shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); + } switch (nbit) { case 1: @@ -151,6 +156,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( torchao::bitpacking::internal::pack_8_uint7_values(packed + 14, buffer7 + 16); torchao::bitpacking::internal::pack_8_uint7_values(packed + 21, buffer7 + 24); break; + case 8: + vst1q_u8(packed, vreinterpretq_u8_s8(unpacked0)); + vst1q_u8(packed + 16, vreinterpretq_u8_s8(unpacked1)); + break; default: assert(false); } @@ -161,7 +170,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( int8x16_t& unpacked0, int8x16_t& unpacked1, const uint8_t* packed) { - static_assert(nbit < 8); + static_assert(nbit < 9); static_assert(nbit >= 1); uint8x16_t shifted0; @@ -234,14 +243,21 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( shifted0 = vld1q_u8(buffer7); shifted1 = vld1q_u8(buffer7 + 16); break; + case 8: + unpacked0 = vreinterpretq_s8_u8(vld1q_u8(packed)); + unpacked1 = vreinterpretq_s8_u8(vld1q_u8(packed + 16)); + break; default: assert(false); } // unshift to move unpacked values to full range - int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1))); - unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift); - unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift); + // no shifting is needed for 8-bit packing + if constexpr (nbit < 8) { + int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1))); + unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift); + unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift); + } } template @@ -251,15 +267,23 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( const int8x16_t& unpacked1, const int8x16_t& unpacked2, const int8x16_t& unpacked3) { - static_assert(nbit < 8); + static_assert(nbit < 9); static_assert(nbit >= 1); - // Shift unpacked values to nonnegative range - int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); - uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); - uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); - uint8x16_t shifted2 = vreinterpretq_u8_s8(vaddq_s8(unpacked2, shift)); - uint8x16_t shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift)); + // Shift unpacked values to nonnegative range for quantization of 1-7 bits + // No shifting is needed for 8-bit packing + uint8x16_t shifted0; + uint8x16_t shifted1; + uint8x16_t shifted2; + uint8x16_t shifted3; + if constexpr (nbit < 8) { + int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); + shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); + shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); + shifted2 = vreinterpretq_u8_s8(vaddq_s8(unpacked2, shift)); + shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift)); + } + switch (nbit) { case 1: @@ -292,6 +316,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( torchao::bitpacking::internal::vec_pack_64_uint7_values( packed, shifted0, shifted1, shifted2, shifted3); break; + case 8: + vst1q_u8(packed, vreinterpretq_u8_s8(unpacked0)); + vst1q_u8(packed + 16, vreinterpretq_u8_s8(unpacked1)); + vst1q_u8(packed + 32, vreinterpretq_u8_s8(unpacked2)); + vst1q_u8(packed + 48, vreinterpretq_u8_s8(unpacked3)); + break; default: assert(false); } @@ -304,7 +334,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( int8x16_t& unpacked2, int8x16_t& unpacked3, const uint8_t* packed) { - static_assert(nbit < 8); + static_assert(nbit < 9); static_assert(nbit >= 1); uint8x16_t shifted0; @@ -343,16 +373,25 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( torchao::bitpacking::internal::vec_unpack_64_uint7_values( shifted0, shifted1, shifted2, shifted3, packed); break; + case 8: + unpacked0 = vreinterpretq_s8_u8(vld1q_u8(packed)); + unpacked1 = vreinterpretq_s8_u8(vld1q_u8(packed + 16)); + unpacked2 = vreinterpretq_s8_u8(vld1q_u8(packed + 32)); + unpacked3 = vreinterpretq_s8_u8(vld1q_u8(packed + 48)); + break; default: assert(false); } // unshift to move unpacked values to full range - int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1))); - unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift); - unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift); - unpacked2 = vaddq_s8(vreinterpretq_s8_u8(shifted2), unshift); - unpacked3 = vaddq_s8(vreinterpretq_s8_u8(shifted3), unshift); + // no shifting is needed for 8-bit packing + if constexpr (nbit < 8) { + int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1))); + unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift); + unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift); + unpacked2 = vaddq_s8(vreinterpretq_s8_u8(shifted2), unshift); + unpacked3 = vaddq_s8(vreinterpretq_s8_u8(shifted3), unshift); + } } template @@ -366,19 +405,31 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( const int8x16_t& unpacked5, const int8x16_t& unpacked6, const int8x16_t& unpacked7) { - static_assert(nbit < 8); + static_assert(nbit < 9); static_assert(nbit >= 1); - // Shift unpacked values to nonnegative range - int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); - uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); - uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); - uint8x16_t shifted2 = vreinterpretq_u8_s8(vaddq_s8(unpacked2, shift)); - uint8x16_t shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift)); - uint8x16_t shifted4 = vreinterpretq_u8_s8(vaddq_s8(unpacked4, shift)); - uint8x16_t shifted5 = vreinterpretq_u8_s8(vaddq_s8(unpacked5, shift)); - uint8x16_t shifted6 = vreinterpretq_u8_s8(vaddq_s8(unpacked6, shift)); - uint8x16_t shifted7 = vreinterpretq_u8_s8(vaddq_s8(unpacked7, shift)); + // Shift unpacked values to nonnegative range for quantization of 1-7 bits + // No shifting is needed for 8-bit packing + uint8x16_t shifted0; + uint8x16_t shifted1; + uint8x16_t shifted2; + uint8x16_t shifted3; + uint8x16_t shifted4; + uint8x16_t shifted5; + uint8x16_t shifted6; + uint8x16_t shifted7; + if constexpr (nbit < 8) { + int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); + shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); + shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); + shifted2 = vreinterpretq_u8_s8(vaddq_s8(unpacked2, shift)); + shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift)); + shifted4 = vreinterpretq_u8_s8(vaddq_s8(unpacked4, shift)); + shifted5 = vreinterpretq_u8_s8(vaddq_s8(unpacked5, shift)); + shifted6 = vreinterpretq_u8_s8(vaddq_s8(unpacked6, shift)); + shifted7 = vreinterpretq_u8_s8(vaddq_s8(unpacked7, shift)); + } + switch (nbit) { case 1: @@ -451,6 +502,16 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( shifted6, shifted7); break; + case 8: + vst1q_u8(packed, vreinterpretq_u8_s8(unpacked0)); + vst1q_u8(packed + 16, vreinterpretq_u8_s8(unpacked1)); + vst1q_u8(packed + 32, vreinterpretq_u8_s8(unpacked2)); + vst1q_u8(packed + 48, vreinterpretq_u8_s8(unpacked3)); + vst1q_u8(packed + 64, vreinterpretq_u8_s8(unpacked4)); + vst1q_u8(packed + 80, vreinterpretq_u8_s8(unpacked5)); + vst1q_u8(packed + 96, vreinterpretq_u8_s8(unpacked6)); + vst1q_u8(packed + 112, vreinterpretq_u8_s8(unpacked7)); + break; default: assert(false); } @@ -467,7 +528,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( int8x16_t& unpacked6, int8x16_t& unpacked7, const uint8_t* packed) { - static_assert(nbit < 8); + static_assert(nbit < 9); static_assert(nbit >= 1); uint8x16_t shifted0; @@ -550,20 +611,33 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( shifted7, packed); break; + case 8: + unpacked0 = vreinterpretq_s8_u8(vld1q_u8(packed)); + unpacked1 = vreinterpretq_s8_u8(vld1q_u8(packed + 16)); + unpacked2 = vreinterpretq_s8_u8(vld1q_u8(packed + 32)); + unpacked3 = vreinterpretq_s8_u8(vld1q_u8(packed + 48)); + unpacked4 = vreinterpretq_s8_u8(vld1q_u8(packed + 64)); + unpacked5 = vreinterpretq_s8_u8(vld1q_u8(packed + 80)); + unpacked6 = vreinterpretq_s8_u8(vld1q_u8(packed + 96)); + unpacked7 = vreinterpretq_s8_u8(vld1q_u8(packed + 112)); + break; default: assert(false); } // unshift to move unpacked values to full range - int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1))); - unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift); - unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift); - unpacked2 = vaddq_s8(vreinterpretq_s8_u8(shifted2), unshift); - unpacked3 = vaddq_s8(vreinterpretq_s8_u8(shifted3), unshift); - unpacked4 = vaddq_s8(vreinterpretq_s8_u8(shifted4), unshift); - unpacked5 = vaddq_s8(vreinterpretq_s8_u8(shifted5), unshift); - unpacked6 = vaddq_s8(vreinterpretq_s8_u8(shifted6), unshift); - unpacked7 = vaddq_s8(vreinterpretq_s8_u8(shifted7), unshift); + // no shifting is needed for 8-bit packing + if constexpr (nbit < 8) { + int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1))); + unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift); + unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift); + unpacked2 = vaddq_s8(vreinterpretq_s8_u8(shifted2), unshift); + unpacked3 = vaddq_s8(vreinterpretq_s8_u8(shifted3), unshift); + unpacked4 = vaddq_s8(vreinterpretq_s8_u8(shifted4), unshift); + unpacked5 = vaddq_s8(vreinterpretq_s8_u8(shifted5), unshift); + unpacked6 = vaddq_s8(vreinterpretq_s8_u8(shifted6), unshift); + unpacked7 = vaddq_s8(vreinterpretq_s8_u8(shifted7), unshift); + } } } // namespace bitpacking diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp index c891bdcef3..20b62b11dc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp @@ -674,15 +674,7 @@ template void test_bitpacking_32_lowbit_values() { int unpacked_bytes = 32; int packed_bytes = unpacked_bytes * nbit / 8; - auto input_shifted = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); - std::vector input(unpacked_bytes, 0); - int8_t low = -(1 << (nbit - 1)); - int8_t high = (1 << (nbit - 1)); - for (int i = 0; i < unpacked_bytes; ++i) { - input[i] = (int8_t)(input_shifted[i]) + low; - assert(input[i] >= low); - assert(input[i] <= high); - } + auto input = torchao::get_random_signed_lowbit_vector(unpacked_bytes, nbit); std::vector packed(packed_bytes, 0); int8x16_t input0; @@ -706,15 +698,7 @@ template void test_bitpacking_64_lowbit_values() { int unpacked_bytes = 64; int packed_bytes = unpacked_bytes * nbit / 8; - auto input_shifted = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); - std::vector input(unpacked_bytes, 0); - int8_t low = -(1 << (nbit - 1)); - int8_t high = (1 << (nbit - 1)); - for (int i = 0; i < unpacked_bytes; ++i) { - input[i] = (int8_t)(input_shifted[i]) + low; - assert(input[i] >= low); - assert(input[i] <= high); - } + auto input = torchao::get_random_signed_lowbit_vector(unpacked_bytes, nbit); std::vector packed(packed_bytes, 0); int8x16_t input0; @@ -746,15 +730,7 @@ template void test_bitpacking_128_lowbit_values() { int unpacked_bytes = 128; int packed_bytes = unpacked_bytes * nbit / 8; - auto input_shifted = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); - std::vector input(unpacked_bytes, 0); - int8_t low = -(1 << (nbit - 1)); - int8_t high = (1 << (nbit - 1)); - for (int i = 0; i < unpacked_bytes; ++i) { - input[i] = (int8_t)(input_shifted[i]) + low; - assert(input[i] >= low); - assert(input[i] <= high); - } + auto input = torchao::get_random_signed_lowbit_vector(unpacked_bytes, nbit); std::vector packed(packed_bytes, 0); int8x16_t input0; @@ -836,6 +812,7 @@ TEST_BITPACKING_32_LOWBIT_VALUES(4); TEST_BITPACKING_32_LOWBIT_VALUES(5); TEST_BITPACKING_32_LOWBIT_VALUES(6); TEST_BITPACKING_32_LOWBIT_VALUES(7); +TEST_BITPACKING_32_LOWBIT_VALUES(8); TEST_BITPACKING_64_LOWBIT_VALUES(1); TEST_BITPACKING_64_LOWBIT_VALUES(2); @@ -844,6 +821,7 @@ TEST_BITPACKING_64_LOWBIT_VALUES(4); TEST_BITPACKING_64_LOWBIT_VALUES(5); TEST_BITPACKING_64_LOWBIT_VALUES(6); TEST_BITPACKING_64_LOWBIT_VALUES(7); +TEST_BITPACKING_64_LOWBIT_VALUES(8); TEST_BITPACKING_128_LOWBIT_VALUES(1); TEST_BITPACKING_128_LOWBIT_VALUES(2); @@ -852,5 +830,6 @@ TEST_BITPACKING_128_LOWBIT_VALUES(4); TEST_BITPACKING_128_LOWBIT_VALUES(5); TEST_BITPACKING_128_LOWBIT_VALUES(6); TEST_BITPACKING_128_LOWBIT_VALUES(7); +TEST_BITPACKING_128_LOWBIT_VALUES(8); #endif // defined(__aarch64__) || defined(__ARM_NEON)