11#include < ATen/ATen.h>
2- #include < ATen/native/quantized/affine_quantizer.h>
32#include < torch/library.h>
43
54#include " ../../cpu/roi_align_common.h"
@@ -9,6 +8,90 @@ namespace ops {
98
109namespace {
1110
11+ // BEGIN copy-pasted code from pytorch core
12+ // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/affine_quantizer_base.cpp
13+ // We're vendoring the quantize_val() and dequantize_val() functions here. The
14+ // reason is that these functions belong in at::native, which is incompatible
15+ // with android xplat support.
16+
17+ // FIXME: Remove this section once we can use at::native for android xplat
18+ // builds, or when quantize_val() and dequantize_val() aren't in at::native
19+
20+ #ifdef USE_FBGEMM
21+ template <typename T>
22+ T quantize_val (double scale, int64_t zero_point, float value) {
23+ // Internally, fbgemm::Quantize uses std::nearbyint.
24+ // std::nearbyint results in nearest integer value according to the current
25+ // rounding mode and the default rounding mode is rounds to even in half-way
26+ // cases in most popular processor architectures like x86 and ARM. This is
27+ // typically faster than an alternatives like std::round that rounds half-way
28+ // cases away from zero, and can be consistent with SIMD implementations for
29+ // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
30+ // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
31+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
32+ int32_t qvalue;
33+ // NOLINTNEXTLINE(bugprone-signed-char-misuse)
34+ qvalue = fbgemm::Quantize<typename T::underlying, false /* LEGACY*/ >(
35+ value,
36+ static_cast <int32_t >(zero_point),
37+ static_cast <float >(scale),
38+ /* result_precision=*/ CHAR_BIT * sizeof (typename T::underlying));
39+ return static_cast <T>(qvalue);
40+ }
41+
42+ template <typename T>
43+ inline float dequantize_val (double scale, int64_t zero_point, T value) {
44+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
45+ fbgemm::TensorQuantizationParams qparams;
46+ qparams.scale = static_cast <float >(scale);
47+ qparams.zero_point = static_cast <int32_t >(zero_point);
48+ return fbgemm::Dequantize<typename T::underlying>(value.val_ , qparams);
49+ }
50+ #else // USE_FBGEMM
51+
52+ #if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
53+ template <class T >
54+ inline float Round (const float x) {
55+ return ::nearbyintf (x);
56+ }
57+ inline double Round (const double x) {
58+ return ::nearbyint (x);
59+ }
60+ #else
61+ template <class T >
62+ inline T Round (const T x) {
63+ return std::nearbyint (x);
64+ }
65+ #endif
66+
67+ template <typename T>
68+ T quantize_val (double scale, int64_t zero_point, float value) {
69+ // std::nearbyint results in nearest integer value according to the current
70+ // rounding mode and the default rounding mode is rounds to even in half-way
71+ // cases in most popular processor architectures like x86 and ARM. This is
72+ // typically faster than an alternatives like std::round that rounds half-way
73+ // cases away from zero, and can be consistent with SIMD implementations for
74+ // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
75+ // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
76+ int64_t qvalue;
77+ constexpr int64_t qmin = std::numeric_limits<typename T::underlying>::min ();
78+ constexpr int64_t qmax = std::numeric_limits<typename T::underlying>::max ();
79+ float inv_scale = 1 .0f / static_cast <float >(scale);
80+ qvalue = static_cast <int64_t >(zero_point + Round (value * inv_scale));
81+ qvalue = std::max<int64_t >(qvalue, qmin);
82+ qvalue = std::min<int64_t >(qvalue, qmax);
83+ return static_cast <T>(qvalue);
84+ }
85+
86+ template <typename T>
87+ float dequantize_val (double scale, int64_t zero_point, T value) {
88+ // We need to convert the qint8 value to float to ensure the subtraction
89+ // subexpression returns a float
90+ return (static_cast <float >(value.val_ ) - zero_point) * scale;
91+ }
92+ #endif // USE_FBGEMM
93+ // END copy-pasted code from pytorch core
94+
1295template <typename T>
1396void qroi_align_forward_kernel_impl (
1497 int n_rois,
@@ -46,19 +129,19 @@ void qroi_align_forward_kernel_impl(
46129 // Do not using rounding; this implementation detail is critical
47130 float offset = aligned ? 0.5 : 0 .;
48131 float roi_start_w =
49- at::native:: dequantize_val (rois_scale, rois_zp, offset_rois[1 ]) *
132+ dequantize_val (rois_scale, rois_zp, offset_rois[1 ]) *
50133 spatial_scale -
51134 offset;
52135 float roi_start_h =
53- at::native:: dequantize_val (rois_scale, rois_zp, offset_rois[2 ]) *
136+ dequantize_val (rois_scale, rois_zp, offset_rois[2 ]) *
54137 spatial_scale -
55138 offset;
56139 float roi_end_w =
57- at::native:: dequantize_val (rois_scale, rois_zp, offset_rois[3 ]) *
140+ dequantize_val (rois_scale, rois_zp, offset_rois[3 ]) *
58141 spatial_scale -
59142 offset;
60143 float roi_end_h =
61- at::native:: dequantize_val (rois_scale, rois_zp, offset_rois[4 ]) *
144+ dequantize_val (rois_scale, rois_zp, offset_rois[4 ]) *
62145 spatial_scale -
63146 offset;
64147
@@ -134,8 +217,7 @@ void qroi_align_forward_kernel_impl(
134217
135218 output_val /= count; // Average pooling
136219
137- output[index] =
138- at::native::quantize_val<T>(input_scale, input_zp, output_val);
220+ output[index] = quantize_val<T>(input_scale, input_zp, output_val);
139221 } // for pw
140222 } // for ph
141223 } // for c
0 commit comments