From b1ad9848522857dfe0115df0536621ab8eb231fe Mon Sep 17 00:00:00 2001 From: Nicolas De Carli Date: Tue, 18 Nov 2025 09:45:20 -0800 Subject: [PATCH] Add NEON implementation of FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf (#5115) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2121 Adding NEON translation of FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf, used by Ads Performance improves by an order of magnitude: Before: bit_rate rows, cols, elems_per_usec, GB/Sec 2, 100, 16, 211.26, 0.85 2, 100, 64, 210.96, 0.84 2, 100, 128, 204.26, 0.82 2, 100, 256, 200.47, 0.80 2, 100, 512, 194.19, 0.78 2, 100, 1024, 190.98, 0.76 2, 100, 2048, 186.85, 0.75 2, 120, 16, 206.88, 0.83 2, 120, 64, 211.64, 0.85 2, 120, 128, 203.97, 0.82 2, 120, 256, 200.22, 0.80 2, 120, 512, 194.97, 0.78 2, 120, 1024, 191.76, 0.77 2, 120, 2048, 187.45, 0.75 2, 1000, 16, 205.10, 0.82 2, 1000, 64, 214.15, 0.86 2, 1000, 128, 205.43, 0.82 2, 1000, 256, 200.34, 0.80 2, 1000, 512, 196.62, 0.79 2, 1000, 1024, 194.64, 0.78 2, 1000, 2048, 187.54, 0.75 4, 100, 16, 197.97, 0.79 4, 100, 64, 200.02, 0.80 4, 100, 128, 191.06, 0.76 4, 100, 256, 186.58, 0.75 4, 100, 512, 180.76, 0.72 4, 100, 1024, 176.65, 0.71 4, 100, 2048, 175.00, 0.70 4, 120, 16, 198.93, 0.80 4, 120, 64, 201.74, 0.81 4, 120, 128, 190.95, 0.76 4, 120, 256, 186.79, 0.75 4, 120, 512, 181.32, 0.73 4, 120, 1024, 177.54, 0.71 4, 120, 2048, 174.69, 0.70 4, 1000, 16, 194.63, 0.78 4, 1000, 64, 201.64, 0.81 4, 1000, 128, 191.78, 0.77 4, 1000, 256, 186.87, 0.75 4, 1000, 512, 182.91, 0.73 4, 1000, 1024, 180.66, 0.72 4, 1000, 2048, 175.04, 0.70 8, 100, 16, 171.01, 0.68 8, 100, 64, 177.53, 0.71 8, 100, 128, 168.92, 0.68 8, 100, 256, 165.23, 0.66 8, 100, 512, 162.25, 0.65 8, 100, 1024, 158.87, 0.64 8, 100, 2048, 155.39, 0.62 8, 120, 16, 173.77, 0.70 8, 120, 64, 178.34, 0.71 8, 120, 128, 168.66, 0.67 8, 120, 256, 165.60, 0.66 8, 120, 512, 162.30, 0.65 8, 120, 1024, 159.38, 0.64 8, 120, 2048, 156.17, 0.62 8, 1000, 16, 171.34, 0.69 8, 1000, 64, 178.96, 0.72 8, 1000, 128, 169.71, 0.68 8, 1000, 256, 165.62, 0.66 8, 1000, 512, 162.98, 0.65 8, 1000, 1024, 161.59, 0.65 8, 1000, 2048, 157.16, 0.63 After: bit_rate rows, cols, elems_per_usec, GB/Sec 2, 100, 16, 1006.83, 4.03 2, 100, 64, 1542.11, 6.17 2, 100, 128, 1882.99, 7.53 2, 100, 256, 2063.71, 8.25 2, 100, 512, 2232.29, 8.93 2, 100, 1024, 2298.69, 9.19 2, 100, 2048, 2333.73, 9.33 2, 120, 16, 1016.40, 4.07 2, 120, 64, 1524.36, 6.10 2, 120, 128, 1853.40, 7.41 2, 120, 256, 2158.92, 8.64 2, 120, 512, 2321.61, 9.29 2, 120, 1024, 2353.80, 9.42 2, 120, 2048, 2332.84, 9.33 2, 1000, 16, 1129.08, 4.52 2, 1000, 64, 1606.46, 6.43 2, 1000, 128, 2095.33, 8.38 2, 1000, 256, 2470.88, 9.88 2, 1000, 512, 2746.67, 10.99 2, 1000, 1024, 2882.32, 11.53 2, 1000, 2048, 2447.96, 9.79 4, 100, 16, 999.05, 4.00 4, 100, 64, 1666.00, 6.66 4, 100, 128, 2062.08, 8.25 4, 100, 256, 2226.33, 8.91 4, 100, 512, 2481.11, 9.92 4, 100, 1024, 2717.50, 10.87 4, 100, 2048, 2656.00, 10.62 4, 120, 16, 1056.31, 4.23 4, 120, 64, 1651.95, 6.61 4, 120, 128, 2058.65, 8.23 4, 120, 256, 2339.64, 9.36 4, 120, 512, 2570.03, 10.28 4, 120, 1024, 2788.24, 11.15 4, 120, 2048, 2701.20, 10.80 4, 1000, 16, 1184.28, 4.74 4, 1000, 64, 1765.47, 7.06 4, 1000, 128, 2348.17, 9.39 4, 1000, 256, 2852.72, 11.41 4, 1000, 512, 3249.46, 13.00 4, 1000, 1024, 3418.46, 13.67 4, 1000, 2048, 2841.77, 11.37 8, 100, 16, 1176.35, 4.71 8, 100, 64, 1902.76, 7.61 8, 100, 128, 2196.23, 8.78 8, 100, 256, 2596.55, 10.39 8, 100, 512, 2814.30, 11.26 8, 100, 1024, 3175.49, 12.70 8, 100, 2048, 3334.41, 13.34 8, 120, 16, 1213.55, 4.85 8, 120, 64, 1806.19, 7.22 8, 120, 128, 2390.64, 9.56 8, 120, 256, 2736.11, 10.94 8, 120, 512, 3015.86, 12.06 8, 120, 1024, 3332.53, 13.33 8, 120, 2048, 3319.50, 13.28 8, 1000, 16, 1362.12, 5.45 8, 1000, 64, 2029.25, 8.12 8, 1000, 128, 2759.50, 11.04 8, 1000, 256, 3532.71, 14.13 8, 1000, 512, 4014.48, 16.06 8, 1000, 1024, 4240.49, 16.96 8, 1000, 2048, 3440.59, 13.76 Differential Revision: D86774172 --- include/fbgemm/QuantUtilsNeon.h | 7 + src/QuantUtils.cc | 22 +++ src/QuantUtilsNeon.cc | 260 +++++++++++++++++++++++++++++++- 3 files changed, 285 insertions(+), 4 deletions(-) diff --git a/include/fbgemm/QuantUtilsNeon.h b/include/fbgemm/QuantUtilsNeon.h index 13169c8a05..69845ec67a 100644 --- a/include/fbgemm/QuantUtilsNeon.h +++ b/include/fbgemm/QuantUtilsNeon.h @@ -36,6 +36,13 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( int input_columns, OutputType* output); +template +void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( + const InputType* input, + size_t input_rows, + int input_columns, + std::uint8_t* output); + } // namespace fbgemm #endif // __aarch64__ diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index 5301909193..06030df46b 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -636,6 +636,26 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( throw std::runtime_error("Unsupported number of columns"); } +#if HAVE_SVE + switch (bit_rate) { + case 2: + FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( + input, input_rows, input_columns, output); + break; + case 4: + FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( + input, input_rows, input_columns, output); + break; + case 8: + FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( + input, input_rows, input_columns, output); + break; + default: + FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef( + bit_rate, input, input_rows, input_columns, output); + } +#else + if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 switch (bit_rate) { @@ -660,6 +680,8 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef( bit_rate, input, input_rows, input_columns, output); } + +#endif } template diff --git a/src/QuantUtilsNeon.cc b/src/QuantUtilsNeon.cc index 8fef86b94f..7868fbb519 100644 --- a/src/QuantUtilsNeon.cc +++ b/src/QuantUtilsNeon.cc @@ -95,8 +95,12 @@ void FindMinMax(const float* m, float* min, float* max, int64_t len) { #if HAVE_SVE -static inline void -FindMinMaxImpl_f16(const float16_t* m, float* min, float* max, uint64_t count) { +template +static inline void FindMinMaxImpl_f16( + const float16_t* m, + OutType* min, + OutType* max, + uint64_t count) { float16_t first = *m; float16_t tmp_min_s = first; @@ -141,8 +145,8 @@ FindMinMaxImpl_f16(const float16_t* m, float* min, float* max, uint64_t count) { tmp_max_s = vmaxh_f16(tmp_max_s, tmp); } - *min = static_cast(tmp_min_s); - *max = static_cast(tmp_max_s); + *min = static_cast(tmp_min_s); + *max = static_cast(tmp_max_s); } template @@ -257,6 +261,236 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon( } // for each row } +template +void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( + const InputType* input, + size_t input_rows, + int input_columns, + std::uint8_t* output) { + if (input_rows == 0 || input_columns <= 0) { + return; + } + + static_assert( + std::is_same() || std::is_same(), + "Only float and float16 types are allowed."); + + static_assert( + (BIT_RATE == 8) || (BIT_RATE == 4) || (BIT_RATE == 2), + "Only bit rates of 8, 4 and 2 are allowed."); + + constexpr uint64_t num_elem_per_byte = 8 / BIT_RATE; + uint64_t column_count = static_cast(input_columns); + const int output_columns = + (column_count + num_elem_per_byte - 1) / num_elem_per_byte + + 2 * sizeof(float16); + + for (size_t row = 0; __builtin_expect(row < input_rows, 1); ++row) { + const InputType* input_row = input + row * column_count; + std::uint8_t* output_row = output + row * output_columns; + float16_t* output_row_scale_bias = reinterpret_cast( + output_row + + (column_count + num_elem_per_byte - 1) / num_elem_per_byte); + + float minimum_element; + float maximum_element; + float16_t minimum_element_fp16; + if constexpr (std::is_same()) { + FindMinMaxImpl_f32( + input_row, &minimum_element, &maximum_element, column_count); + minimum_element_fp16 = static_cast(minimum_element); + minimum_element = static_cast(minimum_element_fp16); + } else { + float16_t maximum_element_fp16; + FindMinMaxImpl_f16( + reinterpret_cast(input_row), + &minimum_element_fp16, + &maximum_element_fp16, + column_count); + minimum_element = static_cast(minimum_element_fp16); + maximum_element = static_cast(maximum_element_fp16); + } + + const float range = maximum_element - minimum_element; + + float scale = range == 0 ? 1.0f : range / ((1 << BIT_RATE) - 1); + float16_t scale_fp16 = static_cast(scale); + scale = static_cast(scale_fp16); + svfloat32_t inverse_scale_sv; + if (scale != 0.0f) { + float inverse_scale = 1.0f / scale; + inverse_scale_sv = svdup_n_f32(inverse_scale); + bool isInf = svptest_any( + svptrue_b8(), + svcmpuo_f32( + svptrue_b8(), + svsub_f32_x(svptrue_b8(), inverse_scale_sv, inverse_scale_sv), + svdup_n_f32(0.0))); + if (isInf) { + scale_fp16 = static_cast(1.0f); + scale = 1.0f; + inverse_scale_sv = svdup_n_f32(1.0f); + } + } else { + // Corner case handling when maximum_element == minimum_element + // Any scale would work because X - minimum_element will be 0 for all X + scale_fp16 = static_cast(1.0f); + scale = 1.0f; + inverse_scale_sv = svdup_n_f32(1.0f); + } + + constexpr uint64_t kItemsPerIter = 8; + uint64_t loopIters = column_count / kItemsPerIter; + uint64_t loopRemainder = column_count % kItemsPerIter; + + output_row_scale_bias[0] = scale_fp16; + output_row_scale_bias[1] = minimum_element_fp16; + + float32x4_t inverse_scale_v = svget_neonq(inverse_scale_sv); + float32x4_t min_v = vdupq_n_f32(minimum_element); + + constexpr unsigned int maxValPerBitRate = (1ul << BIT_RATE) - 1; + uint32x4_t maxval_v = vdupq_n_u32(maxValPerBitRate); + + svbool_t lastPredA = svwhilelt_b32_u64(0, loopRemainder); + svbool_t lastPredB = svwhilelt_b32_u64(4, loopRemainder); + + while (__builtin_expect(loopIters > 0, 1)) { + float32x4_t v0; + float32x4_t v1; + + if constexpr (std::is_same()) { + v0 = vld1q_f32(input_row); + v1 = vld1q_f32(input_row + 4); + } else { + float16x8_t h0 = + vld1q_f16(reinterpret_cast(input_row)); + v0 = vcvt_f32_f16(vget_low_f16(h0)); + v1 = vcvt_high_f32_f16(h0); + } + + input_row += kItemsPerIter; + loopIters -= 1; + + v0 = vsubq_f32(v0, min_v); + v1 = vsubq_f32(v1, min_v); + + v0 = vmulq_f32(v0, inverse_scale_v); + v1 = vmulq_f32(v1, inverse_scale_v); + + int32x4_t i0 = vcvtnq_s32_f32(v0); + int32x4_t i1 = vcvtnq_s32_f32(v1); + + uint32x4_t u0 = vminq_u32(vreinterpretq_u32_s32(i0), maxval_v); + uint32x4_t u1 = vminq_u32(vreinterpretq_u32_s32(i1), maxval_v); + + if constexpr (num_elem_per_byte == 1) { + svst1b_u32( + svptrue_b8(), output_row, svset_neonq_u32(svundef_u32(), u0)); + svst1b_u32( + svptrue_b8(), output_row + 4, svset_neonq_u32(svundef_u32(), u1)); + } else { + constexpr uint64_t shiftVar = num_elem_per_byte == 2 ? 28 : 30; + + uint64x2_t u2 = vreinterpretq_u64_u32(u0) >> shiftVar; + uint64x2_t u3 = vreinterpretq_u64_u32(u1) >> shiftVar; + + u2 = veorq_u64(u2, vreinterpretq_u64_u32(u0)); + u3 = veorq_u64(u3, vreinterpretq_u64_u32(u1)); + + if constexpr (num_elem_per_byte == 2) { + svst1b_u64( + svptrue_b8(), output_row, svset_neonq_u64(svundef_u64(), u2)); + svst1b_u64( + svptrue_b8(), output_row + 2, svset_neonq_u64(svundef_u64(), u3)); + + } else if constexpr (num_elem_per_byte == 4) { + auto u4 = vdup_laneq_u8(vreinterpretq_u8_u64(u2), 8); + auto u5 = vdup_laneq_u8(vreinterpretq_u8_u64(u3), 8); + + u4 = u4 << 4; + u5 = u5 << 4; + + u4 = veor_u8(u4, vget_low_u8(u2)); + u5 = veor_u8(u5, vget_low_u8(u3)); + + vst1_lane_u8(output_row, u4, 0); + vst1_lane_u8(output_row + 1, u5, 0); + } + } + + constexpr uint64_t bytesStored = kItemsPerIter / num_elem_per_byte; + output_row += bytesStored; + } + + if (loopRemainder > 0) { + float32x4_t v0; + float32x4_t v1; + + if constexpr (std::is_same()) { + v0 = svget_neonq(svld1_f32(lastPredA, input_row)); + v1 = svget_neonq(svld1_f32(lastPredB, input_row + 4)); + } else { + auto h0 = svld1uh_u32( + lastPredA, reinterpret_cast(input_row)); + auto h1 = svld1uh_u32( + lastPredB, reinterpret_cast(input_row + 4)); + v0 = svget_neonq( + svcvt_f32_f16_x(svptrue_b8(), svreinterpret_f16_u32(h0))); + v1 = svget_neonq( + svcvt_f32_f16_x(svptrue_b8(), svreinterpret_f16_u32(h1))); + } + + v0 = vsubq_f32(v0, min_v); + v1 = vsubq_f32(v1, min_v); + + v0 = vmulq_f32(v0, inverse_scale_v); + v1 = vmulq_f32(v1, inverse_scale_v); + + int32x4_t i0 = vcvtnq_s32_f32(v0); + int32x4_t i1 = vcvtnq_s32_f32(v1); + + uint32x4_t u0 = vminq_u32(vreinterpretq_u32_s32(i0), maxval_v); + uint32x4_t u1 = vminq_u32(vreinterpretq_u32_s32(i1), maxval_v); + + if constexpr (num_elem_per_byte == 1) { + svst1b_u32(lastPredA, output_row, svset_neonq_u32(svundef_u32(), u0)); + svst1b_u32( + lastPredB, output_row + 4, svset_neonq_u32(svundef_u32(), u1)); + } else { + constexpr uint64_t shiftVar = num_elem_per_byte == 2 ? 28 : 30; + + uint64x2_t u2 = vreinterpretq_u64_u32(u0) >> shiftVar; + uint64x2_t u3 = vreinterpretq_u64_u32(u1) >> shiftVar; + + u2 = veorq_u64(u2, vreinterpretq_u64_u32(u0)); + u3 = veorq_u64(u3, vreinterpretq_u64_u32(u1)); + + if constexpr (num_elem_per_byte == 2) { + svst1b_u64(lastPredA, output_row, svset_neonq_u64(svundef_u64(), u2)); + svst1b_u64( + lastPredB, output_row + 2, svset_neonq_u64(svundef_u64(), u3)); + + } else if constexpr (num_elem_per_byte == 4) { + auto u4 = vdup_laneq_u8(vreinterpretq_u8_u64(u2), 8); + auto u5 = vdup_laneq_u8(vreinterpretq_u8_u64(u3), 8); + + u4 = u4 << 4; + u5 = u5 << 4; + + u4 = veor_u8(u4, vget_low_u8(u2)); + u5 = veor_u8(u5, vget_low_u8(u3)); + + vst1_lane_u8(output_row, u4, 0); + if (loopRemainder > 4) { + vst1_lane_u8(output_row + 1, u5, 0); + } + } + } + } + } // for each row +} + template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( const std::uint8_t* input, @@ -372,6 +606,24 @@ INSTANTIATE_QuantizationNeonFunctions8Bits(float16) // clang-format on #undef INSTANTIATE_QuantizationNeonFunctions8Bits +#define INSTANTIATE_QuantizationNeonFunctionsNBits(type, bit_rate) \ + template void \ + FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( \ + const type* input, \ + size_t input_rows, \ + int input_columns, \ + std::uint8_t* output); + + // clang-format off +INSTANTIATE_QuantizationNeonFunctionsNBits(float, 2) +INSTANTIATE_QuantizationNeonFunctionsNBits(float, 4) +INSTANTIATE_QuantizationNeonFunctionsNBits(float, 8) +INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 2) +INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 4) +INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 8) +// clang-format on +#undef INSTANTIATE_QuantizationNeonFunctionsNBits + #endif // HAVE_SVE } // namespace fbgemm