Skip to content

Commit 16a2f5a

Browse files
Nicoshevfacebook-github-bot
authored andcommitted
Add NEON implementation of FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf (#5115)
Summary: X-link: facebookresearch/FBGEMM#2121 Adding NEON translation of FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf, used by Ads Performance improves by an order of magnitude: Before: bit_rate rows, cols, elems_per_usec, GB/Sec 2, 100, 16, 211.26, 0.85 2, 100, 64, 210.96, 0.84 2, 100, 128, 204.26, 0.82 2, 100, 256, 200.47, 0.80 2, 100, 512, 194.19, 0.78 2, 100, 1024, 190.98, 0.76 2, 100, 2048, 186.85, 0.75 2, 120, 16, 206.88, 0.83 2, 120, 64, 211.64, 0.85 2, 120, 128, 203.97, 0.82 2, 120, 256, 200.22, 0.80 2, 120, 512, 194.97, 0.78 2, 120, 1024, 191.76, 0.77 2, 120, 2048, 187.45, 0.75 2, 1000, 16, 205.10, 0.82 2, 1000, 64, 214.15, 0.86 2, 1000, 128, 205.43, 0.82 2, 1000, 256, 200.34, 0.80 2, 1000, 512, 196.62, 0.79 2, 1000, 1024, 194.64, 0.78 2, 1000, 2048, 187.54, 0.75 4, 100, 16, 197.97, 0.79 4, 100, 64, 200.02, 0.80 4, 100, 128, 191.06, 0.76 4, 100, 256, 186.58, 0.75 4, 100, 512, 180.76, 0.72 4, 100, 1024, 176.65, 0.71 4, 100, 2048, 175.00, 0.70 4, 120, 16, 198.93, 0.80 4, 120, 64, 201.74, 0.81 4, 120, 128, 190.95, 0.76 4, 120, 256, 186.79, 0.75 4, 120, 512, 181.32, 0.73 4, 120, 1024, 177.54, 0.71 4, 120, 2048, 174.69, 0.70 4, 1000, 16, 194.63, 0.78 4, 1000, 64, 201.64, 0.81 4, 1000, 128, 191.78, 0.77 4, 1000, 256, 186.87, 0.75 4, 1000, 512, 182.91, 0.73 4, 1000, 1024, 180.66, 0.72 4, 1000, 2048, 175.04, 0.70 8, 100, 16, 171.01, 0.68 8, 100, 64, 177.53, 0.71 8, 100, 128, 168.92, 0.68 8, 100, 256, 165.23, 0.66 8, 100, 512, 162.25, 0.65 8, 100, 1024, 158.87, 0.64 8, 100, 2048, 155.39, 0.62 8, 120, 16, 173.77, 0.70 8, 120, 64, 178.34, 0.71 8, 120, 128, 168.66, 0.67 8, 120, 256, 165.60, 0.66 8, 120, 512, 162.30, 0.65 8, 120, 1024, 159.38, 0.64 8, 120, 2048, 156.17, 0.62 8, 1000, 16, 171.34, 0.69 8, 1000, 64, 178.96, 0.72 8, 1000, 128, 169.71, 0.68 8, 1000, 256, 165.62, 0.66 8, 1000, 512, 162.98, 0.65 8, 1000, 1024, 161.59, 0.65 8, 1000, 2048, 157.16, 0.63 After: bit_rate rows, cols, elems_per_usec, GB/Sec 2, 100, 16, 1006.83, 4.03 2, 100, 64, 1542.11, 6.17 2, 100, 128, 1882.99, 7.53 2, 100, 256, 2063.71, 8.25 2, 100, 512, 2232.29, 8.93 2, 100, 1024, 2298.69, 9.19 2, 100, 2048, 2333.73, 9.33 2, 120, 16, 1016.40, 4.07 2, 120, 64, 1524.36, 6.10 2, 120, 128, 1853.40, 7.41 2, 120, 256, 2158.92, 8.64 2, 120, 512, 2321.61, 9.29 2, 120, 1024, 2353.80, 9.42 2, 120, 2048, 2332.84, 9.33 2, 1000, 16, 1129.08, 4.52 2, 1000, 64, 1606.46, 6.43 2, 1000, 128, 2095.33, 8.38 2, 1000, 256, 2470.88, 9.88 2, 1000, 512, 2746.67, 10.99 2, 1000, 1024, 2882.32, 11.53 2, 1000, 2048, 2447.96, 9.79 4, 100, 16, 999.05, 4.00 4, 100, 64, 1666.00, 6.66 4, 100, 128, 2062.08, 8.25 4, 100, 256, 2226.33, 8.91 4, 100, 512, 2481.11, 9.92 4, 100, 1024, 2717.50, 10.87 4, 100, 2048, 2656.00, 10.62 4, 120, 16, 1056.31, 4.23 4, 120, 64, 1651.95, 6.61 4, 120, 128, 2058.65, 8.23 4, 120, 256, 2339.64, 9.36 4, 120, 512, 2570.03, 10.28 4, 120, 1024, 2788.24, 11.15 4, 120, 2048, 2701.20, 10.80 4, 1000, 16, 1184.28, 4.74 4, 1000, 64, 1765.47, 7.06 4, 1000, 128, 2348.17, 9.39 4, 1000, 256, 2852.72, 11.41 4, 1000, 512, 3249.46, 13.00 4, 1000, 1024, 3418.46, 13.67 4, 1000, 2048, 2841.77, 11.37 8, 100, 16, 1176.35, 4.71 8, 100, 64, 1902.76, 7.61 8, 100, 128, 2196.23, 8.78 8, 100, 256, 2596.55, 10.39 8, 100, 512, 2814.30, 11.26 8, 100, 1024, 3175.49, 12.70 8, 100, 2048, 3334.41, 13.34 8, 120, 16, 1213.55, 4.85 8, 120, 64, 1806.19, 7.22 8, 120, 128, 2390.64, 9.56 8, 120, 256, 2736.11, 10.94 8, 120, 512, 3015.86, 12.06 8, 120, 1024, 3332.53, 13.33 8, 120, 2048, 3319.50, 13.28 8, 1000, 16, 1362.12, 5.45 8, 1000, 64, 2029.25, 8.12 8, 1000, 128, 2759.50, 11.04 8, 1000, 256, 3532.71, 14.13 8, 1000, 512, 4014.48, 16.06 8, 1000, 1024, 4240.49, 16.96 8, 1000, 2048, 3440.59, 13.76 Differential Revision: D86774172
1 parent 8189ad4 commit 16a2f5a

File tree

3 files changed

+285
-4
lines changed

3 files changed

+285
-4
lines changed

include/fbgemm/QuantUtilsNeon.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
3636
int input_columns,
3737
OutputType* output);
3838

