Skip to content

Commit 69cf6b7

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
Port quantize_val and dequantize_val into torchvision to avoid at::native and android xplat incompatibility
Summary: This diff ports `quantize_val` and `dequantize_val` from at::native to torchvision because native kernels are incompatible with android xplat builds (see D30234056). This should only be temporary until we find a way to move those functions out of at::native, or until the at::native / android incompatibility disappears. Reviewed By: fmassa Differential Revision: D30393619 fbshipit-source-id: 18b7b1b349ad9a24088a120e23da7535f7fa7ddc
1 parent 278c6ae commit 69cf6b7

File tree

1 file changed

+89
-7
lines changed

1 file changed

+89
-7
lines changed

torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include <ATen/ATen.h>
2-
#include <ATen/native/quantized/affine_quantizer.h>
32
#include <torch/library.h>
43

54
#include "../../cpu/roi_align_common.h"
@@ -9,6 +8,90 @@ namespace ops {
98

109
namespace {
1110

11+
// BEGIN copy-pasted code from pytorch core
12+
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/affine_quantizer_base.cpp
13+
// We're vendoring the quantize_val() and dequantize_val() functions here. The
14+
// reason is that these functions belong in at::native, which is incompatible
15+
// with android xplat support.
16+
17+
// FIXME: Remove this section once we can use at::native for android xplat
18+
// builds, or when quantize_val() and dequantize_val() aren't in at::native
19+
20+
#ifdef USE_FBGEMM
21+
template <typename T>
22+
T quantize_val(double scale, int64_t zero_point, float value) {
23+
// Internally, fbgemm::Quantize uses std::nearbyint.
24+
// std::nearbyint results in nearest integer value according to the current
25+
// rounding mode and the default rounding mode is rounds to even in half-way
26+
// cases in most popular processor architectures like x86 and ARM. This is
27+
// typically faster than an alternatives like std::round that rounds half-way
28+
// cases away from zero, and can be consistent with SIMD implementations for
29+
// example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
30+
// _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
31+
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
32+
int32_t qvalue;
33+
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
34+
qvalue = fbgemm::Quantize<typename T::underlying, false /*LEGACY*/>(
35+
value,
36+
static_cast<int32_t>(zero_point),
37+
static_cast<float>(scale),
38+
/*result_precision=*/CHAR_BIT * sizeof(typename T::underlying));
39+
return static_cast<T>(qvalue);
40+
}
41+
42+
template <typename T>
43+
inline float dequantize_val(double scale, int64_t zero_point, T value) {
44+
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
45+
fbgemm::TensorQuantizationParams qparams;
46+
qparams.scale = static_cast<float>(scale);
47+
qparams.zero_point = static_cast<int32_t>(zero_point);
48+
return fbgemm::Dequantize<typename T::underlying>(value.val_, qparams);
49+
}
50+
#else // USE_FBGEMM
51+
52+
#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
53+
template <class T>
54+
inline float Round(const float x) {
55+
return ::nearbyintf(x);
56+
}
57+
inline double Round(const double x) {
58+
return ::nearbyint(x);
59+
}
60+
#else
61+
template <class T>
62+
inline T Round(const T x) {
63+
return std::nearbyint(x);
64+
}
65+
#endif
66+
67+
template <typename T>
68+
T quantize_val(double scale, int64_t zero_point, float value) {
69+
// std::nearbyint results in nearest integer value according to the current
70+
// rounding mode and the default rounding mode is rounds to even in half-way
71+
// cases in most popular processor architectures like x86 and ARM. This is
72+
// typically faster than an alternatives like std::round that rounds half-way
73+
// cases away from zero, and can be consistent with SIMD implementations for
74+
// example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
75+
// _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
76+
int64_t qvalue;
77+
constexpr int64_t qmin = std::numeric_limits<typename T::underlying>::min();
78+
constexpr int64_t qmax = std::numeric_limits<typename T::underlying>::max();
79+
float inv_scale = 1.0f / static_cast<float>(scale);
80+
qvalue = static_cast<int64_t>(zero_point + Round(value * inv_scale));
81+
qvalue = std::max<int64_t>(qvalue, qmin);
82+
qvalue = std::min<int64_t>(qvalue, qmax);
83+
return static_cast<T>(qvalue);
84+
}
85+
86+
template <typename T>
87+
float dequantize_val(double scale, int64_t zero_point, T value) {
88+
// We need to convert the qint8 value to float to ensure the subtraction
89+
// subexpression returns a float
90+
return (static_cast<float>(value.val_) - zero_point) * scale;
91+
}
92+
#endif // USE_FBGEMM
93+
// END copy-pasted code from pytorch core
94+
1295
template <typename T>
1396
void qroi_align_forward_kernel_impl(
1497
int n_rois,
@@ -46,19 +129,19 @@ void qroi_align_forward_kernel_impl(
46129
// Do not using rounding; this implementation detail is critical
47130
float offset = aligned ? 0.5 : 0.;
48131
float roi_start_w =
49-
at::native::dequantize_val(rois_scale, rois_zp, offset_rois[1]) *
132+
dequantize_val(rois_scale, rois_zp, offset_rois[1]) *
50133
spatial_scale -
51134
offset;
52135
float roi_start_h =
53-
at::native::dequantize_val(rois_scale, rois_zp, offset_rois[2]) *
136+
dequantize_val(rois_scale, rois_zp, offset_rois[2]) *
54137
spatial_scale -
55138
offset;
56139
float roi_end_w =
57-
at::native::dequantize_val(rois_scale, rois_zp, offset_rois[3]) *
140+
dequantize_val(rois_scale, rois_zp, offset_rois[3]) *
58141
spatial_scale -
59142
offset;
60143
float roi_end_h =
61-
at::native::dequantize_val(rois_scale, rois_zp, offset_rois[4]) *
144+
dequantize_val(rois_scale, rois_zp, offset_rois[4]) *
62145
spatial_scale -
63146
offset;
64147

@@ -134,8 +217,7 @@ void qroi_align_forward_kernel_impl(
134217

135218
output_val /= count; // Average pooling
136219

137-
output[index] =
138-
at::native::quantize_val<T>(input_scale, input_zp, output_val);
220+
output[index] = quantize_val<T>(input_scale, input_zp, output_val);
139221
} // for pw
140222
} // for ph
141223
} // for c

0 commit comments

Comments
 (0)