Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/fbgemm/QuantUtilsNeon.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
int input_columns,
OutputType* output);

template <typename InputType, int BIT_RATE>
void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon(
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);

} // namespace fbgemm

#endif // __aarch64__
22 changes: 22 additions & 0 deletions src/QuantUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,26 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
throw std::runtime_error("Unsupported number of columns");
}

#if HAVE_SVE
switch (bit_rate) {
case 2:
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon<InputType, 2>(
input, input_rows, input_columns, output);
break;
case 4:
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon<InputType, 4>(
input, input_rows, input_columns, output);
break;
case 8:
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon<InputType, 8>(
input, input_rows, input_columns, output);
break;
default:
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>(
bit_rate, input, input_rows, input_columns, output);
}
#else

if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
switch (bit_rate) {
Expand All @@ -660,6 +680,8 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>(
bit_rate, input, input_rows, input_columns, output);
}

#endif
}

template <typename InputType>
Expand Down
260 changes: 256 additions & 4 deletions src/QuantUtilsNeon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,12 @@ void FindMinMax(const float* m, float* min, float* max, int64_t len) {

#if HAVE_SVE

static inline void
FindMinMaxImpl_f16(const float16_t* m, float* min, float* max, uint64_t count) {
template <typename OutType>
static inline void FindMinMaxImpl_f16(
const float16_t* m,
OutType* min,
OutType* max,
uint64_t count) {
float16_t first = *m;

float16_t tmp_min_s = first;
Expand Down Expand Up @@ -141,8 +145,8 @@ FindMinMaxImpl_f16(const float16_t* m, float* min, float* max, uint64_t count) {
tmp_max_s = vmaxh_f16(tmp_max_s, tmp);
}

*min = static_cast<float>(tmp_min_s);
*max = static_cast<float>(tmp_max_s);
*min = static_cast<OutType>(tmp_min_s);
*max = static_cast<OutType>(tmp_max_s);
}

template <typename InputType>
Expand Down Expand Up @@ -257,6 +261,236 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon(
} // for each row
}

template <typename InputType, int BIT_RATE>
void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon(
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output) {
if (input_rows == 0 || input_columns <= 0) {
return;
}

static_assert(
std::is_same<InputType, float>() || std::is_same<InputType, float16>(),
"Only float and float16 types are allowed.");

static_assert(
(BIT_RATE == 8) || (BIT_RATE == 4) || (BIT_RATE == 2),
"Only bit rates of 8, 4 and 2 are allowed.");

constexpr uint64_t num_elem_per_byte = 8 / BIT_RATE;
uint64_t column_count = static_cast<uint64_t>(input_columns);
const int output_columns =
(column_count + num_elem_per_byte - 1) / num_elem_per_byte +
2 * sizeof(float16);

for (size_t row = 0; __builtin_expect(row < input_rows, 1); ++row) {
const InputType* input_row = input + row * column_count;
std::uint8_t* output_row = output + row * output_columns;
float16_t* output_row_scale_bias = reinterpret_cast<float16_t*>(
output_row +
(column_count + num_elem_per_byte - 1) / num_elem_per_byte);

float minimum_element;
float maximum_element;
float16_t minimum_element_fp16;
if constexpr (std::is_same<InputType, float>()) {
FindMinMaxImpl_f32(
input_row, &minimum_element, &maximum_element, column_count);
minimum_element_fp16 = static_cast<float16_t>(minimum_element);
minimum_element = static_cast<float>(minimum_element_fp16);
} else {
float16_t maximum_element_fp16;
FindMinMaxImpl_f16(
reinterpret_cast<const float16_t*>(input_row),
&minimum_element_fp16,
&maximum_element_fp16,
column_count);
minimum_element = static_cast<float>(minimum_element_fp16);
maximum_element = static_cast<float>(maximum_element_fp16);
}

const float range = maximum_element - minimum_element;

float scale = range == 0 ? 1.0f : range / ((1 << BIT_RATE) - 1);
float16_t scale_fp16 = static_cast<float16_t>(scale);
scale = static_cast<float>(scale_fp16);
svfloat32_t inverse_scale_sv;
if (scale != 0.0f) {
float inverse_scale = 1.0f / scale;
inverse_scale_sv = svdup_n_f32(inverse_scale);
bool isInf = svptest_any(
svptrue_b8(),
svcmpuo_f32(
svptrue_b8(),
svsub_f32_x(svptrue_b8(), inverse_scale_sv, inverse_scale_sv),
svdup_n_f32(0.0)));
if (isInf) {
scale_fp16 = static_cast<float16_t>(1.0f);
scale = 1.0f;
inverse_scale_sv = svdup_n_f32(1.0f);
}
} else {
// Corner case handling when maximum_element == minimum_element
// Any scale would work because X - minimum_element will be 0 for all X
scale_fp16 = static_cast<float16_t>(1.0f);
scale = 1.0f;
inverse_scale_sv = svdup_n_f32(1.0f);
}

constexpr uint64_t kItemsPerIter = 8;
uint64_t loopIters = column_count / kItemsPerIter;
uint64_t loopRemainder = column_count % kItemsPerIter;

output_row_scale_bias[0] = scale_fp16;
output_row_scale_bias[1] = minimum_element_fp16;

float32x4_t inverse_scale_v = svget_neonq(inverse_scale_sv);
float32x4_t min_v = vdupq_n_f32(minimum_element);

constexpr unsigned int maxValPerBitRate = (1ul << BIT_RATE) - 1;
uint32x4_t maxval_v = vdupq_n_u32(maxValPerBitRate);

svbool_t lastPredA = svwhilelt_b32_u64(0, loopRemainder);
svbool_t lastPredB = svwhilelt_b32_u64(4, loopRemainder);

while (__builtin_expect(loopIters > 0, 1)) {
float32x4_t v0;
float32x4_t v1;

if constexpr (std::is_same<InputType, float>()) {
v0 = vld1q_f32(input_row);
v1 = vld1q_f32(input_row + 4);
} else {
float16x8_t h0 =
vld1q_f16(reinterpret_cast<const float16_t*>(input_row));
v0 = vcvt_f32_f16(vget_low_f16(h0));
v1 = vcvt_high_f32_f16(h0);
}

input_row += kItemsPerIter;
loopIters -= 1;

v0 = vsubq_f32(v0, min_v);
v1 = vsubq_f32(v1, min_v);

v0 = vmulq_f32(v0, inverse_scale_v);
v1 = vmulq_f32(v1, inverse_scale_v);

int32x4_t i0 = vcvtnq_s32_f32(v0);
int32x4_t i1 = vcvtnq_s32_f32(v1);

uint32x4_t u0 = vminq_u32(vreinterpretq_u32_s32(i0), maxval_v);
uint32x4_t u1 = vminq_u32(vreinterpretq_u32_s32(i1), maxval_v);

if constexpr (num_elem_per_byte == 1) {
svst1b_u32(
svptrue_b8(), output_row, svset_neonq_u32(svundef_u32(), u0));
svst1b_u32(
svptrue_b8(), output_row + 4, svset_neonq_u32(svundef_u32(), u1));
} else {
constexpr uint64_t shiftVar = num_elem_per_byte == 2 ? 28 : 30;

uint64x2_t u2 = vreinterpretq_u64_u32(u0) >> shiftVar;
uint64x2_t u3 = vreinterpretq_u64_u32(u1) >> shiftVar;

u2 = veorq_u64(u2, vreinterpretq_u64_u32(u0));
u3 = veorq_u64(u3, vreinterpretq_u64_u32(u1));

if constexpr (num_elem_per_byte == 2) {
svst1b_u64(
svptrue_b8(), output_row, svset_neonq_u64(svundef_u64(), u2));
svst1b_u64(
svptrue_b8(), output_row + 2, svset_neonq_u64(svundef_u64(), u3));

} else if constexpr (num_elem_per_byte == 4) {
auto u4 = vdup_laneq_u8(vreinterpretq_u8_u64(u2), 8);
auto u5 = vdup_laneq_u8(vreinterpretq_u8_u64(u3), 8);

u4 = u4 << 4;
u5 = u5 << 4;

u4 = veor_u8(u4, vget_low_u8(u2));
u5 = veor_u8(u5, vget_low_u8(u3));

vst1_lane_u8(output_row, u4, 0);
vst1_lane_u8(output_row + 1, u5, 0);
}
}

constexpr uint64_t bytesStored = kItemsPerIter / num_elem_per_byte;
output_row += bytesStored;
}

if (loopRemainder > 0) {
float32x4_t v0;
float32x4_t v1;

if constexpr (std::is_same<InputType, float>()) {
v0 = svget_neonq(svld1_f32(lastPredA, input_row));
v1 = svget_neonq(svld1_f32(lastPredB, input_row + 4));
} else {
auto h0 = svld1uh_u32(
lastPredA, reinterpret_cast<const uint16_t*>(input_row));
auto h1 = svld1uh_u32(
lastPredB, reinterpret_cast<const uint16_t*>(input_row + 4));
v0 = svget_neonq(
svcvt_f32_f16_x(svptrue_b8(), svreinterpret_f16_u32(h0)));
v1 = svget_neonq(
svcvt_f32_f16_x(svptrue_b8(), svreinterpret_f16_u32(h1)));
}

v0 = vsubq_f32(v0, min_v);
v1 = vsubq_f32(v1, min_v);

v0 = vmulq_f32(v0, inverse_scale_v);
v1 = vmulq_f32(v1, inverse_scale_v);

int32x4_t i0 = vcvtnq_s32_f32(v0);
int32x4_t i1 = vcvtnq_s32_f32(v1);

uint32x4_t u0 = vminq_u32(vreinterpretq_u32_s32(i0), maxval_v);
uint32x4_t u1 = vminq_u32(vreinterpretq_u32_s32(i1), maxval_v);

if constexpr (num_elem_per_byte == 1) {
svst1b_u32(lastPredA, output_row, svset_neonq_u32(svundef_u32(), u0));
svst1b_u32(
lastPredB, output_row + 4, svset_neonq_u32(svundef_u32(), u1));
} else {
constexpr uint64_t shiftVar = num_elem_per_byte == 2 ? 28 : 30;

uint64x2_t u2 = vreinterpretq_u64_u32(u0) >> shiftVar;
uint64x2_t u3 = vreinterpretq_u64_u32(u1) >> shiftVar;

u2 = veorq_u64(u2, vreinterpretq_u64_u32(u0));
u3 = veorq_u64(u3, vreinterpretq_u64_u32(u1));

if constexpr (num_elem_per_byte == 2) {
svst1b_u64(lastPredA, output_row, svset_neonq_u64(svundef_u64(), u2));
svst1b_u64(
lastPredB, output_row + 2, svset_neonq_u64(svundef_u64(), u3));

} else if constexpr (num_elem_per_byte == 4) {
auto u4 = vdup_laneq_u8(vreinterpretq_u8_u64(u2), 8);
auto u5 = vdup_laneq_u8(vreinterpretq_u8_u64(u3), 8);

u4 = u4 << 4;
u5 = u5 << 4;

u4 = veor_u8(u4, vget_low_u8(u2));
u5 = veor_u8(u5, vget_low_u8(u3));

vst1_lane_u8(output_row, u4, 0);
if (loopRemainder > 4) {
vst1_lane_u8(output_row + 1, u5, 0);
}
}
}
}
} // for each row
}

template <typename OutputType>
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
const std::uint8_t* input,
Expand Down Expand Up @@ -372,6 +606,24 @@ INSTANTIATE_QuantizationNeonFunctions8Bits(float16)
// clang-format on
#undef INSTANTIATE_QuantizationNeonFunctions8Bits

#define INSTANTIATE_QuantizationNeonFunctionsNBits(type, bit_rate) \
template void \
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon<type, bit_rate>( \
const type* input, \
size_t input_rows, \
int input_columns, \
std::uint8_t* output);

// clang-format off
INSTANTIATE_QuantizationNeonFunctionsNBits(float, 2)
INSTANTIATE_QuantizationNeonFunctionsNBits(float, 4)
INSTANTIATE_QuantizationNeonFunctionsNBits(float, 8)
INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 2)
INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 4)
INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 8)
// clang-format on
#undef INSTANTIATE_QuantizationNeonFunctionsNBits

#endif // HAVE_SVE

} // namespace fbgemm
Expand Down
Loading