39+
template <typename InputType, int BIT_RATE>
40+
void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon(
41+
const InputType* input,
42+
size_t input_rows,
43+
int input_columns,
44+
std::uint8_t* output);
45+
3946
} // namespace fbgemm
4047

4148
#endif // __aarch64__

src/QuantUtils.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,26 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
636636
throw std::runtime_error("Unsupported number of columns");
637637
}
638638

639+
#if HAVE_SVE
640+
switch (bit_rate) {
641+
case 2:
642+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon<InputType, 2>(
643+
input, input_rows, input_columns, output);
644+
break;
645+
case 4:
646+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon<InputType, 4>(
647+
input, input_rows, input_columns, output);
648+
break;
649+
case 8:
650+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon<InputType, 8>(
651+
input, input_rows, input_columns, output);
652+
break;
653+
default:
654+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>(
655+
bit_rate, input, input_rows, input_columns, output);
656+
}
657+
#else
658+
639659
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
640660
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
641661
switch (bit_rate) {
@@ -660,6 +680,8 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
660680
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>(
661681
bit_rate, input, input_rows, input_columns, output);
662682
}
683+
684+
#endif
663685
}
664686

665687
template <typename InputType>

src/QuantUtilsNeon.cc

Lines changed: 256 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,12 @@ void FindMinMax(const float* m, float* min, float* max, int64_t len) {
9595

9696
#if HAVE_SVE
9797

98-
static inline void
99-
FindMinMaxImpl_f16(const float16_t* m, float* min, float* max, uint64_t count) {
98+
template <typename OutType>
99+
static inline void FindMinMaxImpl_f16(
100+
const float16_t* m,
101+
OutType* min,
102+
OutType* max,
103+
uint64_t count) {
100104
float16_t first = *m;
101105

102106
float16_t tmp_min_s = first;
@@ -141,8 +145,8 @@ FindMinMaxImpl_f16(const float16_t* m, float* min, float* max, uint64_t count) {
141145
tmp_max_s = vmaxh_f16(tmp_max_s, tmp);
142146
}
143147

144-
*min = static_cast<float>(tmp_min_s);
145-
*max = static_cast<float>(tmp_max_s);
148+
*min = static_cast<OutType>(tmp_min_s);
149+
*max = static_cast<OutType>(tmp_max_s);
146150
}
147151

148152
template <typename InputType>
@@ -257,6 +261,236 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon(
257261
} // for each row
258262
}
259263

264+
template <typename InputType, int BIT_RATE>
265+
void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon(
266+
const InputType* input,
267+
size_t input_rows,
268+
int input_columns,
269+
std::uint8_t* output) {
270+
if (input_rows == 0 || input_columns <= 0) {
271+
return;
272+
}
273+
274+
static_assert(
275+
std::is_same<InputType, float>() || std::is_same<InputType, float16>(),
276+
"Only float and float16 types are allowed.");
277+
278+
static_assert(
279+
(BIT_RATE == 8) || (BIT_RATE == 4) || (BIT_RATE == 2),
280+
"Only bit rates of 8, 4 and 2 are allowed.");
281+
282+
constexpr uint64_t num_elem_per_byte = 8 / BIT_RATE;
283+
uint64_t column_count = static_cast<uint64_t>(input_columns);
284+
const int output_columns =
285+
(column_count + num_elem_per_byte - 1) / num_elem_per_byte +
286+
2 * sizeof(float16);
287+
288+
for (size_t row = 0; __builtin_expect(row < input_rows, 1); ++row) {
289+
const InputType* input_row = input + row * column_count;
290+
std::uint8_t* output_row = output + row * output_columns;
291+
float16_t* output_row_scale_bias = reinterpret_cast<float16_t*>(
292+
output_row +
293+
(column_count + num_elem_per_byte - 1) / num_elem_per_byte);
294+
295+
float minimum_element;
296+
float maximum_element;
297+
float16_t minimum_element_fp16;
298+
if constexpr (std::is_same<InputType, float>()) {
299+
FindMinMaxImpl_f32(
300+
input_row, &minimum_element, &maximum_element, column_count);
301+
minimum_element_fp16 = static_cast<float16_t>(minimum_element);
302+
minimum_element = static_cast<float>(minimum_element_fp16);
303+
} else {
304+
float16_t maximum_element_fp16;
305+
FindMinMaxImpl_f16(
306+
reinterpret_cast<const float16_t*>(input_row),
307+
&minimum_element_fp16,
308+
&maximum_element_fp16,
309+
column_count);
310+
minimum_element = static_cast<float>(minimum_element_fp16);
311+
maximum_element = static_cast<float>(maximum_element_fp16);
312+
}
313+
314+
const float range = maximum_element - minimum_element;
315+
316+
float scale = range == 0 ? 1.0f : range / ((1 << BIT_RATE) - 1);
317+
float16_t scale_fp16 = static_cast<float16_t>(scale);
318+
scale = static_cast<float>(scale_fp16);
319+
svfloat32_t inverse_scale_sv;
320+
if (scale != 0.0f) {
321+
float inverse_scale = 1.0f / scale;
322+
inverse_scale_sv = svdup_n_f32(inverse_scale);
323+
bool isInf = svptest_any(
324+
svptrue_b8(),
325+
svcmpuo_f32(
326+
svptrue_b8(),
327+
svsub_f32_x(svptrue_b8(), inverse_scale_sv, inverse_scale_sv),
328+
svdup_n_f32(0.0)));
329+
if (isInf) {
330+
scale_fp16 = static_cast<float16_t>(1.0f);
331+
scale = 1.0f;
332+
inverse_scale_sv = svdup_n_f32(1.0f);
333+
}
334+
} else {
335+
// Corner case handling when maximum_element == minimum_element
336+
// Any scale would work because X - minimum_element will be 0 for all X
337+
scale_fp16 = static_cast<float16_t>(1.0f);
338+
scale = 1.0f;
339+
inverse_scale_sv = svdup_n_f32(1.0f);
340+
}
341+
342+
constexpr uint64_t kItemsPerIter = 8;
343+
uint64_t loopIters = column_count / kItemsPerIter;
344+
uint64_t loopRemainder = column_count % kItemsPerIter;
345+
346+
output_row_scale_bias[0] = scale_fp16;
347+
output_row_scale_bias[1] = minimum_element_fp16;
348+
349+
float32x4_t inverse_scale_v = svget_neonq(inverse_scale_sv);
350+
float32x4_t min_v = vdupq_n_f32(minimum_element);
351+
352+
constexpr unsigned int maxValPerBitRate = (1ul << BIT_RATE) - 1;
353+
uint32x4_t maxval_v = vdupq_n_u32(maxValPerBitRate);
354+
355+
svbool_t lastPredA = svwhilelt_b32_u64(0, loopRemainder);
356+
svbool_t lastPredB = svwhilelt_b32_u64(4, loopRemainder);
357+
358+
while (__builtin_expect(loopIters > 0, 1)) {
359+
float32x4_t v0;
360+
float32x4_t v1;
361+
362+
if constexpr (std::is_same<InputType, float>()) {
363+
v0 = vld1q_f32(input_row);
364+
v1 = vld1q_f32(input_row + 4);
365+
} else {
366+
float16x8_t h0 =
367+
vld1q_f16(reinterpret_cast<const float16_t*>(input_row));
368+
v0 = vcvt_f32_f16(vget_low_f16(h0));
369+
v1 = vcvt_high_f32_f16(h0);
370+
}
371+
372+
input_row += kItemsPerIter;
373+
loopIters -= 1;
374+
375+
v0 = vsubq_f32(v0, min_v);
376+
v1 = vsubq_f32(v1, min_v);
377+
378+
v0 = vmulq_f32(v0, inverse_scale_v);
379+
v1 = vmulq_f32(v1, inverse_scale_v);
380+
381+
int32x4_t i0 = vcvtnq_s32_f32(v0);
382+
int32x4_t i1 = vcvtnq_s32_f32(v1);
383+
384+
uint32x4_t u0 = vminq_u32(vreinterpretq_u32_s32(i0), maxval_v);
385+
uint32x4_t u1 = vminq_u32(vreinterpretq_u32_s32(i1), maxval_v);
386+
387+
if constexpr (num_elem_per_byte == 1) {
388+
svst1b_u32(
389+
svptrue_b8(), output_row, svset_neonq_u32(svundef_u32(), u0));
390+
svst1b_u32(
391+
svptrue_b8(), output_row + 4, svset_neonq_u32(svundef_u32(), u1));
392+
} else {
393+
constexpr uint64_t shiftVar = num_elem_per_byte == 2 ? 28 : 30;
394+
395+
uint64x2_t u2 = vreinterpretq_u64_u32(u0) >> shiftVar;
396+
uint64x2_t u3 = vreinterpretq_u64_u32(u1) >> shiftVar;
397+
398+
u2 = veorq_u64(u2, vreinterpretq_u64_u32(u0));
399+
u3 = veorq_u64(u3, vreinterpretq_u64_u32(u1));
400+
401+
if constexpr (num_elem_per_byte == 2) {
402+
svst1b_u64(
403+
svptrue_b8(), output_row, svset_neonq_u64(svundef_u64(), u2));
404+
svst1b_u64(
405+
svptrue_b8(), output_row + 2, svset_neonq_u64(svundef_u64(), u3));
406+
407+
} else if constexpr (num_elem_per_byte == 4) {
408+
auto u4 = vdup_laneq_u8(vreinterpretq_u8_u64(u2), 8);
409+
auto u5 = vdup_laneq_u8(vreinterpretq_u8_u64(u3), 8);
410+
411+
u4 = u4 << 4;
412+
u5 = u5 << 4;
413+
414+
u4 = veor_u8(u4, vget_low_u8(u2));
415+
u5 = veor_u8(u5, vget_low_u8(u3));
416+
417+
vst1_lane_u8(output_row, u4, 0);
418+
vst1_lane_u8(output_row + 1, u5, 0);
419+
}
420+
}
421+
422+
constexpr uint64_t bytesStored = kItemsPerIter / num_elem_per_byte;
423+
output_row += bytesStored;
424+
}
425+
426+
if (loopRemainder > 0) {
427+
float32x4_t v0;
428+
float32x4_t v1;
429+
430+
if constexpr (std::is_same<InputType, float>()) {
431+
v0 = svget_neonq(svld1_f32(lastPredA, input_row));
432+
v1 = svget_neonq(svld1_f32(lastPredB, input_row + 4));
433+
} else {
434+
auto h0 = svld1uh_u32(
435+
lastPredA, reinterpret_cast<const uint16_t*>(input_row));
436+
auto h1 = svld1uh_u32(
437+
lastPredB, reinterpret_cast<const uint16_t*>(input_row + 4));
438+
v0 = svget_neonq(
439+
svcvt_f32_f16_x(svptrue_b8(), svreinterpret_f16_u32(h0)));
440+
v1 = svget_neonq(
441+
svcvt_f32_f16_x(svptrue_b8(), svreinterpret_f16_u32(h1)));
442+
}
443+
444+
v0 = vsubq_f32(v0, min_v);
445+
v1 = vsubq_f32(v1, min_v);
446+
447+
v0 = vmulq_f32(v0, inverse_scale_v);
448+
v1 = vmulq_f32(v1, inverse_scale_v);
449+
450+
int32x4_t i0 = vcvtnq_s32_f32(v0);
451+
int32x4_t i1 = vcvtnq_s32_f32(v1);
452+
453+
uint32x4_t u0 = vminq_u32(vreinterpretq_u32_s32(i0), maxval_v);
454+
uint32x4_t u1 = vminq_u32(vreinterpretq_u32_s32(i1), maxval_v);
455+
456+
if constexpr (num_elem_per_byte == 1) {
457+
svst1b_u32(lastPredA, output_row, svset_neonq_u32(svundef_u32(), u0));
458+
svst1b_u32(
459+
lastPredB, output_row + 4, svset_neonq_u32(svundef_u32(), u1));
460+
} else {
461+
constexpr uint64_t shiftVar = num_elem_per_byte == 2 ? 28 : 30;
462+
463+
uint64x2_t u2 = vreinterpretq_u64_u32(u0) >> shiftVar;
464+
uint64x2_t u3 = vreinterpretq_u64_u32(u1) >> shiftVar;
465+
466+
u2 = veorq_u64(u2, vreinterpretq_u64_u32(u0));
467+
u3 = veorq_u64(u3, vreinterpretq_u64_u32(u1));
468+
469+
if constexpr (num_elem_per_byte == 2) {
470+
svst1b_u64(lastPredA, output_row, svset_neonq_u64(svundef_u64(), u2));
471+
svst1b_u64(
472+
lastPredB, output_row + 2, svset_neonq_u64(svundef_u64(), u3));
473+
474+
} else if constexpr (num_elem_per_byte == 4) {
475+
auto u4 = vdup_laneq_u8(vreinterpretq_u8_u64(u2), 8);
476+
auto u5 = vdup_laneq_u8(vreinterpretq_u8_u64(u3), 8);
477+
478+
u4 = u4 << 4;
479+
u5 = u5 << 4;
480+
481+
u4 = veor_u8(u4, vget_low_u8(u2));
482+
u5 = veor_u8(u5, vget_low_u8(u3));
483+
484+
vst1_lane_u8(output_row, u4, 0);
485+
if (loopRemainder > 4) {
486+
vst1_lane_u8(output_row + 1, u5, 0);
487+
}
488+
}
489+
}
490+
}
491+
} // for each row
492+
}
493+
260494
template <typename OutputType>
261495
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
262496
const std::uint8_t* input,
@@ -372,6 +606,24 @@ INSTANTIATE_QuantizationNeonFunctions8Bits(float16)
372606
// clang-format on
373607
#undef INSTANTIATE_QuantizationNeonFunctions8Bits
374608

609+
#define INSTANTIATE_QuantizationNeonFunctionsNBits(type, bit_rate) \
610+
template void \
611+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon<type, bit_rate>( \
612+
const type* input, \
613+
size_t input_rows, \
614+
int input_columns, \
615+
std::uint8_t* output);
616+
617+
// clang-format off
618+
INSTANTIATE_QuantizationNeonFunctionsNBits(float, 2)
619+
INSTANTIATE_QuantizationNeonFunctionsNBits(float, 4)
620+
INSTANTIATE_QuantizationNeonFunctionsNBits(float, 8)
621+
INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 2)
622+
INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 4)
623+
INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 8)
624+
// clang-format on
625+
#undef INSTANTIATE_QuantizationNeonFunctionsNBits
626+
375627
#endif // HAVE_SVE
376628

377629
} // namespace fbgemm

0 commit comments

Comments
 (0)