Skip to content

Commit cdd446a

Browse files
mcfifacebook-github-bot
authored andcommitted
Fix Arm64 OSS pytorch build with FBGEMM (#4775)
Summary: X-link: facebookresearch/FBGEMM#1796 X-link: pytorch/pytorch#161527 Pull Request resolved: #4775 Without this change, Arm64 OSS pytorch build with FBGEMM failed with the following error. Undefined symbols for architecture arm64: "fbgemm::FindMinMax(float const*, float*, float*, long long)", referenced from: at::native::fbgemm_linear_int8_weight_fp32_activation(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&, at::Tensor const&) in QuantizedLinear.cpp.o at::native::fbgemm_linear_quantize_weight(at::Tensor const&) in QuantizedLinear.cpp.o PackedConvWeight<2>::apply_dynamic(at::Tensor const&, bool) in qconv_dynamic.cpp.o PackedConvWeight<3>::apply_dynamic(at::Tensor const&, bool) in qconv_dynamic.cpp.o at::Tensor PackedLinearWeight::apply_dynamic_impl<false>(at::Tensor, bool) in qlinear_dynamic.cpp.o at::Tensor PackedLinearWeight::apply_dynamic_impl<true>(at::Tensor, bool) in qlinear_dynamic.cpp.o ld: symbol(s) not found for architecture arm64 This change fixed the issue by moving FindMinMax's implementation from QuantUtilsAvx2.cc to QuantUtils.cc. FindMinMax is a platform-agnostic function with AVX2-specific optimizations so conceptually it can be put in QuantUtils.cc. Reviewed By: q10 Differential Revision: D81052327 fbshipit-source-id: c50ac43329d939433fcf6a1610cbbe5726dc6f6e
1 parent a56882d commit cdd446a

File tree

2 files changed

+42
-37
lines changed

2 files changed

+42
-37
lines changed

src/QuantUtils.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
#include "fbgemm/FloatConversion.h"
2020
#include "fbgemm/Types.h"
2121

22+
#if defined(__x86_64__) || defined(__i386__) || \
23+
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
24+
#include <immintrin.h>
25+
#endif
26+
2227
namespace fbgemm {
2328

2429
using namespace std;
@@ -196,6 +201,43 @@ void ChooseRequantizationMultiplier(
196201
////////////////////////////////////////////////////////////////////////////////
197202
// Utility functions
198203

204+
void FindMinMax(const float* m, float* min, float* max, int64_t len) {
205+
if (len <= 0) {
206+
*min = 0.0f;
207+
*max = 0.0f;
208+
return;
209+
}
210+
211+
float temp_min = *m, temp_max = *m;
212+
int64_t i = 0;
213+
214+
#ifdef __AVX__
215+
__m256 min_v = _mm256_set1_ps(*m), max_v = _mm256_set1_ps(*m);
216+
constexpr int VLEN = 8;
217+
if (len >= VLEN) {
218+
for (; i < len / VLEN * VLEN; i += VLEN) {
219+
min_v = _mm256_min_ps(min_v, _mm256_loadu_ps(m + i));
220+
max_v = _mm256_max_ps(max_v, _mm256_loadu_ps(m + i));
221+
}
222+
223+
float min_buf[VLEN], max_buf[VLEN];
224+
_mm256_storeu_ps(min_buf, min_v);
225+
_mm256_storeu_ps(max_buf, max_v);
226+
for (int j = 0; j < VLEN; ++j) {
227+
temp_min = std::min(temp_min, min_buf[j]);
228+
temp_max = std::max(temp_max, max_buf[j]);
229+
}
230+
}
231+
#endif
232+
233+
for (; i < len; i++) {
234+
temp_min = std::min(temp_min, m[i]);
235+
temp_max = std::max(temp_max, m[i]);
236+
}
237+
*min = temp_min;
238+
*max = temp_max;
239+
}
240+
199241
#define FBGEMM_SPECIALIZED_QUANTIZE(T, LEGACY) \
200242
template <> \
201243
FBGEMM_API void Quantize<T, LEGACY>( \

src/QuantUtilsAvx2.cc

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -278,43 +278,6 @@ SPECIALIZE_FUSEDDQAVX2(int8_t)
278278

279279
#undef SPECIALIZE_FUSEDDQAVX2
280280

281-
void FindMinMax(const float* m, float* min, float* max, int64_t len) {
282-
if (len <= 0) {
283-
*min = 0.0f;
284-
*max = 0.0f;
285-
return;
286-
}
287-
288-
float temp_min = *m, temp_max = *m;
289-
int64_t i = 0;
290-
291-
#ifdef __AVX__
292-
__m256 min_v = _mm256_set1_ps(*m), max_v = _mm256_set1_ps(*m);
293-
constexpr int VLEN = 8;
294-
if (len >= VLEN) {
295-
for (; i < len / VLEN * VLEN; i += VLEN) {
296-
min_v = _mm256_min_ps(min_v, _mm256_loadu_ps(m + i));
297-
max_v = _mm256_max_ps(max_v, _mm256_loadu_ps(m + i));
298-
}
299-
300-
float min_buf[VLEN], max_buf[VLEN];
301-
_mm256_storeu_ps(min_buf, min_v);
302-
_mm256_storeu_ps(max_buf, max_v);
303-
for (int j = 0; j < VLEN; ++j) {
304-
temp_min = std::min(temp_min, min_buf[j]);
305-
temp_max = std::max(temp_max, max_buf[j]);
306-
}
307-
}
308-
#endif
309-
310-
for (; i < len; i++) {
311-
temp_min = std::min(temp_min, m[i]);
312-
temp_max = std::max(temp_max, m[i]);
313-
}
314-
*min = temp_min;
315-
*max = temp_max;
316-
}
317-
318281
////////////////////////////////////////////////////////////////////////////////
319282
// Requantization (with floats)
320283

0 commit comments

Comments
 (0)