diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ceabf757f..3dd3d09e30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,7 +53,7 @@ endfunction() project( fbgemm VERSION 1.4.0 - LANGUAGES CXX) + LANGUAGES C CXX) # Add C++ compiler flag detection include(CheckCXXCompilerFlag) diff --git a/defs.bzl b/defs.bzl index 5b3d15d34a..ba92928e0b 100644 --- a/defs.bzl +++ b/defs.bzl @@ -33,6 +33,7 @@ def get_fbgemm_base_srcs(): def get_fbgemm_generic_srcs(with_base = False, msvc = False, buck = False): sources = [ + "src/CodeStorage.cc", "src/EmbeddingSpMDM.cc", "src/EmbeddingSpMDMNBit.cc", "src/ExecuteKernel.cc", diff --git a/external/asmjit b/external/asmjit index a3199e8857..28295814dd 160000 --- a/external/asmjit +++ b/external/asmjit @@ -1 +1 @@ -Subproject commit a3199e8857792cd10b7589ff5d58343d2c9008ea +Subproject commit 28295814dd126459711de4bc20b5d1502b000d4e diff --git a/src/CodeGenHelpers.h b/src/CodeGenHelpers.h index df88a7bc5f..f4cfff36ad 100644 --- a/src/CodeGenHelpers.h +++ b/src/CodeGenHelpers.h @@ -15,61 +15,89 @@ namespace fbgemm { namespace x86 = asmjit::x86; +class FileLoggerWithClose : public asmjit::FileLogger { + explicit FileLoggerWithClose(FILE* f) + : FileLogger(f) {} + + ~FileLoggerWithClose() noexcept override { + if (_file) { + fclose(_file); + } + } +}; + +/** + * @brief Emit an instruction that fills a vector register with + * all ones. + * + * @param dest Destination register to fill + */ +template +void emitVecFillWithOnes(x86::Emitter* a, const x86::Vec& dest) { + if constexpr (instSet == inst_set_t::avx2) { + a->vpcmpeqb(dest, dest, dest); + } else { + a->vpternlogd(dest, dest, dest, 0xff); + } +} + +/** + * @brief Create instruction sequence to generate 8-bit 1s + * + * @param dest Once the instruction sequence is executed, + * dest[0:7] will have 0x01, dest[8:15] + * will have 0x01 and so on + */ +template +void gen8BitVectorOne(x86::Emitter* a, const x86::Vec& dest) { + emitVecFillWithOnes(a, dest); + a->vpabsb(dest, dest); +} + /** * @brief Create instruction sequence to generate 16-bit 1s - * @tparam T Register type of destination, e.g., Ymm or Zmm * * @param dest Once the instruction sequence is executed, * dest[0:15] will have 0x0001, dest[16:31] * will have 0x0001 and so on */ -template < - inst_set_t instSet, - typename T, - std::enable_if_t = 0> -void gen16BitVectorOne(x86::Emitter* a, T dest) { - a->vpcmpeqw(dest, dest, dest); - a->vpsrlw(dest, dest, 15); -} - -template < - inst_set_t instSet, - typename T, - std::enable_if_t< - instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm || - instSet == inst_set_t::avx512_vnni || - instSet == inst_set_t::avx512_vnni_ymm, - int> = 0> -void gen16BitVectorOne(x86::Emitter* a, T dest) { - a->vpternlogd(dest, dest, dest, 0xff); +template +void gen16BitVectorOne(x86::Emitter* a, const x86::Vec& dest) { + emitVecFillWithOnes(a, dest); a->vpsrlw(dest, dest, 15); } /** * @brief Emit instruction do load 32-bit integer. AVX512 has - * different instrunction to load registers with index >= 16 - * @tparam T Register type of destination, e.g., Ymm or Zmm + * different instruction to load registers with index >= 16 * * @param dest Destination vector register */ -template < - inst_set_t instSet, - typename T, - std::enable_if_t = 0> -void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) { - a->vmovdqa(dest, ptr); +template +void emitVecMove(x86::Emitter* a, const Dest& dest, const Src& src) { + if constexpr (instSet == inst_set_t::avx2) { + a->vmovdqa(dest, src); + } else { + a->vmovdqa32(dest, src); + } } -template < - inst_set_t instSet, - typename T, - std::enable_if_t< - instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm || - instSet == inst_set_t::avx512_vnni || - instSet == inst_set_t::avx512_vnni_ymm, - int> = 0> -void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) { - a->vmovdqa32(dest, ptr); +template +void emitVecOr(x86::Emitter* a, const Dest& dest, const Src1& src1, const Src2& src2) { + if constexpr (instSet == inst_set_t::avx2) { + a->vpor(dest, src1, src2); + } else { + a->vpord(dest, src1, src2); + } +} + +template +void emitVecXor(x86::Emitter* a, const Dest& dest, const Src1& src1, const Src2& src2) { + if constexpr (instSet == inst_set_t::avx2) { + a->vpxor(dest, src1, src2); + } else { + a->vpxord(dest, src1, src2); + } } /** @@ -81,68 +109,22 @@ void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) { * @param vec Source (full) vector register * @param idx Index of of the half vector 0 or 1 */ -template < - inst_set_t instSet, - typename T, - std::enable_if_t< - instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm || - instSet == inst_set_t::avx512_vnni || - instSet == inst_set_t::avx512_vnni_ymm, - int> = 0> -void emitExtractHalfVector( - x86::Emitter* a, - const Ymm& half, - const Zmm& vec, - int idx) { - a->vextracti32x8(half, vec, idx); -} -template < - inst_set_t instSet, - typename T, - std::enable_if_t< - instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm || - instSet == inst_set_t::avx512_vnni || - instSet == inst_set_t::avx512_vnni_ymm, - int> = 0> +template void emitExtractHalfVector( x86::Emitter* a, - const Xmm& half, - const Ymm& vec, + const x86::Vec& half, + const x86::Vec& vec, int idx) { - a->vextracti32x4(half, vec, idx); -} - -template < - inst_set_t instSet, - typename T, - std::enable_if_t = 0> -void emitExtractHalfVector( - x86::Emitter* a, - const Xmm& half, - const Ymm& vec, - int idx) { - a->vextracti128(half, vec, idx); -} - -/** - * @brief Create instruction sequence to generate 8-bit 1s - * @tparam T Register type of destination, e.g., Ymm or Zmm - * - * @param dest Once the instruction sequence is executed, - * dest[0:7] will have 0x01, dest[8:15] - * will have 0x01 and so on - */ -template , int> = 0> -void gen8BitVectorOne(x86::Emitter* a, T dest) { - a->vpcmpeqw(dest, dest, dest); - a->vpabsb(dest, dest); -} - -template , int> = 0> -void gen8BitVectorOne(x86::Emitter* a, T dest) { - a->vpternlogd(dest, dest, dest, 0xff); - a->vpabsb(dest, dest); + if constexpr (instSet == inst_set_t::avx2) { + a->vextracti128(half, vec, idx); + } else { + if (vec.is_vec512()) { + a->vextracti32x8(half, vec, idx); + } else { + a->vextracti32x4(half, vec, idx); + } + } } /** @@ -160,11 +142,11 @@ template < int> = 0> void genU8I8S32FMA( x86::Emitter* a, - typename simd_info::vec_reg_t aReg, - typename simd_info::vec_reg_t bReg, - typename simd_info::vec_reg_t cReg, - typename simd_info::vec_reg_t oneReg16Bit, - typename simd_info::vec_reg_t tmpReg) { + const x86::Vec& aReg, + const x86::Vec& bReg, + const x86::Vec& cReg, + const x86::Vec& oneReg16Bit, + const x86::Vec& tmpReg) { a->vpmaddubsw(tmpReg, aReg, bReg); a->vpmaddwd(tmpReg, oneReg16Bit, tmpReg); a->vpaddd(cReg, tmpReg, cReg); @@ -175,11 +157,11 @@ template < std::enable_if_t = 0> void genU8I8S32FMA( x86::Emitter* a, - typename simd_info::vec_reg_t aReg, - typename simd_info::vec_reg_t bReg, - typename simd_info::vec_reg_t cReg, - typename simd_info::vec_reg_t /*oneReg16Bit*/, - typename simd_info::vec_reg_t /*tmpReg*/) { + const x86::Vec& aReg, + const x86::Vec& bReg, + const x86::Vec& cReg, + const x86::Vec& /*oneReg16Bit*/, + const x86::Vec& /*tmpReg*/) { a->vpdpbusd(cReg, aReg, bReg); } @@ -200,11 +182,11 @@ template < int> = 0> void genU8Sum4( x86::Emitter* a, - typename simd_info::vec_reg_t src, - typename simd_info::vec_reg_t dest, - typename simd_info::vec_reg_t oneReg16Bit, - typename simd_info::vec_reg_t tmpReg) { - gen8BitVectorOne(a, tmpReg); + const x86::Vec& src, + const x86::Vec& dest, + const x86::Vec& oneReg16Bit, + const x86::Vec& tmpReg) { + gen8BitVectorOne(a, tmpReg); a->vpmaddubsw(tmpReg, src, tmpReg); a->vpmaddwd(tmpReg, tmpReg, oneReg16Bit); a->vpaddd(dest, tmpReg, dest); @@ -220,11 +202,11 @@ template < std::enable_if_t = 0> void genU8Sum4( x86::Emitter* a, - typename simd_info::vec_reg_t src, - typename simd_info::vec_reg_t dest, - typename simd_info::vec_reg_t /*oneReg16Bit*/, - typename simd_info::vec_reg_t tmpReg) { - gen8BitVectorOne(a, tmpReg); + const x86::Vec& src, + const x86::Vec& dest, + const x86::Vec& /*oneReg16Bit*/, + const x86::Vec& tmpReg) { + gen8BitVectorOne(a, tmpReg); a->vpdpbusd(dest, src, tmpReg); } @@ -263,8 +245,32 @@ template void broadcast8Bit(x86::Emitter* a, x86::Gp src, T dest) { // move src to dest auto xmm = dest.xmm(); - a->movq(xmm, src); + a->vmovq(xmm, src); a->vpbroadcastb(dest, xmm); } +template +void emitReduceAddF32(x86::Emitter* a, x86::Vec v, const x86::Vec& tmp) { + if constexpr (instSet != inst_set_t::avx2) { + if (v.is_vec512()) { + a->vextractf32x8(tmp.ymm(), v, 1); + v = v.ymm(); + a->vaddps(v.ymm(), v, tmp); + } + } + + if (v.is_vec256()) { + if constexpr (instSet == inst_set_t::avx2) { + a->vextractf128(tmp.xmm(), v, 1); + } else { + a->vextractf32x4(tmp.xmm(), v, 1); + } + v = v.xmm(); + a->vaddps(v, v, tmp.xmm()); + } + + a->vhaddps(v, v, v); + a->vhaddps(v, v, v); +} + } // namespace fbgemm diff --git a/src/CodeStorage.cc b/src/CodeStorage.cc new file mode 100644 index 0000000000..5c8b6d92aa --- /dev/null +++ b/src/CodeStorage.cc @@ -0,0 +1,23 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "./CodeStorage.h" + +namespace fbgemm { +namespace CodeStorage { + +asmjit::JitRuntime& getRuntime() { + // JIT Runtime for asmjit, depends on other static variables. + // Required to prevent initialization order fiasco + static asmjit::JitRuntime rt; + + return rt; +} + +} // namespace CodeStorage +} // namespace fbgemm diff --git a/src/CodeStorage.h b/src/CodeStorage.h new file mode 100644 index 0000000000..f0f7b2757d --- /dev/null +++ b/src/CodeStorage.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include // @manual + +namespace fbgemm { +namespace CodeStorage { + +asmjit::JitRuntime& getRuntime(); + +} // namespace CodeStorage +} // namespace fbgemm diff --git a/src/DirectConv.h b/src/DirectConv.h index 4b4ac941c0..a0bb100e77 100644 --- a/src/DirectConv.h +++ b/src/DirectConv.h @@ -53,8 +53,6 @@ class DirectConvCodeGenBase { int o1Xoc, int i1); - inline static std::mutex rtMutex_; ///< Control access to runtime; - // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr inline static CodeCache< std::tuple, @@ -194,15 +192,6 @@ class DirectConvCodeGenBase { const x86::Gp& C_offset, int rowRegs, int colRegs); - - private: - static asmjit::JitRuntime& runtime() { - static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, - // depents on other static - // variables. Required to prevent - // initialization order fiasco - return rt; - } }; } // namespace fbgemm diff --git a/src/EmbeddingSpMDM.cc b/src/EmbeddingSpMDM.cc index dc86ff221f..ccd2974ae6 100644 --- a/src/EmbeddingSpMDM.cc +++ b/src/EmbeddingSpMDM.cc @@ -17,6 +17,7 @@ #include #include #include "./CodeCache.h" // @manual +#include "./CodeStorage.h" // @manual #include "./EmbeddingSpMDMAutovec.h" // @manual #include "./MaskAvx2.h" // @manual #include "./RefImplementations.h" // @manual @@ -109,16 +110,6 @@ class GenEmbeddingSpMDMLookup { bool is_bf16_in); private: - static asmjit::JitRuntime& runtime() { - static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, - // depents on other static - // variables. Required to prevent - // initialization order fiasco - return rt; - } - - inline static std::mutex rtMutex_; ///< Control access to runtime; - // The hash depends on embedding dimension (block size), weighted sls, // positional weights, normalize by lengths, prefetch distance, use_offsets, // output_stride, input_stride, and scale_bias_last @@ -195,12 +186,14 @@ GenEmbeddingSpMDMLookup< bool is_fp16_in = is_16bit_in && !is_bf16_in; bool is_fp16_out = is_16bit_out && !is_bf16_out; + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); + // TODO: Make this tunable int pref_dist = prefetch; constexpr bool areIndices64b = std::is_same_v; asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); #if defined(FBGEMM_LOG_CODE) @@ -241,15 +234,15 @@ GenEmbeddingSpMDMLookup< } filename += ".txt"; FILE* codeLogFile = fopen(filename.c_str(), "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile); - code.setLogger(codeLogger); + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif // arguments to the function created - x86::Gp output_size = a->zdi(); + x86::Gp output_size = x86::rdi; // index_size will be overwritten to hold the end address of indices - x86::Gp index_size = a->zsi(); - x86::Gp data_size = a->zdx(); - x86::Gp input = a->zcx(); + x86::Gp index_size = x86::rsi; + x86::Gp data_size = x86::rdx; + x86::Gp input = x86::rcx; int reg_id = 8; x86::Gp indices = a->gpz(reg_id); // 8 ++reg_id; @@ -309,28 +302,28 @@ GenEmbeddingSpMDMLookup< frame.init(func); if constexpr (instSet == inst_set_t::avx2) { - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); } else { - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | - asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | - asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15) | + asmjit::Support::bit_mask(16, 17, 18, 19, 20, 21, 22, 23) | + asmjit::Support::bit_mask(24, 25, 26, 27, 28, 29, 30, 31)); } - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kGp, reg_id == 15 - ? asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) - : asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + ? asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15) + : asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14)); asmjit::FuncArgsAssignment args(&func); if constexpr (ROWWISE_SPARSE) { - args.assignAll( + args.assign_all( output_size, index_size, data_size, @@ -342,7 +335,7 @@ GenEmbeddingSpMDMLookup< compressed_indices_table, scratchReg1_); } else { - args.assignAll( + args.assign_all( output_size, index_size, data_size, @@ -354,11 +347,11 @@ GenEmbeddingSpMDMLookup< scratchReg1_); } - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + a->emit_prolog(frame); + a->emit_args_assignment(frame, args); constexpr int vlen = simd_info::WIDTH_32BIT_ELEMS; constexpr int NUM_VEC_REG = simd_info::NUM_VEC_REGS; @@ -375,8 +368,8 @@ GenEmbeddingSpMDMLookup< vec_reg_t vlen_inv_vreg; // used for normalize by lengths -- 1/ lengths[i] vec_reg_t src_vreg; // for holding embedding value temporarily - Ymm mask_vreg; // mask for avx2 - Xmm mask_fp16_vreg; // mask for loading fp16 in avx2 + x86::Vec mask_vreg_ymm; // mask for avx2 + x86::Vec mask_fp16_vreg_xmm; // mask for loading fp16 in avx2 vec_reg_t ones_vreg; // 2^15 for bf16_2_fp32_rn if constexpr (is_8bit_in) { @@ -409,10 +402,10 @@ GenEmbeddingSpMDMLookup< if (remainder && instSet == inst_set_t::avx2) { // AVX512 doesn't need to use vector register for masking --unroll_factor; - mask_vreg = x86::ymm(unroll_factor); + mask_vreg_ymm = x86::ymm(unroll_factor); if (remainder > 1 && (is_16bit_in || is_bf16_out || is_fp16_out)) { --unroll_factor; - mask_fp16_vreg = x86::xmm(unroll_factor); + mask_fp16_vreg_xmm = x86::xmm(unroll_factor); } } @@ -424,13 +417,13 @@ GenEmbeddingSpMDMLookup< if (remainder) { if constexpr (instSet == inst_set_t::avx2) { a->vmovups( - mask_vreg, + mask_vreg_ymm, x86::ymmword_ptr( scratchReg1_, (vlen - remainder) % vlen * sizeof(int32_t))); if (is_16bit_in || is_bf16_out || is_fp16_out) { if (remainder > 1) { a->vmovups( - mask_fp16_vreg, + mask_fp16_vreg_xmm, x86::xmmword_ptr( scratchReg1_, (vlen - remainder / 2) * sizeof(int32_t))); @@ -443,7 +436,7 @@ GenEmbeddingSpMDMLookup< } } else { a->mov(scratchReg1_, (1 << remainder) - 1); - a->kmovw(x86::k(1), scratchReg1_); + a->kmovw(x86::k1, scratchReg1_); } } @@ -451,10 +444,10 @@ GenEmbeddingSpMDMLookup< a->lea( index_size, x86::ptr(indices, index_size, areIndices64b ? 3 : 2)); - asmjit::Label exit = a->newLabel(); - asmjit::Label error = a->newLabel(); - asmjit::Label LoopRangeIndexBegin = a->newLabel(); - asmjit::Label LoopRangeIndexEnd = a->newLabel(); + asmjit::Label exit = a->new_label(); + asmjit::Label error = a->new_label(); + asmjit::Label LoopRangeIndexBegin = a->new_label(); + asmjit::Label LoopRangeIndexEnd = a->new_label(); // rangeIndex loop begins (iterate output_size times) a->bind(LoopRangeIndexBegin); @@ -462,8 +455,8 @@ GenEmbeddingSpMDMLookup< a->jl(LoopRangeIndexEnd); if (normalize_by_lengths) { - asmjit::Label IfLengthsBegin = a->newLabel(); - asmjit::Label IfLengthsEnd = a->newLabel(); + asmjit::Label IfLengthsBegin = a->new_label(); + asmjit::Label IfLengthsEnd = a->new_label(); a->bind(IfLengthsBegin); if (use_offsets) { a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType))); @@ -480,13 +473,13 @@ GenEmbeddingSpMDMLookup< vec_reg_t temp_vreg(0); if constexpr (instSet == inst_set_t::avx2) { a->mov(scratchReg1_, 1); - a->cvtsi2ss(vlen_inv_vreg.xmm(), scratchReg1_); - a->cvtsi2ss(temp_vreg.xmm(), lengths_R_); - a->divss(vlen_inv_vreg.xmm(), temp_vreg.xmm()); + a->vcvtsi2ss(vlen_inv_vreg.xmm(), vlen_inv_vreg.xmm(), scratchReg1_); + a->vcvtsi2ss(temp_vreg.xmm(), temp_vreg.xmm(), lengths_R_); + a->vdivss(vlen_inv_vreg.xmm(), vlen_inv_vreg.xmm(), temp_vreg.xmm()); a->vpbroadcastd(vlen_inv_vreg, vlen_inv_vreg.xmm()); } else { // avx512 a->mov(scratchReg1_, 1); - a->cvtsi2ss(temp_vreg.xmm(), scratchReg1_); + a->vcvtsi2ss(temp_vreg.xmm(), temp_vreg.xmm(), scratchReg1_); a->vpbroadcastd(vlen_inv_vreg, temp_vreg.xmm()); a->vpbroadcastd(temp_vreg, lengths_R_); a->vcvtdq2ps(temp_vreg, temp_vreg); @@ -520,9 +513,9 @@ GenEmbeddingSpMDMLookup< a->cmp(scratchReg1_, index_size); a->jg(error); - asmjit::Label LoopDataIndexBegin = a->newLabel(); - asmjit::Label LoopDataIndexEnd = a->newLabel(); - asmjit::Label ValidIndexLabel = a->newLabel(); + asmjit::Label LoopDataIndexBegin = a->new_label(); + asmjit::Label LoopDataIndexEnd = a->new_label(); + asmjit::Label ValidIndexLabel = a->new_label(); // dataIndex loop begins (iterate lengths_R_ times) a->bind(LoopDataIndexBegin); @@ -539,14 +532,14 @@ GenEmbeddingSpMDMLookup< // When scale_bias_last == false, assume this is for table batched // embedding (TBE) that can get -1 for pruned rows. if constexpr (areIndices64b) { - a->cmp(scratchReg1_, static_cast(-1)); + a->cmp(scratchReg1_, -1); } else { - a->cmp(scratchReg1_.r32(), static_cast(-1)); + a->cmp(scratchReg1_.r32(), -1); } a->jne(ValidIndexLabel); - a->add(indices, static_cast(sizeof(indxType))); + a->add(indices, sizeof(indxType)); if (has_weight) { - a->add(weights, static_cast(sizeof(float))); + a->add(weights, sizeof(float)); } a->jmp(LoopDataIndexBegin); a->bind(ValidIndexLabel); @@ -569,8 +562,8 @@ GenEmbeddingSpMDMLookup< int fused_block_size = input_stride * sizeof(inType); if (pref_dist) { - asmjit::Label pref_dist_reset_start = a->newLabel(); - asmjit::Label pref_dist_reset_end = a->newLabel(); + asmjit::Label pref_dist_reset_start = a->new_label(); + asmjit::Label pref_dist_reset_end = a->new_label(); // out of bound handling for prefetch a->lea( scratchReg2_, x86::ptr(indices, pref_dist * sizeof(indxType))); @@ -601,8 +594,8 @@ GenEmbeddingSpMDMLookup< a->bind(pref_dist_reset_end); if constexpr (ROWWISE_SPARSE) { asmjit::Label rowwise_sparse_pref_corner_case_begin = - a->newLabel(); - asmjit::Label rowwise_sparse_pref_corner_case_end = a->newLabel(); + a->new_label(); + asmjit::Label rowwise_sparse_pref_corner_case_end = a->new_label(); a->cmp(scratchReg2_, data_size); a->jae(rowwise_sparse_pref_corner_case_begin); @@ -621,22 +614,22 @@ GenEmbeddingSpMDMLookup< a->xor_(scratchReg2_.r32(), scratchReg2_.r32()); a->bind(rowwise_sparse_pref_corner_case_end); } - a->imul(scratchReg2_, static_cast(fused_block_size)); + a->imul(scratchReg2_, fused_block_size); } - a->add(indices, static_cast(sizeof(indxType))); + a->add(indices, sizeof(indxType)); if (has_weight) { a->vbroadcastss(w_vreg, x86::dword_ptr(weights)); - a->add(weights, static_cast(sizeof(float))); + a->add(weights, sizeof(float)); } if constexpr (ROWWISE_SPARSE) { - a->cmp(scratchReg1_.r32(), static_cast(-1)); + a->cmp(scratchReg1_.r32(), -1); a->je(LoopDataIndexBegin); } - a->imul(scratchReg1_, static_cast(fused_block_size)); + a->imul(scratchReg1_, fused_block_size); // broadcast the scale constexpr unsigned int CACHE_LINE_LEN = 64; @@ -694,7 +687,7 @@ GenEmbeddingSpMDMLookup< if constexpr (is_8bit_in) { if (remainder && vec_idx + v == num_vec_regs_per_block - 1 && instSet == inst_set_t::avx512) { - a->k(x86::k(1)).z().vpmovzxbd(src_vreg, src_addr); + a->k(x86::k1).z().vpmovzxbd(src_vreg, src_addr); } else { // We don't use a mask for AVX2 since we can use the extra // "padding" of the 2 floats (= 8 chars) scale and bias @@ -708,7 +701,7 @@ GenEmbeddingSpMDMLookup< if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if constexpr (instSet == inst_set_t::avx2) { if (remainder % 2 == 0) { - a->vmaskmovps(src_vreg.xmm(), mask_fp16_vreg, src_addr); + a->vmaskmovps(src_vreg.xmm(), mask_fp16_vreg_xmm, src_addr); } else { a->vpbroadcastw( src_vreg.xmm(), @@ -724,10 +717,10 @@ GenEmbeddingSpMDMLookup< // First put broadcasted last 16-bit element a->vmovups(x86::xmmword_ptr(x86::rsp), src_vreg.xmm()); // Mask store the remaining 16-bit elements - a->vmaskmovps(src_vreg.xmm(), mask_fp16_vreg, src_addr); + a->vmaskmovps(src_vreg.xmm(), mask_fp16_vreg_xmm, src_addr); a->vmaskmovps( x86::xmmword_ptr(x86::rsp), - mask_fp16_vreg, + mask_fp16_vreg_xmm, src_vreg.xmm()); // Load combined 16-bit elements a->vmovups(src_vreg.xmm(), x86::xmmword_ptr(x86::rsp)); @@ -743,11 +736,11 @@ GenEmbeddingSpMDMLookup< } else { // avx512 if (is_fp16_in) { - a->k(x86::k(1)).z().vcvtph2ps(src_vreg, src_addr); + a->k(x86::k1).z().vcvtph2ps(src_vreg, src_addr); } else if (is_bf16_in) { // bf16 - a->k(x86::k(1)).z().vpmovzxwd(src_vreg, src_addr); - a->k(x86::k(1)).z().vpslld(src_vreg, src_vreg, 16); + a->k(x86::k1).z().vpmovzxwd(src_vreg, src_addr); + a->k(x86::k1).z().vpslld(src_vreg, src_vreg, 16); } } } else { @@ -769,14 +762,14 @@ GenEmbeddingSpMDMLookup< // This part for FP32 SLS if (remainder && vec_idx + v == num_vec_regs_per_block - 1 && instSet == inst_set_t::avx2) { - a->vmaskmovps(src_vreg.ymm(), mask_vreg.ymm(), src_addr); + a->vmaskmovps(src_vreg.ymm(), mask_vreg_ymm.ymm(), src_addr); } if (has_weight) { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if constexpr (instSet == inst_set_t::avx2) { a->vfmadd231ps(out_vreg, w_vreg, src_vreg); } else { - a->k(x86::k(1)).vfmadd231ps(out_vreg, w_vreg, src_addr); + a->k(x86::k1).vfmadd231ps(out_vreg, w_vreg, src_addr); } } else { a->vfmadd231ps(out_vreg, w_vreg, src_addr); @@ -786,7 +779,7 @@ GenEmbeddingSpMDMLookup< if constexpr (instSet == inst_set_t::avx2) { a->vaddps(out_vreg, out_vreg, src_vreg); } else { - a->k(x86::k(1)).vaddps(out_vreg, out_vreg, src_addr); + a->k(x86::k1).vaddps(out_vreg, out_vreg, src_addr); } } else { a->vaddps(out_vreg, out_vreg, src_addr); @@ -820,9 +813,9 @@ GenEmbeddingSpMDMLookup< if constexpr (std::is_same_v) { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if constexpr (instSet == inst_set_t::avx2) { - a->vmaskmovps(dst_addr, mask_vreg, out_vreg.ymm()); + a->vmaskmovps(dst_addr, mask_vreg_ymm, out_vreg.ymm()); } else { - a->k(x86::k(1)).vmovups(dst_addr, out_vreg); + a->k(x86::k1).vmovups(dst_addr, out_vreg); } } else { a->vmovups(dst_addr, out_vreg); @@ -841,7 +834,7 @@ GenEmbeddingSpMDMLookup< } if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (remainder > 1) { - a->vmaskmovps(dst_addr, mask_fp16_vreg, out_vreg.xmm()); + a->vmaskmovps(dst_addr, mask_fp16_vreg_xmm, out_vreg.xmm()); } if (remainder % 2 != 0) { a->vmovups(x86::xmmword_ptr(x86::rsp), out_vreg.xmm()); @@ -862,12 +855,12 @@ GenEmbeddingSpMDMLookup< } else { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (is_fp16_out) { - a->k(x86::k(1)).vcvtps2ph(dst_addr, out_vreg, 8); + a->k(x86::k1).vcvtps2ph(dst_addr, out_vreg, 8); } else if (is_bf16_out) { // bf16 - a->k(x86::k(1)).vpaddd(out_vreg, out_vreg, ones_vreg); - a->k(x86::k(1)).vpsrld(out_vreg, out_vreg, 16); - a->k(x86::k(1)).vpmovdw(dst_addr, out_vreg); + a->k(x86::k1).vpaddd(out_vreg, out_vreg, ones_vreg); + a->k(x86::k1).vpsrld(out_vreg, out_vreg, 16); + a->k(x86::k1).vpmovdw(dst_addr, out_vreg); } } else { if (is_fp16_out) { @@ -895,30 +888,22 @@ GenEmbeddingSpMDMLookup< } if (has_weight) { - a->imul( - scratchReg1_, - lengths_R_, - static_cast(sizeof(float))); + a->imul(scratchReg1_, lengths_R_, sizeof(float)); a->sub(weights, scratchReg1_); if (vec_idx + unroll_factor < num_vec_regs_per_block) { - a->imul( - scratchReg1_, - static_cast(sizeof(indxType) / sizeof(float))); + a->imul(scratchReg1_, sizeof(indxType) / sizeof(float)); a->sub(indices, scratchReg1_); } } else { - a->imul( - scratchReg1_, - lengths_R_, - static_cast(sizeof(indxType))); + a->imul(scratchReg1_, lengths_R_, sizeof(indxType)); a->sub(indices, scratchReg1_); } } } - a->add(lengths, static_cast(sizeof(offsetType))); - a->add(out, static_cast(output_stride * sizeof(outType))); + a->add(lengths, sizeof(offsetType)); + a->add(out, output_stride * sizeof(outType)); a->jmp(LoopRangeIndexBegin); a->bind(LoopRangeIndexEnd); @@ -936,7 +921,7 @@ GenEmbeddingSpMDMLookup< a->lea(x86::rsp, x86::ymmword_ptr(x86::rsp, vlen * sizeof(int32_t))); } - a->emitEpilog(frame); + a->emit_epilog(frame); // jit_fused8bitembedding_kernel fn; typename ReturnFunctionSignature< @@ -945,21 +930,13 @@ GenEmbeddingSpMDMLookup< offsetType, outType, ROWWISE_SPARSE>::jit_embedding_kernel fn; - asmjit::Error err = 0; - { - std::unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + asmjit::Error err = runtime.add(&fn, &code); + if (err != asmjit::Error::kOk) { std::cout << "Error: in fn add" << '\n'; return nullptr; } -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogFile); - delete codeLogger; -#endif return fn; }); } diff --git a/src/EmbeddingSpMDMNBit.cc b/src/EmbeddingSpMDMNBit.cc index a732caa5f3..b19483eb28 100644 --- a/src/EmbeddingSpMDMNBit.cc +++ b/src/EmbeddingSpMDMNBit.cc @@ -16,6 +16,7 @@ #include #include #include "./CodeCache.h" // @manual +#include "./CodeStorage.h" // @manual #include "./EmbeddingSpMDMAutovec.h" // @manual #include "./MaskAvx2.h" // @manual #include "./RefImplementations.h" // @manual @@ -101,17 +102,6 @@ class GenEmbeddingSpMDMNBitLookup { bool scale_bias_last, bool is_bf16_out); - private: - static asmjit::JitRuntime& runtime() { - static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, - // depends on other static - // variables. Required to prevent - // initialization order fiasco - return rt; - } - - inline static mutex rtMutex_; ///< Control access to runtime; - // The hash depends on bit_rate, embedding dimension (block size), weighted // sls, positional weights, normalize by lengths, prefetch distance, // use_offsets, output_stride, input_stride, and scale_bias_last @@ -181,8 +171,10 @@ GenEmbeddingSpMDMNBitLookup< int pref_dist = prefetch; constexpr bool areIndices64b = is_same_v; + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); + asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); #if defined(FBGEMM_LOG_CODE) @@ -210,15 +202,15 @@ GenEmbeddingSpMDMNBitLookup< } filename += ".txt"; FILE* codeLogFile = fopen(filename.c_str(), "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile); - code.setLogger(codeLogger); + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif // arguments to the function created - x86::Gp output_size = a->zdi(); + x86::Gp output_size = x86::rdi; // index_size will be overwritten to hold the end address of indices - x86::Gp index_size = a->zsi(); - x86::Gp data_size = a->zdx(); - x86::Gp input = a->zcx(); + x86::Gp index_size = x86::rsi; + x86::Gp data_size = x86::rdx; + x86::Gp input = x86::rcx; int reg_id = 8; x86::Gp indices = a->gpz(reg_id); // 8 ++reg_id; @@ -244,7 +236,7 @@ GenEmbeddingSpMDMNBitLookup< x86::Gp scratchReg2_ = a->gpz(reg_id); // 14 or 15 x86::Gp scratchReg3_; if constexpr (instSet == inst_set_t::avx2) { - scratchReg3_ = a->zax(); + scratchReg3_ = x86::rax; } asmjit::FuncDetail func; @@ -283,22 +275,22 @@ GenEmbeddingSpMDMNBitLookup< asmjit::FuncFrame frame; frame.init(func); - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | - asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | - asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15) | + asmjit::Support::bit_mask(16, 17, 18, 19, 20, 21, 22, 23) | + asmjit::Support::bit_mask(24, 25, 26, 27, 28, 29, 30, 31)); - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kGp, reg_id == 15 - ? asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) - : asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + ? asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15) + : asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14)); asmjit::FuncArgsAssignment args(&func); if constexpr (ROWWISE_SPARSE) { - args.assignAll( + args.assign_all( output_size, index_size, data_size, @@ -310,7 +302,7 @@ GenEmbeddingSpMDMNBitLookup< compressed_indices_table, scratchReg1_); } else { - args.assignAll( + args.assign_all( output_size, index_size, data_size, @@ -322,11 +314,11 @@ GenEmbeddingSpMDMNBitLookup< scratchReg1_); } - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + a->emit_prolog(frame); + a->emit_args_assignment(frame, args); constexpr int vlen = simd_info::WIDTH_32BIT_ELEMS; constexpr int NUM_VEC_REG = simd_info::NUM_VEC_REGS; @@ -353,9 +345,9 @@ GenEmbeddingSpMDMNBitLookup< vec_reg_t vlen_inv_vreg; // used for normalize by lengths -- 1/ lengths[i] vec_reg_t src_vreg; // for holding embedding value temporarily - Ymm mask_vreg; // mask for avx2 - Xmm mask2_vreg; - Xmm mask_fp16_vreg; + x86::Vec mask_vreg_ymm; // mask for avx2 + x86::Vec mask2_vreg_xmm; + x86::Vec mask_fp16_vreg_xmm; vec_reg_t ones_vreg; // We need 2 vec registers for 1. scale 2. bias @@ -405,10 +397,10 @@ GenEmbeddingSpMDMNBitLookup< if (remainder && instSet == inst_set_t::avx2) { // AVX512 doesn't need to use vector register for masking --unroll_factor; - mask_vreg = x86::ymm(unroll_factor); + mask_vreg_ymm = x86::ymm(unroll_factor); if (remainder > 1 && std::is_same_v) { --unroll_factor; - mask_fp16_vreg = x86::xmm(unroll_factor); + mask_fp16_vreg_xmm = x86::xmm(unroll_factor); } } @@ -416,7 +408,7 @@ GenEmbeddingSpMDMNBitLookup< if (remainder_32bit_granularity && instSet == inst_set_t::avx2) { // AVX512 doesn't need to use vector register for masking --unroll_factor; - mask2_vreg = x86::xmm(unroll_factor); + mask2_vreg_xmm = x86::xmm(unroll_factor); } if (normalize_by_lengths) { @@ -430,13 +422,13 @@ GenEmbeddingSpMDMNBitLookup< if (remainder) { if constexpr (instSet == inst_set_t::avx2) { a->vmovups( - mask_vreg, + mask_vreg_ymm, x86::ymmword_ptr( scratchReg1_, (vlen - remainder) % vlen * sizeof(int32_t))); if constexpr (std::is_same_v) { if (remainder > 1) { a->vmovups( - mask_fp16_vreg, + mask_fp16_vreg_xmm, x86::xmmword_ptr( scratchReg1_, (vlen - remainder / 2) * sizeof(int32_t))); @@ -449,7 +441,7 @@ GenEmbeddingSpMDMNBitLookup< } } else { a->mov(scratchReg1_, (1 << remainder) - 1); - a->kmovw(x86::k(1), scratchReg1_); + a->kmovw(x86::k1, scratchReg1_); } } @@ -465,14 +457,14 @@ GenEmbeddingSpMDMNBitLookup< for (int i = remainder_32bit_granularity; i < vlen / 2; i++) { a->mov(x86::dword_ptr(x86::rsp, i * sizeof(int32_t)), 0); } - a->vmovups(mask2_vreg, x86::dword_ptr(x86::rsp)); + a->vmovups(mask2_vreg_xmm, x86::xmmword_ptr(x86::rsp)); a->lea( x86::rsp, x86::dword_ptr( x86::rsp, (int32_t)((vlen / 2) * sizeof(int32_t)))); } else { a->mov(scratchReg1_, (1 << remainder_32bit_granularity) - 1); - a->kmovw(x86::k(2), scratchReg1_); + a->kmovw(x86::k2, scratchReg1_); } } @@ -480,10 +472,10 @@ GenEmbeddingSpMDMNBitLookup< a->lea( index_size, x86::ptr(indices, index_size, areIndices64b ? 3 : 2)); - asmjit::Label exit = a->newLabel(); - asmjit::Label error = a->newLabel(); - asmjit::Label LoopRangeIndexBegin = a->newLabel(); - asmjit::Label LoopRangeIndexEnd = a->newLabel(); + asmjit::Label exit = a->new_label(); + asmjit::Label error = a->new_label(); + asmjit::Label LoopRangeIndexBegin = a->new_label(); + asmjit::Label LoopRangeIndexEnd = a->new_label(); // rangeIndex loop begins (iterate output_size times) a->bind(LoopRangeIndexBegin); @@ -491,8 +483,8 @@ GenEmbeddingSpMDMNBitLookup< a->jl(LoopRangeIndexEnd); if (normalize_by_lengths) { - asmjit::Label IfLengthsBegin = a->newLabel(); - asmjit::Label IfLengthsEnd = a->newLabel(); + asmjit::Label IfLengthsBegin = a->new_label(); + asmjit::Label IfLengthsEnd = a->new_label(); a->bind(IfLengthsBegin); if (use_offsets) { a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType))); @@ -508,13 +500,13 @@ GenEmbeddingSpMDMNBitLookup< vec_reg_t temp_vreg0(0); if constexpr (instSet == inst_set_t::avx2) { a->mov(scratchReg1_, 1); - a->cvtsi2ss(vlen_inv_vreg.xmm(), scratchReg1_); - a->cvtsi2ss(temp_vreg0.xmm(), lengths_R_); - a->divss(vlen_inv_vreg.xmm(), temp_vreg0.xmm()); + a->vcvtsi2ss(vlen_inv_vreg.xmm(), vlen_inv_vreg.xmm(), scratchReg1_); + a->vcvtsi2ss(temp_vreg0.xmm(), temp_vreg0.xmm(), lengths_R_); + a->vdivss(vlen_inv_vreg.xmm(), vlen_inv_vreg.xmm(), temp_vreg0.xmm()); a->vpbroadcastd(vlen_inv_vreg, vlen_inv_vreg.xmm()); } else { a->mov(scratchReg1_, 1); - a->cvtsi2ss(temp_vreg0.xmm(), scratchReg1_); + a->vcvtsi2ss(temp_vreg0.xmm(), temp_vreg0.xmm(), scratchReg1_); a->vpbroadcastd(vlen_inv_vreg, temp_vreg0.xmm()); a->vpbroadcastd(temp_vreg0, lengths_R_); a->vcvtdq2ps(temp_vreg0, temp_vreg0); @@ -548,9 +540,9 @@ GenEmbeddingSpMDMNBitLookup< a->cmp(scratchReg1_, index_size); a->jg(error); - asmjit::Label LoopDataIndexBegin = a->newLabel(); - asmjit::Label LoopDataIndexEnd = a->newLabel(); - asmjit::Label ValidIndexLabel = a->newLabel(); + asmjit::Label LoopDataIndexBegin = a->new_label(); + asmjit::Label LoopDataIndexEnd = a->new_label(); + asmjit::Label ValidIndexLabel = a->new_label(); // dataIndex loop begins (iterate lengths_R_ times) a->bind(LoopDataIndexBegin); @@ -567,14 +559,14 @@ GenEmbeddingSpMDMNBitLookup< // When scale_bias_last == false, assume this is for table batched // embedding (TBE) that can get -1 for pruned rows. if (areIndices64b) { - a->cmp(scratchReg1_, static_cast(-1)); + a->cmp(scratchReg1_, -1); } else { - a->cmp(scratchReg1_.r32(), static_cast(-1)); + a->cmp(scratchReg1_.r32(), -1); } a->jne(ValidIndexLabel); - a->add(indices, static_cast(sizeof(indxType))); + a->add(indices, sizeof(indxType)); if (has_weight) { - a->add(weights, static_cast(sizeof(float))); + a->add(weights, sizeof(float)); } a->jmp(LoopDataIndexBegin); a->bind(ValidIndexLabel); @@ -597,8 +589,8 @@ GenEmbeddingSpMDMNBitLookup< int num_elem_per_byte = 8 / bit_rate; int fused_block_size = input_stride; if (pref_dist) { - asmjit::Label pref_dist_reset_start = a->newLabel(); - asmjit::Label pref_dist_reset_end = a->newLabel(); + asmjit::Label pref_dist_reset_start = a->new_label(); + asmjit::Label pref_dist_reset_end = a->new_label(); // out of bound handling for prefetch a->lea( scratchReg2_, x86::ptr(indices, pref_dist * sizeof(indxType))); @@ -629,8 +621,8 @@ GenEmbeddingSpMDMNBitLookup< a->bind(pref_dist_reset_end); if constexpr (ROWWISE_SPARSE) { asmjit::Label rowwise_sparse_pref_corner_case_begin = - a->newLabel(); - asmjit::Label rowwise_sparse_pref_corner_case_end = a->newLabel(); + a->new_label(); + asmjit::Label rowwise_sparse_pref_corner_case_end = a->new_label(); a->cmp(scratchReg2_, data_size); a->jae(rowwise_sparse_pref_corner_case_begin); @@ -650,22 +642,22 @@ GenEmbeddingSpMDMNBitLookup< a->bind(rowwise_sparse_pref_corner_case_end); } // This has to be fused_block_size - a->imul(scratchReg2_, static_cast(fused_block_size)); + a->imul(scratchReg2_, fused_block_size); } - a->add(indices, static_cast(sizeof(indxType))); + a->add(indices, sizeof(indxType)); if (has_weight) { a->vbroadcastss(w_vreg, x86::dword_ptr(weights)); - a->add(weights, static_cast(sizeof(float))); + a->add(weights, sizeof(float)); } if constexpr (ROWWISE_SPARSE) { - a->cmp(scratchReg1_.r32(), static_cast(-1)); + a->cmp(scratchReg1_.r32(), -1); a->je(LoopDataIndexBegin); } - a->imul(scratchReg1_, static_cast(fused_block_size)); + a->imul(scratchReg1_, fused_block_size); // broadcast the scale x86::Mem scale_src, bias_src; @@ -715,9 +707,9 @@ GenEmbeddingSpMDMNBitLookup< if (num_vec_regs_per_block - (vec_idx + v) < 4 && remainder_32bit_granularity) { if constexpr (instSet == inst_set_t::avx512) { - a->k(x86::k(2)).vmovups(src_vreg.ymm(), src_addr); + a->k(x86::k2).vmovups(src_vreg.ymm(), src_addr); } else { - a->vpmaskmovd(src_vreg.xmm(), mask2_vreg.xmm(), src_addr); + a->vpmaskmovd(src_vreg.xmm(), mask2_vreg_xmm.xmm(), src_addr); } a->vpmovzxbw(src_vreg, src_vreg.half()); } else { @@ -736,10 +728,10 @@ GenEmbeddingSpMDMNBitLookup< if (num_vec_regs_per_block - (vec_idx + v) < 4 && remainder_32bit_granularity) { if constexpr (instSet == inst_set_t::avx512) { - a->k(x86::k(2)).vmovups(src_vreg.xmm(), src_addr); + a->k(x86::k2).vmovups(src_vreg.xmm(), src_addr); a->vpmovzxbd(src_vreg, src_vreg.xmm()); } else { - a->vpmaskmovd(src_vreg.xmm(), mask2_vreg.xmm(), src_addr); + a->vpmaskmovd(src_vreg.xmm(), mask2_vreg_xmm.xmm(), src_addr); a->vpmovzxbd(src_vreg, src_vreg.xmm()); } } else { @@ -827,9 +819,9 @@ GenEmbeddingSpMDMNBitLookup< if constexpr (std::is_same_v) { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if constexpr (instSet == inst_set_t::avx512) { - a->k(x86::k(1)).vmovups(dst_addr, out_vreg); + a->k(x86::k1).vmovups(dst_addr, out_vreg); } else { - a->vmaskmovps(dst_addr, mask_vreg, out_vreg.ymm()); + a->vmaskmovps(dst_addr, mask_vreg_ymm, out_vreg.ymm()); } } else { a->vmovups(dst_addr, out_vreg); @@ -848,7 +840,7 @@ GenEmbeddingSpMDMNBitLookup< } if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (remainder > 1) { - a->vmaskmovps(dst_addr, mask_fp16_vreg, out_vreg.xmm()); + a->vmaskmovps(dst_addr, mask_fp16_vreg_xmm, out_vreg.xmm()); } if (remainder % 2 != 0) { a->vmovups(x86::xmmword_ptr(x86::rsp), out_vreg.xmm()); @@ -870,11 +862,11 @@ GenEmbeddingSpMDMNBitLookup< if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (is_bf16_out) { // bf16 - a->k(x86::k(1)).vpaddd(out_vreg, out_vreg, ones_vreg); - a->k(x86::k(1)).vpsrld(out_vreg, out_vreg, 16); - a->k(x86::k(1)).vpmovdw(dst_addr, out_vreg); + a->k(x86::k1).vpaddd(out_vreg, out_vreg, ones_vreg); + a->k(x86::k1).vpsrld(out_vreg, out_vreg, 16); + a->k(x86::k1).vpmovdw(dst_addr, out_vreg); } else { - a->k(x86::k(1)).vcvtps2ph(dst_addr, out_vreg, 8); + a->k(x86::k1).vcvtps2ph(dst_addr, out_vreg, 8); } } else { if (is_bf16_out) { @@ -902,30 +894,22 @@ GenEmbeddingSpMDMNBitLookup< } if (has_weight) { - a->imul( - scratchReg1_, - lengths_R_, - static_cast(sizeof(float))); + a->imul(scratchReg1_, lengths_R_, sizeof(float)); a->sub(weights, scratchReg1_); if (vec_idx + unroll_factor < num_vec_regs_per_block) { - a->imul( - scratchReg1_, - static_cast(sizeof(indxType) / sizeof(float))); + a->imul(scratchReg1_, sizeof(indxType) / sizeof(float)); a->sub(indices, scratchReg1_); } } else { - a->imul( - scratchReg1_, - lengths_R_, - static_cast(sizeof(indxType))); + a->imul(scratchReg1_, lengths_R_, sizeof(indxType)); a->sub(indices, scratchReg1_); } } } - a->add(lengths, static_cast(sizeof(offsetType))); - a->add(out, static_cast(output_stride * sizeof(outType))); + a->add(lengths, sizeof(offsetType)); + a->add(out, output_stride * sizeof(outType)); a->jmp(LoopRangeIndexBegin); a->bind(LoopRangeIndexEnd); @@ -943,7 +927,7 @@ GenEmbeddingSpMDMNBitLookup< a->lea(x86::rsp, x86::ymmword_ptr(x86::rsp, vlen * sizeof(int32_t))); } - a->emitEpilog(frame); + a->emit_epilog(frame); // jit_fused8bitembedding_kernel fn; typename ReturnFunctionSignature< @@ -951,20 +935,13 @@ GenEmbeddingSpMDMNBitLookup< offsetType, outType, ROWWISE_SPARSE>::jit_embedding_kernel fn; - asmjit::Error err = 0; - { - unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + + asmjit::Error err = runtime.add(&fn, &code); + if (err != asmjit::Error::kOk) { cout << "Error: in fn add" << '\n'; return nullptr; } -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogFile); - delete codeLogger; -#endif return fn; }); } diff --git a/src/FbgemmI64.cc b/src/FbgemmI64.cc index c9d4096e2c..40b0213f37 100644 --- a/src/FbgemmI64.cc +++ b/src/FbgemmI64.cc @@ -17,6 +17,7 @@ #include #include +#include "./CodeStorage.h" // @manual #include "./GenerateKernel.h" // @manual #include "./RefImplementations.h" // @manual #include "fbgemm/PackingTraits-inl.h" @@ -97,11 +98,11 @@ void CodeGenBase::storeCRegs( VecT(i * colRegs + j), VecT(i * colRegs + j), x86::dword_ptr( - a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int64_t))); + x86::rcx, C_Offset, 0, j * vectorLen * sizeof(int64_t))); } a->vmovups( x86::dword_ptr( - a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int64_t)), + x86::rcx, C_Offset, 0, j * vectorLen * sizeof(int64_t)), VecT(i * colRegs + j)); } } @@ -141,8 +142,9 @@ CodeGenBase::getOrCreate( make_tuple(accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize); return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); #ifdef FBGEMM_LOG_CODE @@ -152,10 +154,8 @@ CodeGenBase::getOrCreate( accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize) .c_str(), "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code.setLogger(codeLogger); - } + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif const int maxMRegs [[maybe_unused]] = mRegBlockSize; @@ -169,12 +169,12 @@ CodeGenBase::getOrCreate( const int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); + x86::Gp buffer_A = x86::rdi; + x86::Gp buffer_B = x86::rsi; + x86::Gp B_pf = x86::rdx; + x86::Gp CBase = x86::rcx; + x86::Gp kSize = x86::r8; + x86::Gp ldcReg = x86::r9; asmjit::FuncDetail func; func.init( @@ -185,38 +185,38 @@ CodeGenBase::getOrCreate( asmjit::FuncFrame frame; frame.init(func); - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | - asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | - asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); - frame.setDirtyRegs( + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15) | + asmjit::Support::bit_mask(16, 17, 18, 19, 20, 21, 22, 23) | + asmjit::Support::bit_mask(24, 25, 26, 27, 28, 29, 30, 31)); + frame.set_dirty_regs( asmjit::RegGroup::kGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + args.assign_all(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + a->emit_prolog(frame); + a->emit_args_assignment(frame, args); - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->new_label(); + asmjit::Label LoopNBlocks = a->new_label(); + asmjit::Label Loopk = a->new_label(); - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); + x86::Gp buffer_B_saved = x86::r10; + x86::Gp C_Offset = x86::r11; + x86::Gp B_pf_saved = x86::r12; + x86::Gp iIdx = x86::r13; + x86::Gp jIdx = x86::r14; + x86::Gp kIdx = x86::r15; - a->imul(ldcReg, ldcReg, static_cast(sizeof(int64_t))); - a->imul(kSize, kSize, static_cast(sizeof(int64_t))); + a->imul(ldcReg, ldcReg, sizeof(int64_t)); + a->imul(kSize, kSize, sizeof(int64_t)); // save B_buffer address a->mov(buffer_B_saved, buffer_B); @@ -245,17 +245,17 @@ CodeGenBase::getOrCreate( a->bind(Loopk); // k is incremented by 1 - a->add(kIdx, static_cast(sizeof(int64_t))); + a->add(kIdx, sizeof(int64_t)); genComputeBlock( a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); // update buffer_A address for next k iteration - a->add(buffer_A, static_cast(sizeof(int64_t))); + a->add(buffer_A, sizeof(int64_t)); // update buffer_B address for next k iteration - a->add(buffer_B, static_cast(nBlock * sizeof(int64_t))); - a->add(B_pf, static_cast(nBlock * sizeof(int64_t))); + a->add(buffer_B, nBlock * sizeof(int64_t)); + a->add(B_pf, nBlock * sizeof(int64_t)); a->cmp(kIdx, kSize); a->jl(Loopk); @@ -269,16 +269,13 @@ CodeGenBase::getOrCreate( // B for next block a->mov(buffer_B, buffer_B_saved); // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast(nRegBlockSize * sizeof(int64_t))); + a->imul(C_Offset, jIdx, nRegBlockSize * sizeof(int64_t)); a->add(buffer_B, C_Offset); a->mov(B_pf, B_pf_saved); a->add(B_pf, C_Offset); // increment C for next B block - a->add(CBase, static_cast(nRegBlockSize * sizeof(int64_t))); + a->add(CBase, nRegBlockSize * sizeof(int64_t)); int jLoopTrips = currColRegs / maxNRegs; // jLoopTrips should be at least 1 @@ -287,16 +284,11 @@ CodeGenBase::getOrCreate( a->jl(LoopNBlocks); // increment A for next block - a->add( - buffer_A, - static_cast(rowRegs * kBlock * sizeof(int64_t))); + a->add(buffer_A, rowRegs * kBlock * sizeof(int64_t)); // increment C for next A block - a->sub( - CBase, - static_cast( - jLoopTrips * nRegBlockSize * sizeof(int64_t))); - a->imul(C_Offset, ldcReg, static_cast(rowRegs)); + a->sub(CBase, jLoopTrips * nRegBlockSize * sizeof(int64_t)); + a->imul(C_Offset, ldcReg, rowRegs); a->add(CBase, C_Offset); // reset B @@ -308,8 +300,8 @@ CodeGenBase::getOrCreate( // generate code for remainder if (mRegBlocksRem > 0) { assert(false); - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); + asmjit::Label LoopNRem = a->new_label(); + asmjit::Label LoopkRem = a->new_label(); int rowRegs = mRegBlocksRem; a->xor_(jIdx.r32(), jIdx.r32()); @@ -324,17 +316,17 @@ CodeGenBase::getOrCreate( a->bind(LoopkRem); // k is incremented by 1 - a->add(kIdx, static_cast(sizeof(int64_t))); + a->add(kIdx, sizeof(int64_t)); genComputeBlock( a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); // update buffer_A address for next k iteration - a->add(buffer_A, static_cast(sizeof(int64_t))); + a->add(buffer_A, sizeof(int64_t)); // update buffer_B address for next k iteration - a->add(buffer_B, static_cast(nBlock * sizeof(int64_t))); - a->add(B_pf, static_cast(nBlock * sizeof(int64_t))); + a->add(buffer_B, nBlock * sizeof(int64_t)); + a->add(B_pf, nBlock * sizeof(int64_t)); a->cmp(kIdx, kSize); a->jl(LoopkRem); @@ -344,10 +336,7 @@ CodeGenBase::getOrCreate( // B for next block // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast(nRegBlockSize * sizeof(int64_t))); + a->imul(C_Offset, jIdx, nRegBlockSize * sizeof(int64_t)); a->mov(buffer_B, buffer_B_saved); a->add(buffer_B, C_Offset); a->mov(B_pf, B_pf_saved); @@ -357,7 +346,7 @@ CodeGenBase::getOrCreate( storeCRegs(a, rowRegs, colRegs, C_Offset, ldcReg, accum); // increment C for next B block - a->add(CBase, static_cast(nRegBlockSize * sizeof(int64_t))); + a->add(CBase, nRegBlockSize * sizeof(int64_t)); int jLoopTrips = currColRegs / maxNRegs; // jLoopTrips should be at least 1 @@ -366,24 +355,15 @@ CodeGenBase::getOrCreate( a->jl(LoopNRem); } - a->emitEpilog(frame); + a->emit_epilog(frame); jit_micro_kernel_fp fn = nullptr; - asmjit::Error err = 0; - { - unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + asmjit::Error err = runtime.add(&fn, &code); + if (err != asmjit::Error::kOk) { cout << "Error: in fn add" << '\n'; return nullptr; } -#ifdef FBGEMM_LOG_CODE - fclose(codeLogfile); - delete codeLogger; -#endif - return fn; }); } diff --git a/src/GenerateI8Depthwise.cc b/src/GenerateI8Depthwise.cc index c13885aa7d..0a63b59b69 100644 --- a/src/GenerateI8Depthwise.cc +++ b/src/GenerateI8Depthwise.cc @@ -14,21 +14,12 @@ #include "./CodeCache.h" // @manual #include "./CodeGenHelpers.h" // @manual +#include "./CodeStorage.h" // @manual #include "fbgemm/Utils.h" namespace fbgemm { namespace { -asmjit::JitRuntime& runtime() { - static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, - // depents on other static - // variables. Required to prevent - // initialization order fiasco - return rt; -} - -// Controll access to runtime; -std::mutex rtMutex_; // The hash depends on D, K_T, K_H, K_W, oc_per_g, compute_a_sum, // remainder, prev_skip, next_skip, top_skip, bottom_skip, left_skip, and @@ -52,19 +43,23 @@ namespace x86 = asmjit::x86; // c3_v: c[12:16], c[28:32] static void genMaddEpi16xNPacked( x86::Emitter* e, - Ymm a[4], + x86::Vec a[4], const x86::Gp& b, - Ymm c[4], - Ymm* a_sum, + x86::Vec c[4], + x86::Vec* a_sum, int n, int remainder, bool accumulation, - const Ymm& one_epi8, - const Ymm& one_epi16, - const Ymm& zero) { + const x86::Vec& one_epi8, + const x86::Vec& one_epi16, + const x86::Vec& zero) { // Interleave inputs corresponding to 4 filter positions. // Reuse a[1] and a[3] to save registers - Ymm a01_lo(0), a01_hi(1), a23_lo(a[1]), a23_hi(a[3]); + x86::Vec a01_lo = x86::ymm0; + x86::Vec a01_hi = x86::ymm1; + x86::Vec a23_lo(a[1]); + x86::Vec a23_hi(a[3]); + e->vpunpcklbw(a01_lo, a[0], n == 1 ? zero : a[1]); if (remainder >= 8) { e->vpunpckhbw(a01_hi, a[0], n == 1 ? zero : a[1]); @@ -203,8 +198,9 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( right_skip); return codeCache_.getOrCreate(kernelSig, [&]() -> jit_kernel_signature { + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* e = assembler.as(); #ifdef FBGEMM_LOG_CODE @@ -212,7 +208,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( for (int i = 3 - D; i < 3; ++i) { filename += std::to_string(K[i]); if (i < 2) { - filename += "x" + filename += "x"; } } filename += "_" + std::to_string(oc_per_g); @@ -242,22 +238,22 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( } filename += ".txt"; FILE* codeLogFile = fopen(filename.c_str(), "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile); - code.setLogger(codeLogger); + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif - x86::Gp a_addr = e->zdi(); - x86::Gp b_addr = e->zsi(); - x86::Gp c_addr = e->zdx(); - x86::Gp a_sum_addr = e->zcx(); - x86::Gp h = e->gpz(8); - x86::Gp w = e->gpz(9); - x86::Gp ic = e->gpz(10); - x86::Gp mask_addr = e->gpz(11); - x86::Gp a_zero_point = e->gpz(12); - x86::Gp b_zero_point_addr = e->gpz(13); - x86::Gp ic_loop_count = e->gpz(14); - x86::Gp a_addr_save = e->gpz(15); + x86::Gp a_addr = x86::rdi; + x86::Gp b_addr = x86::rsi; + x86::Gp c_addr = x86::rdx; + x86::Gp a_sum_addr = x86::rcx; + x86::Gp h = x86::r8; + x86::Gp w = x86::r9; + x86::Gp ic = x86::r10; + x86::Gp mask_addr = x86::r11; + x86::Gp a_zero_point = x86::r12; + x86::Gp b_zero_point_addr = x86::r13; + x86::Gp ic_loop_count = x86::r14; + x86::Gp a_addr_save = x86::r15; asmjit::FuncDetail func; func.init( @@ -278,16 +274,16 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( asmjit::FuncFrame frame; frame.init(func); - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.set_dirty_regs( asmjit::RegGroup::kGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); asmjit::FuncArgsAssignment args(&func); - args.assignAll( + args.assign_all( a_addr, b_addr, c_addr, @@ -299,31 +295,31 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( a_zero_point, b_zero_point_addr); - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - e->emitProlog(frame); - e->emitArgsAssignment(frame, args); + e->emit_prolog(frame); + e->emit_args_assignment(frame, args); // Assign vector registers - Ymm a[4]; - Ymm c[4]; - Ymm a_sum[2]; + x86::Vec a[4]; + x86::Vec c[4]; + x86::Vec a_sum[2]; - int vreg_id = 2; // reserve 2 for temp vreg + unsigned vreg_id = 2; // reserve 2 for temp vreg for (int i = 0; i < 4; ++i, ++vreg_id) { - a[i] = Ymm(vreg_id); + a[i] = x86::ymm(vreg_id); } for (int i = 0; i < 4; ++i, ++vreg_id) { - c[i] = Ymm(vreg_id); + c[i] = x86::ymm(vreg_id); } if (compute_a_sum) { - a_sum[0] = Ymm(vreg_id); + a_sum[0] = x86::ymm(vreg_id); ++vreg_id; - a_sum[1] = Ymm(vreg_id); + a_sum[1] = x86::ymm(vreg_id); ++vreg_id; } - Ymm mask_vreg(vreg_id); + x86::Vec mask_vreg = x86::ymm(vreg_id); constexpr int vlen = simd_info::WIDTH_32BIT_ELEMS; if (remainder != simd_info::WIDTH_BYTES) { ++vreg_id; @@ -333,17 +329,17 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( mask_addr, (vlen - remainder / 4 / oc_per_g) % vlen * sizeof(int32_t))); } - Ymm one_epi8(vreg_id); + x86::Vec one_epi8 = x86::ymm(vreg_id); if (compute_a_sum) { ++vreg_id; - gen8BitVectorOne(e, one_epi8); + gen8BitVectorOne(e, one_epi8); } int K = std::accumulate(F.begin(), F.end(), 1, std::multiplies()); - Ymm one_epi16(vreg_id); + x86::Vec one_epi16 = x86::ymm(vreg_id); if (K > 2) { ++vreg_id; - gen16BitVectorOne(e, one_epi16); + gen16BitVectorOne(e, one_epi16); } bool has_pad = prev_skip || next_skip || top_skip || bottom_skip || @@ -352,15 +348,16 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( // When out of registers, zero and A_zero_point_vreg need to share. bool recompute_zero = vreg_id == 15 && need_zero; - Ymm a_zero_point_vreg(vreg_id); + x86::Vec a_zero_point_vreg = x86::ymm(vreg_id); if (!recompute_zero && has_pad) { - e->movq(a_zero_point_vreg.half(), a_zero_point); + e->vmovq(a_zero_point_vreg.half(), a_zero_point); e->vpbroadcastb(a_zero_point_vreg, a_zero_point_vreg.half()); } if (vreg_id < 15) { ++vreg_id; } - Ymm zero(vreg_id); + + x86::Vec zero = x86::ymm(vreg_id); if (need_zero && (!recompute_zero || !has_pad)) { e->vpxor(zero.xmm(), zero.xmm(), zero.xmm()); } @@ -378,11 +375,11 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( e->sub(w, a_addr_save); // w * ic - F[2] * ic e->mov(ic_loop_count, ic); - e->add(ic_loop_count, asmjit::Imm(32 / oc_per_g - 1)); - e->sar(ic_loop_count, asmjit::Imm(oc_per_g == 1 ? 5 : 4)); + e->add(ic_loop_count, 32 / oc_per_g - 1); + e->sar(ic_loop_count, oc_per_g == 1 ? 5 : 4); e->mov(a_addr_save, a_addr); - asmjit::Label ic_loop_begin = e->newLabel(), ic_loop_end = e->newLabel(); + asmjit::Label ic_loop_begin = e->new_label(), ic_loop_end = e->new_label(); // main_loop == false: the last vector iteration across input channels for (bool main_loop : {true, false}) { @@ -393,7 +390,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( } if (recompute_zero && has_pad) { - e->movq(a_zero_point_vreg.half(), a_zero_point); + e->vmovq(a_zero_point_vreg.half(), a_zero_point); e->vpbroadcastb(a_zero_point_vreg, a_zero_point_vreg.half()); } @@ -504,7 +501,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( e->vmovups(x86::ymmword_ptr(a_sum_addr), a[0]); } else { // Rollback duplication - e->vpsrld(a_sum[0], a_sum[0], asmjit::Imm(16)); + e->vpsrld(a_sum[0], a_sum[0], 16); e->vmovups(x86::xmmword_ptr(a_sum_addr), a_sum[0].half()); } @@ -514,13 +511,13 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( e->vmovups(x86::ymmword_ptr(a_sum_addr, 32), a[1]); } else { // Rollback duplication - e->vpsrld(a_sum[1], a_sum[1], asmjit::Imm(16)); + e->vpsrld(a_sum[1], a_sum[1], 16); e->vmovups(x86::xmmword_ptr(a_sum_addr, 16), a_sum[1].half()); } } if (main_loop || remainder >= 16) { - e->vextracti128(a_sum[0].half(), a_sum[0], asmjit::Imm(1)); + e->vextracti128(a_sum[0].half(), a_sum[0], 1); if (oc_per_g == 1) { e->vpmovsxwd(a_sum[0], a_sum[0].half()); e->vmovups(x86::ymmword_ptr(a_sum_addr, 64), a_sum[0]); @@ -530,7 +527,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( } if (main_loop || remainder >= 24) { - e->vextracti128(a_sum[1].half(), a_sum[1], asmjit::Imm(1)); + e->vextracti128(a_sum[1].half(), a_sum[1], 1); if (oc_per_g == 1) { e->vpmovsxwd(a_sum[1], a_sum[1].half()); e->vmovups(x86::ymmword_ptr(a_sum_addr, 96), a_sum[1]); @@ -540,13 +537,13 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( } if (main_loop) { - e->add(a_sum_addr, asmjit::Imm(128 / oc_per_g)); + e->add(a_sum_addr, 128 / oc_per_g); } } if (main_loop) { - e->add(c_addr, asmjit::Imm(128)); - e->add(a_addr_save, asmjit::Imm(32 / oc_per_g)); + e->add(c_addr, 128); + e->add(a_addr_save, 32 / oc_per_g); e->mov(a_addr, a_addr_save); e->jmp(ic_loop_begin); @@ -554,24 +551,16 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( } } - e->emitEpilog(frame); + e->emit_epilog(frame); jit_kernel_signature fn = nullptr; - asmjit::Error err = 0; - { - std::unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + asmjit::Error err = runtime.add(&fn, &code); + + if (err != asmjit::Error::kOk) { std::cout << "Error: in fn add" << '\n'; return nullptr; } -#ifdef FBGEMM_LOG_CODE - fclose(codeLogFile); - delete codeLogger; -#endif - return fn; }); } diff --git a/src/GenerateKernel.cc b/src/GenerateKernel.cc index 8d4c6ece0f..30644fd88d 100644 --- a/src/GenerateKernel.cc +++ b/src/GenerateKernel.cc @@ -17,15 +17,12 @@ namespace x86 = asmjit::x86; * Accumulation kernel. */ void initCRegs(x86::Emitter* a, int rowRegs, int colRegs) { - using CRegs = Xmm; // Take advantage of implicit zeroing out // i.e., zero out xmm and ymm will be zeroed out too for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { - a->vpxor( - CRegs(i * colRegs + j), - CRegs(i * colRegs + j), - CRegs(i * colRegs + j)); + x86::Vec reg = x86::xmm(unsigned(i * colRegs + j)); + a->vpxor(reg, reg, reg); } } } diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index 975eccc6b0..4f99f973fe 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -130,16 +130,6 @@ class CodeGenBase { } private: - static asmjit::JitRuntime& runtime() { - static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, - // depents on other static - // variables. Required to prevent - // initialization order fiasco - return rt; - } - - inline static std::mutex rtMutex_; ///< Controll access to runtime; - // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr inline static CodeCache< std::tuple, diff --git a/src/GenerateKernelDirectConvU8S8S32ACC32.cc b/src/GenerateKernelDirectConvU8S8S32ACC32.cc index 92a28b5027..df512e0733 100644 --- a/src/GenerateKernelDirectConvU8S8S32ACC32.cc +++ b/src/GenerateKernelDirectConvU8S8S32ACC32.cc @@ -7,7 +7,9 @@ */ #include +#include "./CodeStorage.h" // @manual #include "./CodeGenHelpers.h" // @manual +#include "./CodeStorage.h" // @manual #include "./DirectConv.h" // @manual namespace fbgemm { @@ -62,10 +64,10 @@ void DirectConvCodeGenBase::storeCRegs( VecT(i * colRegs + j), VecT(i * colRegs + j), x86::dword_ptr( - a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t))); + x86::rcx, C_Offset, 0, j * vectorLen * sizeof(int8_t))); } a->vmovups( - x86::dword_ptr(a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t)), + x86::dword_ptr(x86::rcx, C_Offset, 0, j * vectorLen * sizeof(int8_t)), VecT(i * colRegs + j)); } } @@ -100,7 +102,7 @@ void DirectConvCodeGenBase:: for (int j = 0; j < colRegs; ++j) { // load B - emitLoadDWord( + emitVecMove( a, BReg, x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t))); // load A, broadcast and fmas for (int i = 0; i < rowRegs; ++i) { @@ -190,8 +192,9 @@ DirectConvCodeGenBase::getOrCreateDirectConv( accum, O1, i1Xich, strideXich, i1Xich, mRegBlockSize, nRegBlockSize); return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); #if defined(FBGEMM_LOG_CODE) @@ -201,10 +204,8 @@ DirectConvCodeGenBase::getOrCreateDirectConv( accum, O1, i1Xich, strideXich, i1Xich, mRegBlockSize, nRegBlockSize) .c_str(), "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code.setLogger(codeLogger); - } + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif const int maxMRegs [[maybe_unused]] = mRegBlockSize; @@ -217,12 +218,12 @@ DirectConvCodeGenBase::getOrCreateDirectConv( int O1RegBlocksRem = O1 % mRegBlockSize; // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp ichXk1 = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); + x86::Gp buffer_A = x86::rdi; + x86::Gp buffer_B = x86::rsi; + x86::Gp B_pf = x86::rdx; + x86::Gp CBase = x86::rcx; + x86::Gp ichXk1 = x86::r8; + x86::Gp ldcReg = x86::r9; asmjit::FuncDetail func; func.init( @@ -233,43 +234,43 @@ DirectConvCodeGenBase::getOrCreateDirectConv( asmjit::FuncFrame frame; frame.init(func); - auto dirtyVecRegs = asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15); + auto dirtyVecRegs = asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15); if (numRegs >= 16) { - dirtyVecRegs |= asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | - asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31); + dirtyVecRegs |= asmjit::Support::bit_mask(16, 17, 18, 19, 20, 21, 22, 23) | + asmjit::Support::bit_mask(24, 25, 26, 27, 28, 29, 30, 31); } - frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs); - frame.setDirtyRegs( + frame.set_dirty_regs(asmjit::RegGroup::kVec, dirtyVecRegs); + frame.set_dirty_regs( asmjit::RegGroup::kGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, ichXk1, ldcReg); + args.assign_all(buffer_A, buffer_B, B_pf, CBase, ichXk1, ldcReg); - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + a->emit_prolog(frame); + a->emit_args_assignment(frame, args); - asmjit::Label LoopMBlocks = a->newLabel(); - // asmjit::Label LoopOBlocks = a->newLabel(); - // asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label LoopMBlocks = a->new_label(); + // asmjit::Label LoopOBlocks = a->new_label(); + // asmjit::Label LoopNBlocks = a->new_label(); - const x86::Gp& buffer_B_saved = a->gpz(10); - const x86::Gp& C_Offset = a->gpz(11); - // const x86::Gp& B_pf_saved = a->gpz(12); - const x86::Gp& iIdx = a->gpz(13); - // const x86::Gp& jIdx = a->gpz(14); - const x86::Gp& kIdx = a->gpz(15); - // const x86::Gp& B_pf = a->gpz(8); + const x86::Gp& buffer_B_saved = x86::r10; + const x86::Gp& C_Offset = x86::r11; + // const x86::Gp& B_pf_saved = x86::r12; + const x86::Gp& iIdx = x86::r13; + // const x86::Gp& jIdx = x86::r14; + const x86::Gp& kIdx = x86::r15; + // const x86::Gp& B_pf = x86::r8; VecRegT oneReg(numRegs - 3); - gen16BitVectorOne(a, oneReg); - a->imul(ldcReg, ldcReg, static_cast(sizeof(int32_t))); + gen16BitVectorOne(a, oneReg); + a->imul(ldcReg, ldcReg, sizeof(int32_t)); // a->xor_(C_Offset.r32(), C_Offset.r32()); // a->mov(B_pf_saved, B_pf); @@ -279,8 +280,8 @@ DirectConvCodeGenBase::getOrCreateDirectConv( auto issueLoopOverK = [&](int rowRegs) { // loopKLabel: corresponds to loop "r" where r = 0 // loopK0Label: corresponds to loop "r" where r = 1 - asmjit::Label LoopKLabel = a->newLabel(); - asmjit::Label LoopK0Label = a->newLabel(); + asmjit::Label LoopKLabel = a->new_label(); + asmjit::Label LoopK0Label = a->new_label(); // Init C (result) vector registers initCRegs(a, rowRegs, colRegs); @@ -292,7 +293,7 @@ DirectConvCodeGenBase::getOrCreateDirectConv( a->bind(LoopKLabel); // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); + a->add(kIdx, row_interleave); // this ComputeBlock generates code correspondent to // the above psedu-code since the kernel_height loop (loop "r"). @@ -303,25 +304,24 @@ DirectConvCodeGenBase::getOrCreateDirectConv( a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, strideXich); // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + a->add(buffer_A, row_interleave * sizeof(uint8_t)); // update buffer_B address for next k iteration - a->add(buffer_B, static_cast(8 * sizeof(int32_t))); - a->add(B_pf, static_cast(8 * sizeof(int32_t))); + a->add(buffer_B, 8 * sizeof(int32_t)); + a->add(B_pf, 8 * sizeof(int32_t)); a->cmp(kIdx, ichXk1); a->jl(LoopKLabel); a->sub(buffer_A, ichXk1); - a->add(buffer_A, static_cast(i1Xich)); + a->add(buffer_A, i1Xich); a->xor_(kIdx.r32(), kIdx.r32()); a->bind(LoopK0Label); // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); + a->add(kIdx, row_interleave); // this ComputeBlock generates code that corresponds // to the kernel_height loop (loop "r") in the psedu-code above. @@ -331,12 +331,11 @@ DirectConvCodeGenBase::getOrCreateDirectConv( a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, strideXich); // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + a->add(buffer_A, row_interleave * sizeof(uint8_t)); // update buffer_B address for next k iteration - a->add(buffer_B, static_cast(8 * sizeof(int32_t))); - a->add(B_pf, static_cast(8 * sizeof(int32_t))); + a->add(buffer_B, 8 * sizeof(int32_t)); + a->add(B_pf, 8 * sizeof(int32_t)); a->cmp(kIdx, ichXk1); a->jl(LoopK0Label); @@ -364,12 +363,10 @@ DirectConvCodeGenBase::getOrCreateDirectConv( int rowRegs = mRegBlockSize; // reset A - a->sub(buffer_A, static_cast(i1Xich)); + a->sub(buffer_A, i1Xich); // increment A for next block - a->add( - buffer_A, - static_cast(rowRegs * strideXich * sizeof(uint8_t))); + a->add(buffer_A, rowRegs * strideXich * sizeof(uint8_t)); // B for next block a->mov(buffer_B, buffer_B_saved); @@ -377,7 +374,7 @@ DirectConvCodeGenBase::getOrCreateDirectConv( // increment C for next B block // ldcReg already multiplied with 4 (sizeof(int32_t)) a->imul( - C_Offset, ldcReg, static_cast(rowRegs * sizeof(int8_t))); + C_Offset, ldcReg, rowRegs * sizeof(int8_t)); a->add(CBase, C_Offset); // a->add(CBase, static_cast(12*16*4)); @@ -392,24 +389,16 @@ DirectConvCodeGenBase::getOrCreateDirectConv( issueLoopOverK(O1RegBlocksRem); } - a->emitEpilog(frame); + a->emit_epilog(frame); jit_micro_kernel_fp fn = nullptr; - asmjit::Error err = 0; - { - std::unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + asmjit::Error err = runtime.add(&fn, &code); + + if (err != asmjit::Error::kOk) { std::cout << "Error: in fn add" << '\n'; return nullptr; } -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; -#endif - return fn; }); } @@ -438,9 +427,9 @@ void DirectConvCodeGenBase::storeCRegsTrans( a->vpaddd( VecT(i * colRegs + j), VecT(i * colRegs + j), - x86::dword_ptr(a->zcx(), C_offset)); + x86::dword_ptr(x86::rcx, C_offset)); } - a->vmovups(x86::dword_ptr(a->zcx(), C_offset), VecT(i * colRegs + j)); + a->vmovups(x86::dword_ptr(x86::rcx, C_offset), VecT(i * colRegs + j)); a->add(C_offset, ldcReg); } a->add(C_offset, o1XocReg); @@ -509,7 +498,7 @@ void DirectConvCodeGenBase:: for (int i = 0; i < rowRegs; ++i) { for (int j = 0; j < colRegs; ++j) { // load B, broadcast and fmas - emitLoadDWord( + emitVecMove( a, BReg, x86::dword_ptr(buffer_B, C_offset, 3, 0)); a->vpmaddubsw(res1, AReg, BReg); a->vpmaddwd(res1, oneReg, res1); @@ -611,8 +600,9 @@ DirectConvCodeGenBase:: auto kernelSig = std::make_tuple(accum, stride, mRegBlockSize, nRegBlockSize); return codeCacheT_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp_convT { + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); #if defined(FBGEMM_LOG_CODE) @@ -621,10 +611,8 @@ DirectConvCodeGenBase:: getCodeLoggingFile(accum, stride, mRegBlockSize, nRegBlockSize) .c_str(), "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code.setLogger(codeLogger); - } + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif const int maxMRegs [[maybe_unused]] = mRegBlockSize; @@ -634,13 +622,13 @@ DirectConvCodeGenBase:: "MRegs x NRegs is above available registers (MAX_REGS - 4)"); // arguments to the function created - const x86::Gp& buffer_A = a->zdi(); - const x86::Gp& buffer_B = a->zsi(); - const x86::Gp& CBase = a->zcx(); - const x86::Gp& ic = a->gpz(8); - const x86::Gp& ldcReg = a->gpz(9); - const x86::Gp& o1Xoc = a->gpz(10); - const x86::Gp& i1 = a->gpz(11); + const x86::Gp& buffer_A = x86::rdi; + const x86::Gp& buffer_B = x86::rsi; + const x86::Gp& CBase = x86::rcx; + const x86::Gp& ic = x86::r8; + const x86::Gp& ldcReg = x86::r9; + const x86::Gp& o1Xoc = x86::r10; + const x86::Gp& i1 = x86::r11; asmjit::FuncDetail func; func.init( @@ -651,43 +639,43 @@ DirectConvCodeGenBase:: asmjit::FuncFrame frame; frame.init(func); - auto dirtyVecRegs = asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15); + auto dirtyVecRegs = asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15); if (numRegs >= 16) { - dirtyVecRegs |= asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | - asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31); + dirtyVecRegs |= asmjit::Support::bit_mask(16, 17, 18, 19, 20, 21, 22, 23) | + asmjit::Support::bit_mask(24, 25, 26, 27, 28, 29, 30, 31); } - frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs); - frame.setDirtyRegs( + frame.set_dirty_regs(asmjit::RegGroup::kVec, dirtyVecRegs); + frame.set_dirty_regs( asmjit::RegGroup::kGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, CBase, ic, ldcReg, o1Xoc, i1); + args.assign_all(buffer_A, buffer_B, CBase, ic, ldcReg, o1Xoc, i1); - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + a->emit_prolog(frame); + a->emit_args_assignment(frame, args); - asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopMBlocks = a->new_label(); - const x86::Gp& C_offset = a->gpz(12); - const x86::Gp& buffer_B_saved = a->gpz(13); - const x86::Gp& iIdx = a->gpz(14); - const x86::Gp& kIdx = a->gpz(15); + const x86::Gp& C_offset = x86::r12; + const x86::Gp& buffer_B_saved = x86::r13; + const x86::Gp& iIdx = x86::r14; + const x86::Gp& kIdx = x86::r15; VecRegT oneReg(numRegs - 3); - gen16BitVectorOne(a, oneReg); - a->imul(ldcReg, ldcReg, static_cast(sizeof(int32_t))); + gen16BitVectorOne(a, oneReg); + a->imul(ldcReg, ldcReg, sizeof(int32_t)); int colRegs = maxNRegs; auto issueLoopOverK = [&](int rowRegs) { - asmjit::Label LoopKLabel = a->newLabel(); + asmjit::Label LoopKLabel = a->new_label(); // Init C (result) vector registers initCRegs(a, rowRegs, colRegs); @@ -708,11 +696,10 @@ DirectConvCodeGenBase:: mColRegBlockSize); // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + a->add(buffer_A, row_interleave * sizeof(uint8_t)); // update buffer_B address for next k iteration - a->add(buffer_B, static_cast(8 * sizeof(int32_t))); + a->add(buffer_B, 8 * sizeof(int32_t)); a->cmp(kIdx, ic); a->jl(LoopKLabel); @@ -743,34 +730,23 @@ DirectConvCodeGenBase:: // B for next block a->mov(buffer_B, buffer_B_saved); // increment C for next B block - a->imul( - C_offset, - ldcReg, - static_cast(stride)); // ldcReg already multiplied by 4 + a->imul(C_offset, ldcReg, stride); // ldcReg already multiplied by 4 a->add(CBase, C_offset); a->cmp(iIdx, i1); a->jl(LoopMBlocks); } - a->emitEpilog(frame); + a->emit_epilog(frame); jit_micro_kernel_fp_convT fn = nullptr; - asmjit::Error err = 0; - { - std::unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + asmjit::Error err = runtime.add(&fn, &code); + + if (err != asmjit::Error::kOk) { std::cout << "Error: in fn add" << '\n'; return nullptr; } -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; -#endif - return fn; }); } diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc index f78d5a9a09..2671706d14 100644 --- a/src/GenerateKernelU8S8S32ACC16.cc +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -7,6 +7,7 @@ */ #include +#include "./CodeStorage.h" // @manual #include "./CodeGenHelpers.h" // @manual #include "./GenerateKernel.h" // @manual @@ -29,7 +30,6 @@ void CodeGenBase::genComputeBlock< int rowRegs, int colRegs, int lda) { - using CRegs = Ymm; static constexpr int vectorLen = simd_info::WIDTH_BYTES; // used for matrix A @@ -40,11 +40,12 @@ void CodeGenBase::genComputeBlock< a->vpbroadcastw( AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); for (int j = 0; j < colRegs; ++j) { + x86::Vec v = x86::ymm(i * colRegs + j); a->vpmaddubsw( tmpReg, AReg, x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t))); - a->vpaddsw(CRegs(i * colRegs + j), tmpReg, CRegs(i * colRegs + j)); + a->vpaddsw(v, tmpReg, v); // Prefetching is hurting performance in some cases // because prefetch instructions itself consumes a slot // in pipeline issue thus slowing down the kernel. @@ -75,14 +76,14 @@ void CodeGenBase::storeCRegs( auto extractDestHalf = extractDestFull.half(); for (int i = 0; i < rowRegs; ++i) { - a->imul(C_Offset, ldcReg, static_cast(i * sizeof(int32_t))); + a->imul(C_Offset, ldcReg, i * sizeof(int32_t)); for (int j = 0; j < colRegs; ++j) { for (int idx = 0; idx < 2; ++idx) { - emitExtractHalfVector( + emitExtractHalfVector( a, extractDestHalf, VecT(i * colRegs + j), idx); a->vpmovsxwd(extractDestFull, extractDestHalf); x86::Mem destAddr = - x86::dword_ptr(a->zcx(), C_Offset, 0, (j * 2 + idx) * vectorLen); + x86::dword_ptr(x86::rcx, C_Offset, 0, (j * 2 + idx) * vectorLen); if (accum) { a->vpaddd(extractDestFull, extractDestFull, destAddr); } @@ -135,8 +136,9 @@ CodeGenBase::getOrCreate( accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize); return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); @@ -147,10 +149,8 @@ CodeGenBase::getOrCreate( accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize) .c_str(), "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code.setLogger(codeLogger); - } + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif assert( @@ -171,12 +171,12 @@ CodeGenBase::getOrCreate( //"nc must be equal to the number of register blocks"); // arguments to the function created - const x86::Gp& buffer_A = a->zdi(); - const x86::Gp& buffer_B = a->zsi(); - const x86::Gp& B_pf = a->zdx(); - const x86::Gp& CBase = a->zcx(); - const x86::Gp& kSize = a->gpz(8); - const x86::Gp& ldcReg = a->gpz(9); + const x86::Gp& buffer_A = x86::rdi; + const x86::Gp& buffer_B = x86::rsi; + const x86::Gp& B_pf = x86::rdx; + const x86::Gp& CBase = x86::rcx; + const x86::Gp& kSize = x86::r8; + const x86::Gp& ldcReg = x86::r9; asmjit::FuncDetail func; func.init( @@ -186,31 +186,31 @@ CodeGenBase::getOrCreate( asmjit::FuncFrame frame; frame.init(func); - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.set_dirty_regs( asmjit::RegGroup::kGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14)); asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + args.assign_all(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + a->emit_prolog(frame); + a->emit_args_assignment(frame, args); - asmjit::Label Loopk = a->newLabel(); - asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label Loopk = a->new_label(); + asmjit::Label LoopMBlocks = a->new_label(); - const x86::Gp& buffer_B_saved = a->gpz(10); - const x86::Gp& C_Offset = a->gpz(11); - // const x86::Gp& B_pf_saved = a->gpz(12); - const x86::Gp& iIdx = a->gpz(13); - const x86::Gp& kIdx = a->gpz(14); + const x86::Gp& buffer_B_saved = x86::r10; + const x86::Gp& C_Offset = x86::r11; + // const x86::Gp& B_pf_saved = x86::r12; + const x86::Gp& iIdx = x86::r13; + const x86::Gp& kIdx = x86::r14; int colRegs = nc * row_interleave / vectorLen; if (mRegBlocks > 0) { @@ -233,19 +233,16 @@ CodeGenBase::getOrCreate( a->xor_(kIdx.r32(), kIdx.r32()); a->bind(Loopk); // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); + a->add(kIdx, row_interleave); genComputeBlock( a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + a->add(buffer_A, row_interleave * sizeof(uint8_t)); // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast(nBlock * row_interleave * sizeof(int8_t))); + a->add(buffer_B, nBlock * row_interleave * sizeof(int8_t)); // a->add(B_pf, static_cast(nBlock * row_interleave * // sizeof(int8_t))); @@ -258,14 +255,9 @@ CodeGenBase::getOrCreate( // increment A for next block a->sub(buffer_A, kSize); - a->add( - buffer_A, - static_cast((rowRegs)*kBlock * sizeof(uint8_t))); + a->add(buffer_A, rowRegs * kBlock * sizeof(uint8_t)); // increment C for next block - a->imul( - C_Offset, - ldcReg, - static_cast(rowRegs * sizeof(int32_t))); + a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t)); a->add(CBase, C_Offset); // reset B a->mov(buffer_B, buffer_B_saved); @@ -276,7 +268,7 @@ CodeGenBase::getOrCreate( } // generate code for remainder if (mRegBlocksRem > 0) { - asmjit::Label LoopkRem = a->newLabel(); + asmjit::Label LoopkRem = a->new_label(); int rowRegs = mRegBlocksRem; // init C registers @@ -287,19 +279,16 @@ CodeGenBase::getOrCreate( a->bind(LoopkRem); // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); + a->add(kIdx, row_interleave); genComputeBlock( a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + a->add(buffer_A, row_interleave * sizeof(uint8_t)); // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast(nBlock * row_interleave * sizeof(int8_t))); + a->add(buffer_B, nBlock * row_interleave * sizeof(int8_t)); // a->add(B_pf, static_cast(nBlock * row_interleave * // sizeof(int8_t))); @@ -311,24 +300,16 @@ CodeGenBase::getOrCreate( a, rowRegs, colRegs, C_Offset, ldcReg, accum); } - a->emitEpilog(frame); + a->emit_epilog(frame); jit_micro_kernel_fp fn = nullptr; - asmjit::Error err = 0; - { - std::unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + asmjit::Error err = runtime.add(&fn, &code); + + if (err != asmjit::Error::kOk) { std::cout << "Error: in fn add" << '\n'; return nullptr; } -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; -#endif - return fn; }); } diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index cdc3c03aa0..f52435859b 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -8,6 +8,7 @@ #include #include +#include "./CodeStorage.h" // @manual #include "./GenerateKernel.h" // @manual namespace fbgemm { @@ -101,8 +102,9 @@ CodeGenBase::getOrCreate( accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize); return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); @@ -113,10 +115,8 @@ CodeGenBase::getOrCreate( accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize) .c_str(), "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code.setLogger(codeLogger); - } + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif assert( @@ -133,12 +133,12 @@ CodeGenBase::getOrCreate( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); + x86::Gp buffer_A = x86::rdi; + x86::Gp buffer_B = x86::rsi; + x86::Gp B_pf = x86::rdx; + x86::Gp CBase = x86::rcx; + x86::Gp kSize = x86::r8; + x86::Gp ldcReg = x86::r9; asmjit::FuncDetail func; func.init( @@ -149,35 +149,35 @@ CodeGenBase::getOrCreate( asmjit::FuncFrame frame; frame.init(func); - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | - asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | - asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); - frame.setDirtyRegs( + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15) | + asmjit::Support::bit_mask(16, 17, 18, 19, 20, 21, 22, 23) | + asmjit::Support::bit_mask(24, 25, 26, 27, 28, 29, 30, 31)); + frame.set_dirty_regs( asmjit::RegGroup::kGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + args.assign_all(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + a->emit_prolog(frame); + a->emit_args_assignment(frame, args); - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->new_label(); + asmjit::Label LoopNBlocks = a->new_label(); + asmjit::Label Loopk = a->new_label(); - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - // x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); + x86::Gp buffer_B_saved = x86::r10; + x86::Gp C_Offset = x86::r11; + // x86::Gp B_pf_saved = x86::r12; + x86::Gp iIdx = x86::r13; + x86::Gp jIdx = x86::r14; + x86::Gp kIdx = x86::r15; // save B_buffer address a->mov(buffer_B_saved, buffer_B); @@ -205,19 +205,16 @@ CodeGenBase::getOrCreate( a->xor_(kIdx.r32(), kIdx.r32()); a->bind(Loopk); // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); + a->add(kIdx, row_interleave); genComputeBlock( a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + a->add(buffer_A, row_interleave * sizeof(uint8_t)); // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast(nBlock * row_interleave * sizeof(int8_t))); + a->add(buffer_B, nBlock * row_interleave * sizeof(int8_t)); // a->add(B_pf, static_cast(nBlock * row_interleave * // sizeof(int8_t))); @@ -233,15 +230,11 @@ CodeGenBase::getOrCreate( // B for next block a->mov(buffer_B, buffer_B_saved); // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast( - nRegBlockSize * row_interleave * sizeof(int8_t))); + a->imul(C_Offset, jIdx, nRegBlockSize * row_interleave * sizeof(int8_t)); a->add(buffer_B, C_Offset); // increment C for next block - a->add(CBase, static_cast(nRegBlockSize * sizeof(int32_t))); + a->add(CBase, nRegBlockSize * sizeof(int32_t)); int jLoopTrips = currColRegs / maxNRegs; // jLoopTrips should be at least 1 @@ -250,19 +243,11 @@ CodeGenBase::getOrCreate( a->jl(LoopNBlocks); // increment A for next block - a->add( - buffer_A, - static_cast((rowRegs)*kBlock * sizeof(uint8_t))); + a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t)); // increment C for next A block - a->sub( - CBase, - static_cast( - jLoopTrips * nRegBlockSize * sizeof(int32_t))); - a->imul( - C_Offset, - ldcReg, - static_cast(rowRegs * sizeof(int32_t))); + a->sub(CBase, jLoopTrips * nRegBlockSize * sizeof(int32_t)); + a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t)); a->add(CBase, C_Offset); // reset B @@ -274,8 +259,8 @@ CodeGenBase::getOrCreate( } // generate code for remainder if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); + asmjit::Label LoopNRem = a->new_label(); + asmjit::Label LoopkRem = a->new_label(); int rowRegs = mRegBlocksRem; a->xor_(jIdx.r32(), jIdx.r32()); @@ -290,19 +275,16 @@ CodeGenBase::getOrCreate( a->bind(LoopkRem); // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); + a->add(kIdx, row_interleave); genComputeBlock( a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + a->add(buffer_A, row_interleave * sizeof(uint8_t)); // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast(nBlock * row_interleave * sizeof(int8_t))); + a->add(buffer_B, nBlock * row_interleave * sizeof(int8_t)); // a->add(B_pf, static_cast(nBlock * row_interleave * // sizeof(int8_t))); @@ -315,18 +297,14 @@ CodeGenBase::getOrCreate( // B for next block a->mov(buffer_B, buffer_B_saved); // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast( - nRegBlockSize * row_interleave * sizeof(int8_t))); + a->imul(C_Offset, jIdx, nRegBlockSize * row_interleave * sizeof(int8_t)); a->add(buffer_B, C_Offset); // store C matrix storeCRegs(a, rowRegs, colRegs, C_Offset, ldcReg, accum); // increment C for next block - a->add(CBase, static_cast(nRegBlockSize * sizeof(int32_t))); + a->add(CBase, nRegBlockSize * sizeof(int32_t)); int jLoopTrips = currColRegs / maxNRegs; // jLoopTrips should be at least 1 @@ -335,24 +313,16 @@ CodeGenBase::getOrCreate( a->jl(LoopNRem); } - a->emitEpilog(frame); + a->emit_epilog(frame); jit_micro_kernel_fp fn = nullptr; - asmjit::Error err = 0; - { - std::unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + asmjit::Error err = runtime.add(&fn, &code); + + if (err != asmjit::Error::kOk) { std::cout << "Error: in fn add" << '\n'; return nullptr; } -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; -#endif - return fn; }); } diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index 6c1960b4b8..72a9175ac0 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -8,6 +8,7 @@ #include #include "./CodeGenHelpers.h" // @manual +#include "./CodeStorage.h" // @manual #include "./GenerateKernel.h" // @manual namespace fbgemm { @@ -46,7 +47,7 @@ void CodeGenBase::genComputeBlock( for (int j = 0; j < colRegs; ++j) { // load B - emitLoadDWord( + emitVecMove( a, BReg, x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t))); // load A, broadcast and fmas for (int i = 0; i < rowRegs; ++i) { @@ -88,10 +89,10 @@ void CodeGenBase::storeCRegs( VecT(i * colRegs + j), VecT(i * colRegs + j), x86::dword_ptr( - a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t))); + x86::rcx, C_Offset, 0, j * vectorLen * sizeof(int8_t))); } a->vmovups( - x86::dword_ptr(a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t)), + x86::dword_ptr(x86::rcx, C_Offset, 0, j * vectorLen * sizeof(int8_t)), VecT(i * colRegs + j)); } } @@ -140,8 +141,9 @@ CodeGenBase::getOrCreate( accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize); return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); #if defined(FBGEMM_LOG_CODE) @@ -151,10 +153,8 @@ CodeGenBase::getOrCreate( accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize) .c_str(), "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code.setLogger(codeLogger); - } + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif assert( @@ -170,12 +170,12 @@ CodeGenBase::getOrCreate( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - const x86::Gp& buffer_A = a->zdi(); - const x86::Gp& buffer_B = a->zsi(); - const x86::Gp& B_pf = a->zdx(); - const x86::Gp& CBase = a->zcx(); - const x86::Gp& kSize = a->gpz(8); - const x86::Gp& ldcReg = a->gpz(9); + const x86::Gp buffer_A = x86::rdi; + const x86::Gp buffer_B = x86::rsi; + const x86::Gp B_pf = x86::rdx; + const x86::Gp CBase = x86::rcx; + const x86::Gp kSize = x86::r8; + const x86::Gp ldcReg = x86::r9; asmjit::FuncDetail func; func.init( @@ -186,42 +186,42 @@ CodeGenBase::getOrCreate( asmjit::FuncFrame frame; frame.init(func); - auto dirtyVecRegs = asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15); + auto dirtyVecRegs = asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15); if (numRegs >= 16) { - dirtyVecRegs |= asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | - asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31); + dirtyVecRegs |= asmjit::Support::bit_mask(16, 17, 18, 19, 20, 21, 22, 23) | + asmjit::Support::bit_mask(24, 25, 26, 27, 28, 29, 30, 31); } - frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs); - frame.setDirtyRegs( + frame.set_dirty_regs(asmjit::RegGroup::kVec, dirtyVecRegs); + frame.set_dirty_regs( asmjit::RegGroup::kGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + args.assign_all(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + a->emit_prolog(frame); + a->emit_args_assignment(frame, args); - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label LoopMBlocks = a->new_label(); + asmjit::Label LoopNBlocks = a->new_label(); - const x86::Gp& buffer_B_saved = a->gpz(10); - const x86::Gp& C_Offset = a->gpz(11); - const x86::Gp& B_pf_saved = a->gpz(12); - const x86::Gp& iIdx = a->gpz(13); - const x86::Gp& jIdx = a->gpz(14); - const x86::Gp& kIdx = a->gpz(15); - // const x86::Gp& B_pf = a->gpz(8); + const x86::Gp& buffer_B_saved = x86::r10; + const x86::Gp& C_Offset = x86::r11; + const x86::Gp& B_pf_saved = x86::r12; + const x86::Gp& iIdx = x86::r13; + const x86::Gp& jIdx = x86::r14; + const x86::Gp& kIdx = x86::r15; + // const x86::Gp& B_pf = x86::r8; VecRegT oneReg(numRegs - 3); - gen16BitVectorOne(a, oneReg); - a->imul(ldcReg, ldcReg, static_cast(sizeof(int32_t))); + gen16BitVectorOne(a, oneReg); + a->imul(ldcReg, ldcReg, sizeof(int32_t)); // a->xor_(C_Offset.r32(), C_Offset.r32()); // save B_buffer address @@ -232,7 +232,7 @@ CodeGenBase::getOrCreate( int colRegs = std::min(currColRegs, maxNRegs); auto issueLoopOverK = [&](int rowRegs) { - asmjit::Label LoopKLabel = a->newLabel(); + asmjit::Label LoopKLabel = a->new_label(); // Init C (result) vector registers initCRegs(a, rowRegs, colRegs); @@ -242,22 +242,17 @@ CodeGenBase::getOrCreate( a->bind(LoopKLabel); // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); + a->add(kIdx, row_interleave); genComputeBlock( a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + a->add(buffer_A, row_interleave * sizeof(uint8_t)); // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast(nBlock * row_interleave * sizeof(int8_t))); + a->add(buffer_B, nBlock * row_interleave * sizeof(int8_t)); + a->add(B_pf, nBlock * row_interleave * sizeof(int8_t)); a->cmp(kIdx, kSize); a->jl(LoopKLabel); @@ -288,17 +283,13 @@ CodeGenBase::getOrCreate( // B for next block a->mov(buffer_B, buffer_B_saved); // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast( - nRegBlockSize * row_interleave * sizeof(int8_t))); + a->imul(C_Offset, jIdx, nRegBlockSize * row_interleave * sizeof(int8_t)); a->add(buffer_B, C_Offset); a->mov(B_pf, B_pf_saved); a->add(B_pf, C_Offset); // increment C for next B block - a->add(CBase, static_cast(nRegBlockSize * sizeof(int32_t))); + a->add(CBase, nRegBlockSize * sizeof(int32_t)); int jLoopTrips = currColRegs / maxNRegs; // jLoopTrips should be at least 1 @@ -307,16 +298,11 @@ CodeGenBase::getOrCreate( a->jl(LoopNBlocks); // increment A for next block - a->add( - buffer_A, - static_cast((rowRegs)*kBlock * sizeof(uint8_t))); + a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t)); // increment C for next A block - a->sub( - CBase, - static_cast( - jLoopTrips * nRegBlockSize * sizeof(int32_t))); - a->imul(C_Offset, ldcReg, static_cast(rowRegs)); + a->sub(CBase, jLoopTrips * nRegBlockSize * sizeof(int32_t)); + a->imul(C_Offset, ldcReg, rowRegs); a->add(CBase, C_Offset); // reset B @@ -327,7 +313,7 @@ CodeGenBase::getOrCreate( } // generate code for remainder if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopNRem = a->new_label(); a->xor_(jIdx.r32(), jIdx.r32()); a->bind(LoopNRem); @@ -336,7 +322,7 @@ CodeGenBase::getOrCreate( issueLoopOverK(mRegBlocksRem); // increment C for next B block - a->add(CBase, static_cast(nRegBlockSize * sizeof(int32_t))); + a->add(CBase, nRegBlockSize * sizeof(int32_t)); int jLoopTrips = currColRegs / maxNRegs; // jLoopTrips should be at least 1 @@ -345,24 +331,16 @@ CodeGenBase::getOrCreate( a->jl(LoopNRem); } - a->emitEpilog(frame); + a->emit_epilog(frame); jit_micro_kernel_fp fn = nullptr; - asmjit::Error err = 0; - { - std::unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + asmjit::Error err = runtime.add(&fn, &code); + + if (err != asmjit::Error::kOk) { std::cout << "Error: in fn add" << '\n'; return nullptr; } -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; -#endif - return fn; }); } diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc index 16b3b5ec77..4b8f696e41 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc @@ -7,6 +7,7 @@ */ #include +#include "./CodeStorage.h" // @manual #include "./GenerateKernel.h" // @manual namespace fbgemm { @@ -92,8 +93,9 @@ CodeGenBase::getOrCreate( accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize); return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); @@ -104,10 +106,8 @@ CodeGenBase::getOrCreate( accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize) .c_str(), "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code.setLogger(codeLogger); - } + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif assert( @@ -124,12 +124,12 @@ CodeGenBase::getOrCreate( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); + x86::Gp buffer_A = x86::rdi; + x86::Gp buffer_B = x86::rsi; + x86::Gp B_pf = x86::rdx; + x86::Gp CBase = x86::rcx; + x86::Gp kSize = x86::r8; + x86::Gp ldcReg = x86::r9; asmjit::FuncDetail func; func.init( @@ -140,36 +140,36 @@ CodeGenBase::getOrCreate( asmjit::FuncFrame frame; frame.init(func); - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | - asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | - asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); - frame.setDirtyRegs( + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15) | + asmjit::Support::bit_mask(16, 17, 18, 19, 20, 21, 22, 23) | + asmjit::Support::bit_mask(24, 25, 26, 27, 28, 29, 30, 31)); + frame.set_dirty_regs( asmjit::RegGroup::kGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + args.assign_all(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + a->emit_prolog(frame); + a->emit_args_assignment(frame, args); - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->new_label(); + asmjit::Label LoopNBlocks = a->new_label(); + asmjit::Label Loopk = a->new_label(); - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); - // x86::Gp B_pf = a->gpz(8); + x86::Gp buffer_B_saved = x86::r10; + x86::Gp C_Offset = x86::r11; + x86::Gp B_pf_saved = x86::r12; + x86::Gp iIdx = x86::r13; + x86::Gp jIdx = x86::r14; + x86::Gp kIdx = x86::r15; + // x86::Gp B_pf = x86::r8; auto oneReg = x86::zmm29; // create 16-bit 1s @@ -178,7 +178,7 @@ CodeGenBase::getOrCreate( // a->vpcmpeqw(oneReg, oneReg, oneReg); a->vpternlogd(oneReg, oneReg, oneReg, 0xff); a->vpsrlw(oneReg, oneReg, 15); - a->imul(ldcReg, ldcReg, static_cast(sizeof(int32_t))); + a->imul(ldcReg, ldcReg, sizeof(int32_t)); // save B_buffer address a->mov(buffer_B_saved, buffer_B); @@ -207,22 +207,17 @@ CodeGenBase::getOrCreate( a->bind(Loopk); // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); + a->add(kIdx, row_interleave); genComputeBlock( a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + a->add(buffer_A, row_interleave * sizeof(uint8_t)); // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast(nBlock * row_interleave * sizeof(int8_t))); + a->add(buffer_B, nBlock * row_interleave * sizeof(int8_t)); + a->add(B_pf, nBlock * row_interleave * sizeof(int8_t)); // a->add(B_pf, static_cast(32*sizeof(float))); @@ -238,17 +233,13 @@ CodeGenBase::getOrCreate( // B for next block a->mov(buffer_B, buffer_B_saved); // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast( - nRegBlockSize * row_interleave * sizeof(int8_t))); + a->imul(C_Offset, jIdx, nRegBlockSize * row_interleave * sizeof(int8_t)); a->add(buffer_B, C_Offset); a->mov(B_pf, B_pf_saved); a->add(B_pf, C_Offset); // increment C for next B block - a->add(CBase, static_cast(nRegBlockSize * sizeof(int32_t))); + a->add(CBase, nRegBlockSize * sizeof(int32_t)); int jLoopTrips = currColRegs / maxNRegs; // jLoopTrips should be at least 1 @@ -257,16 +248,11 @@ CodeGenBase::getOrCreate( a->jl(LoopNBlocks); // increment A for next block - a->add( - buffer_A, - static_cast((rowRegs)*kBlock * sizeof(uint8_t))); + a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t)); // increment C for next A block - a->sub( - CBase, - static_cast( - jLoopTrips * nRegBlockSize * sizeof(int32_t))); - a->imul(C_Offset, ldcReg, static_cast(rowRegs)); + a->sub(CBase,jLoopTrips * nRegBlockSize * sizeof(int32_t)); + a->imul(C_Offset, ldcReg, rowRegs); a->add(CBase, C_Offset); // reset B @@ -277,8 +263,8 @@ CodeGenBase::getOrCreate( } // generate code for remainder if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); + asmjit::Label LoopNRem = a->new_label(); + asmjit::Label LoopkRem = a->new_label(); int rowRegs = mRegBlocksRem; a->xor_(jIdx.r32(), jIdx.r32()); @@ -293,22 +279,17 @@ CodeGenBase::getOrCreate( a->bind(LoopkRem); // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); + a->add(kIdx, row_interleave); genComputeBlock( a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + a->add(buffer_A, row_interleave * sizeof(uint8_t)); // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast(nBlock * row_interleave * sizeof(int8_t))); + a->add(buffer_B, nBlock * row_interleave * sizeof(int8_t)); + a->add(B_pf, nBlock * row_interleave * sizeof(int8_t)); a->cmp(kIdx, kSize); a->jl(LoopkRem); @@ -317,11 +298,7 @@ CodeGenBase::getOrCreate( a->sub(buffer_A, kSize); // B for next block // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast( - nRegBlockSize * row_interleave * sizeof(int8_t))); + a->imul(C_Offset, jIdx, nRegBlockSize * row_interleave * sizeof(int8_t)); a->mov(buffer_B, buffer_B_saved); a->add(buffer_B, C_Offset); a->mov(B_pf, B_pf_saved); @@ -331,7 +308,7 @@ CodeGenBase::getOrCreate( storeCRegs(a, rowRegs, colRegs, C_Offset, ldcReg, accum); // increment C for next B block - a->add(CBase, static_cast(nRegBlockSize * sizeof(int32_t))); + a->add(CBase, nRegBlockSize * sizeof(int32_t)); int jLoopTrips = currColRegs / maxNRegs; // jLoopTrips should be at least 1 @@ -340,24 +317,16 @@ CodeGenBase::getOrCreate( a->jl(LoopNRem); } - a->emitEpilog(frame); + a->emit_epilog(frame); jit_micro_kernel_fp fn = nullptr; - asmjit::Error err = 0; - { - std::unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + asmjit::Error err = runtime.add(&fn, &code); + + if (err != asmjit::Error::kOk) { std::cout << "Error: in fn add" << '\n'; return nullptr; } -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; -#endif - return fn; }); } diff --git a/src/GroupwiseConv.cc b/src/GroupwiseConv.cc index 38ec4910b0..7766c98c9f 100644 --- a/src/GroupwiseConv.cc +++ b/src/GroupwiseConv.cc @@ -15,6 +15,7 @@ #include #include #include "./CodeGenHelpers.h" // @manual +#include "./CodeStorage.h" // @manual #include "fbgemm/Fbgemm.h" #include "fbgemm/QuantUtilsAvx512.h" #include "fbgemm/SimdUtils.h" @@ -175,8 +176,9 @@ static jit_conv_kernel_fp getOrCreateConvKernel( template jit_conv_kernel_fp GenConvKernel::getOrCreate() { + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); asmjit::CodeHolder code; - code.init(this->runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); @@ -196,25 +198,23 @@ jit_conv_kernel_fp GenConvKernel::getOrCreate() { this->K_per_G_); // log code to a file FILE* codeLogfile = fopen(this->getCodeLoggingFile(kernelSig).c_str(), "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code.setLogger(codeLogger); - } + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif // arguments to the function created - in_acts_R_ = a->zdi(); - wghts_R_ = a->zsi(); - out_acts_R_ = a->zdx(); - a_zero_pt_R_ = a->zcx(); - H_start_R_ = a->gpz(8); - H_end_R_ = a->gpz(9); - W_R_ = a->gpz(10); - row_offset_R_ = a->gpz(11); + in_acts_R_ = x86::rdi; + wghts_R_ = x86::rsi; + out_acts_R_ = x86::rdx; + a_zero_pt_R_ = x86::rcx; + H_start_R_ = x86::r8; + H_end_R_ = x86::r9; + W_R_ = x86::r10; + row_offset_R_ = x86::r11; // register for temporary use - scratchReg1_ = a->gpz(12); - scratchReg2_ = a->gpz(13); + scratchReg1_ = x86::r12; + scratchReg2_ = x86::r13; func_.init( asmjit::FuncSignature::build< @@ -231,16 +231,16 @@ jit_conv_kernel_fp GenConvKernel::getOrCreate() { frame_.init(func_); - frame_.setDirtyRegs( + frame_.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame_.setDirtyRegs( + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); + frame_.set_dirty_regs( asmjit::RegGroup::kGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); asmjit::FuncArgsAssignment args(&func_); - args.assignAll( + args.assign_all( in_acts_R_, wghts_R_, out_acts_R_, @@ -250,20 +250,20 @@ jit_conv_kernel_fp GenConvKernel::getOrCreate() { W_R_, row_offset_R_); - args.updateFuncFrame(frame_); + args.update_func_frame(frame_); frame_.finalize(); - a->emitProlog(frame_); - a->emitArgsAssignment(frame_, args); + a->emit_prolog(frame_); + a->emit_args_assignment(frame_, args); // We have run out of register so can't keep // this in a register. It's generated again at // each use. Only used for the case of C_per_G == 2 or 4 // gen8BitVectorOne(a, oneReg8Bit_V_); - gen16BitVectorOne(a, oneReg16Bit_V_); + gen16BitVectorOne(a, oneReg16Bit_V_); - loopR1_ = a->gpz(14); - loopR2_ = a->gpz(15); + loopR1_ = x86::r14; + loopR2_ = x86::r15; if (!this->isAZeroPointZero_) { broadcast8Bit(a, a_zero_pt_R_, zeroPTReg_V_); @@ -278,11 +278,11 @@ jit_conv_kernel_fp GenConvKernel::getOrCreate() { // The following logic calculates the input image width in the same register. // Only works for stride == 2 if (this->STRIDE_ > 1) { - a->imul(W_R_, W_R_, static_cast(this->STRIDE_)); + a->imul(W_R_, W_R_, this->STRIDE_); if (!this->use_right_padding_) { a->inc(W_R_); } - a->sub(W_R_, static_cast(this->STRIDE_ - 1)); + a->sub(W_R_, this->STRIDE_ - 1); } if (this->isTopEdgeIncluded_) { @@ -297,25 +297,16 @@ jit_conv_kernel_fp GenConvKernel::getOrCreate() { a, false /* isTopEdge */, this->use_bottom_padding_ /* isBottomEdge */); } - a->emitEpilog(frame_); + a->emit_epilog(frame_); jit_conv_kernel_fp fn = nullptr; - asmjit::Error err = 0; - { - unique_lock lock(this->rtMutex_); - err = this->runtime().add(&fn, &code); - } + asmjit::Error err = runtime.add(&fn, &code); - if (err) { + if (err != asmjit::Error::kOk) { cout << "Error: in fn add" << '\n'; return nullptr; } -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; -#endif - return fn; } @@ -368,10 +359,7 @@ void GenConvKernel::genForSingleOutput( } if (in_image_H) { // advance input pointer by one row - a->imul( - scratchReg2_, - W_R_, - static_cast(this->C_ * sizeof(uint8_t))); + a->imul(scratchReg2_, W_R_, this->C_ * sizeof(uint8_t)); a->add(in_acts_R_, scratchReg2_); ++num_rows_advanced; } @@ -382,30 +370,25 @@ void GenConvKernel::genForSingleOutput( // row offset if (this->needRowOffset_) { storeOffset(a); - a->add( - row_offset_R_, static_cast(GTogether_ * sizeof(int32_t))); + a->add(row_offset_R_, GTogether_ * sizeof(int32_t)); } // rewind input ptr - a->imul( - scratchReg2_, - W_R_, - static_cast(num_rows_advanced * this->C_ * sizeof(uint8_t))); + a->imul(scratchReg2_, W_R_, num_rows_advanced * this->C_ * sizeof(uint8_t)); a->sub(in_acts_R_, scratchReg2_); // advance output pointer - a->add(out_acts_R_, static_cast(this->K_ * sizeof(int32_t))); + a->add(out_acts_R_, this->K_ * sizeof(int32_t)); // advance input ptr if (!isLeft) { a->add( in_acts_R_, - static_cast(this->STRIDE_ * this->C_ * sizeof(uint8_t))); + this->STRIDE_ * this->C_ * sizeof(uint8_t)); } else if (this->STRIDE_ - this->W_PAD_) { a->add( in_acts_R_, - static_cast( - (this->STRIDE_ - this->W_PAD_) * this->C_ * sizeof(uint8_t))); + (this->STRIDE_ - this->W_PAD_) * this->C_ * sizeof(uint8_t)); } } @@ -419,12 +402,12 @@ void GenConvKernel::genForTopOrBottomEdge( a->movsxd( loopR1_, x86::dword_ptr( - x86::rsp, frame_.saOffsetFromSP() + func_.arg(6).stackOffset())); - asmjit::Label LoopWStart = a->newLabel(); - asmjit::Label LoopWEnd = a->newLabel(); - asmjit::Label skipRightEdge = a->newLabel(); - asmjit::Label skipRightEdgeTemp = a->newLabel(); - a->cmp(loopR1_, static_cast(this->W_PAD_)); + x86::rsp, frame_.sa_offset_from_sp() + func_.arg(6).stack_offset())); + asmjit::Label LoopWStart = a->new_label(); + asmjit::Label LoopWEnd = a->new_label(); + asmjit::Label skipRightEdge = a->new_label(); + asmjit::Label skipRightEdgeTemp = a->new_label(); + a->cmp(loopR1_, this->W_PAD_); a->jle(skipRightEdgeTemp); // left corner code @@ -451,7 +434,7 @@ void GenConvKernel::genForTopOrBottomEdge( // edge excluding corners a->bind(LoopWStart); - a->cmp(loopR1_, static_cast(2 * this->W_PAD_)); + a->cmp(loopR1_, 2 * this->W_PAD_); a->jle(LoopWEnd); genForSingleOutput( @@ -483,15 +466,13 @@ void GenConvKernel::genForTopOrBottomEdge( // STRIDE_ == 2 and odd widths, nothing to do // input ptr is already at the right position if (!this->use_right_padding_) { - a->add(in_acts_R_, static_cast(this->C_ * sizeof(uint8_t))); + a->add(in_acts_R_, this->C_ * sizeof(uint8_t)); } } else { // reset input activation pointer by (W_R_ - W_PAD_) * C_ a->mov(scratchReg2_, W_R_); - a->imul(scratchReg2_, static_cast(this->C_ * sizeof(uint8_t))); - a->sub( - scratchReg2_, - static_cast(this->W_PAD_ * this->C_ * sizeof(uint8_t))); + a->imul(scratchReg2_, this->C_ * sizeof(uint8_t)); + a->sub(scratchReg2_, this->W_PAD_ * this->C_ * sizeof(uint8_t)); a->sub(in_acts_R_, scratchReg2_); } } @@ -507,10 +488,10 @@ void GenConvKernel::genCoreInsts(x86::Emitter* a) { a->dec(H_end_R_); } // main compute - asmjit::Label LoopHStart = a->newLabel(); - asmjit::Label LoopHEnd = a->newLabel(); - asmjit::Label LoopWStart = a->newLabel(); - asmjit::Label LoopWEnd = a->newLabel(); + asmjit::Label LoopHStart = a->new_label(); + asmjit::Label LoopHEnd = a->new_label(); + asmjit::Label LoopWStart = a->new_label(); + asmjit::Label LoopWEnd = a->new_label(); // H loop a->mov(loopR1_, H_start_R_); @@ -521,10 +502,10 @@ void GenConvKernel::genCoreInsts(x86::Emitter* a) { a->movsxd( loopR2_, x86::dword_ptr( - x86::rsp, frame_.saOffsetFromSP() + func_.arg(6).stackOffset())); - asmjit::Label skipRightEdge = a->newLabel(); - asmjit::Label skipRightEdgeTemp = a->newLabel(); - a->cmp(loopR2_, static_cast(this->W_PAD_)); + x86::rsp, frame_.sa_offset_from_sp() + func_.arg(6).stack_offset())); + asmjit::Label skipRightEdge = a->new_label(); + asmjit::Label skipRightEdgeTemp = a->new_label(); + a->cmp(loopR2_, this->W_PAD_); a->jle(skipRightEdgeTemp); genForSingleOutput( @@ -549,7 +530,7 @@ void GenConvKernel::genCoreInsts(x86::Emitter* a) { // W loop a->bind(LoopWStart); - a->cmp(loopR2_, static_cast(2 * this->W_PAD_)); + a->cmp(loopR2_, 2 * this->W_PAD_); a->jle(LoopWEnd); genForSingleOutput( @@ -581,12 +562,12 @@ void GenConvKernel::genCoreInsts(x86::Emitter* a) { assert(this->STRIDE_ == 2 && "Not supported case"); a->mov(scratchReg2_, W_R_); if (!this->use_right_padding_) { - a->add(scratchReg2_, static_cast(1)); + a->add(scratchReg2_, 1); } - a->imul(scratchReg2_, static_cast(this->C_ * sizeof(uint8_t))); + a->imul(scratchReg2_, this->C_ * sizeof(uint8_t)); a->add(in_acts_R_, scratchReg2_); } else { - a->add(in_acts_R_, static_cast(this->C_ * sizeof(uint8_t))); + a->add(in_acts_R_, this->C_ * sizeof(uint8_t)); } a->bind(LoopHEnd); @@ -596,14 +577,17 @@ void GenConvKernel::genCoreInsts(x86::Emitter* a) { template void GenConvKernel::initResultRegs(x86::Emitter* a) { + x86::Vec reg = x86::xmm9; + if (kLoopIters_ > 0) { // Take advantage of implicit zeroing out // i.e., zero out xmm and ymm and zmm will be zeroed out too for (int k = 0; k < kLoopIters_; ++k) { - a->vpxor(Xmm(9 - k), Xmm(9 - k), Xmm(9 - k)); + reg.set_id(9u - unsigned(k)); + a->vpxor(reg, reg, reg); } } else { - a->vpxor(Xmm(9), Xmm(9), Xmm(9)); + a->vpxor(reg, reg, reg); } } diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index 27d3a96185..90d9f71fa6 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -163,16 +163,6 @@ class GenConvKernelBase { return oss.str(); } - static asmjit::JitRuntime& runtime() { - static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, - // depents on other static - // variables. Required to prevent - // initialization order fiasco - return rt; - } - - inline static std::mutex rtMutex_; ///< Control access to runtime; - inline static CodeCache< kernel_sig_t, jit_conv_kernel_fp> diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index a54233c959..c36487acf4 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -19,34 +19,33 @@ namespace x86 = asmjit::x86; GCONV_INST_DEF_AVX2_HEADER GenConvKernel::genConstForPermutations(x86::Emitter* a) { if (this->C_per_G_ == 4) { - x86::Gp permute_const_reg = a->gpz(12); + x86::Gp permute_const_reg = x86::r12; auto const_reg_xmm = x86::xmm11; // We have 1st group in even lanes and 2nd group in odd lanes. // Permute to put 1st group to lower 128-bit and 2nd group in upper // 128-bit. // load 7, 5, 3, 1, 6, 4, 2, 0 in a 64-bit reg - a->mov(permute_const_reg, static_cast(0x0705030106040200)); - a->movq(const_reg_xmm, permute_const_reg); + a->mov(permute_const_reg, 0x0705030106040200); + a->vmovq(const_reg_xmm, permute_const_reg); // Zero extend 8 packed 8-bit integers in the low 8 bytes of const_reg_xmm // to 8 packed 32-bit integers in stPermReg_V_ a->vpmovzxbd(stPermReg_V_, const_reg_xmm); } else { // this->C_per_G_ == 2 - x86::Gp permute_const_reg = a->gpz(12); + x86::Gp permute_const_reg = x86::r12; auto const_reg_xmm = x86::xmm11; // We have 1st group in position 0 and 4, 2nd group 1 and 5 and so on. // Permute to put 1st group to lower 64-bit and 2nd group to next // 64-bit and so on. // load 7, 3, 6, 2, 5, 1, 4, 0 in a 64-bit reg - a->mov(permute_const_reg, static_cast(0x0703060205010400)); - a->movq(const_reg_xmm, permute_const_reg); + a->mov(permute_const_reg, 0x0703060205010400); + a->vmovq(const_reg_xmm, permute_const_reg); a->vpmovzxbd(stPermReg_V_, const_reg_xmm); } } GCONV_INST_DEF_AVX2_HEADER GenConvKernel::genForLoadingWeights(x86::Emitter* a) { - using WRegs = Ymm; int paddedICPerG = (this->C_per_G_ + 3) / 4 * 4; // load weights for (int r = 0; r < this->R_; ++r) { @@ -55,7 +54,7 @@ GenConvKernel::genForLoadingWeights(x86::Emitter* a) { // and are loaded as they are used. if (this->C_per_G_ == 4 || this->C_per_G_ == 2) { a->vmovaps( - WRegs(r * this->S_ + s), + x86::ymm(r * this->S_ + s), x86::dword_ptr( wghts_R_, (r * this->S_ + s) * this->K_per_G_ * GTogether_ * @@ -69,45 +68,45 @@ GCONV_INST_DEF_AVX2_HEADER GenConvKernel::storeResult( x86::Emitter* a) { if (GTogether_ > 1) { // store with permutation - a->vpermd(Ymm(9), stPermReg_V_, Ymm(9)); + a->vpermd(x86::ymm9, stPermReg_V_, x86::ymm9); if (this->accum_) { - a->vpaddd(Ymm(9), Ymm(9), x86::dword_ptr(out_acts_R_)); + a->vpaddd(x86::ymm9, x86::ymm9, x86::dword_ptr(out_acts_R_)); } - a->vmovups(x86::dword_ptr(out_acts_R_), Ymm(9)); + a->vmovups(x86::dword_ptr(out_acts_R_), x86::ymm9); } else { // horizontal add and store if (this->C_per_G_ == 8) { - a->vphaddd(Ymm(9), Ymm(9), Ymm(8)); - a->vpermq(Ymm(9), Ymm(9), static_cast(0xd8)); + a->vphaddd(x86::ymm9, x86::ymm9, x86::ymm8); + a->vpermq(x86::ymm9, x86::ymm9, 0xd8); if (this->accum_) { - a->vpaddd(Ymm(9), Ymm(9), x86::dword_ptr(out_acts_R_)); + a->vpaddd(x86::ymm9, x86::ymm9, x86::dword_ptr(out_acts_R_)); } - a->vmovups(x86::dword_ptr(out_acts_R_), Ymm(9)); + a->vmovups(x86::dword_ptr(out_acts_R_), x86::ymm9); } else if (this->K_per_G_ == 16) { - a->vphaddd(Ymm(9), Ymm(9), Ymm(8)); - a->vpermq(Ymm(9), Ymm(9), static_cast(0xd8)); + a->vphaddd(x86::ymm9, x86::ymm9, x86::ymm8); + a->vpermq(x86::ymm9, x86::ymm9, 0xd8); - a->vphaddd(Ymm(7), Ymm(7), Ymm(6)); - a->vpermq(Ymm(7), Ymm(7), static_cast(0xd8)); + a->vphaddd(x86::ymm7, x86::ymm7, x86::ymm6); + a->vpermq(x86::ymm7, x86::ymm7, 0xd8); - a->vphaddd(Ymm(5), Ymm(5), Ymm(4)); - a->vpermq(Ymm(5), Ymm(5), static_cast(0xd8)); + a->vphaddd(x86::ymm5, x86::ymm5, x86::ymm4); + a->vpermq(x86::ymm5, x86::ymm5, 0xd8); - a->vphaddd(Ymm(3), Ymm(3), Ymm(2)); - a->vpermq(Ymm(3), Ymm(3), static_cast(0xd8)); + a->vphaddd(x86::ymm3, x86::ymm3, x86::ymm2); + a->vpermq(x86::ymm3, x86::ymm3, 0xd8); - a->vphaddd(Ymm(9), Ymm(9), Ymm(7)); - a->vpermq(Ymm(9), Ymm(9), static_cast(0xd8)); + a->vphaddd(x86::ymm9, x86::ymm9, x86::ymm7); + a->vpermq(x86::ymm9, x86::ymm9, 0xd8); - a->vphaddd(Ymm(5), Ymm(5), Ymm(3)); - a->vpermq(Ymm(5), Ymm(5), static_cast(0xd8)); + a->vphaddd(x86::ymm5, x86::ymm5, x86::ymm3); + a->vpermq(x86::ymm5, x86::ymm5, 0xd8); if (this->accum_) { - a->vpaddd(Ymm(9), Ymm(9), x86::dword_ptr(out_acts_R_)); - a->vpaddd(Ymm(5), Ymm(5), x86::dword_ptr(out_acts_R_, 32)); + a->vpaddd(x86::ymm9, x86::ymm9, x86::dword_ptr(out_acts_R_)); + a->vpaddd(x86::ymm5, x86::ymm5, x86::dword_ptr(out_acts_R_, 32)); } - a->vmovups(x86::dword_ptr(out_acts_R_), Ymm(9)); - a->vmovups(x86::dword_ptr(out_acts_R_, 32), Ymm(5)); + a->vmovups(x86::dword_ptr(out_acts_R_), x86::ymm9); + a->vmovups(x86::dword_ptr(out_acts_R_, 32), x86::ymm5); } } } @@ -161,8 +160,6 @@ GenConvKernel::genForSingleFilterPoint( int s, int act_s, bool use_zero_reg) { - using WRegs = Ymm; - if (GTogether_ > 1) { if (this->C_per_G_ == 2) { // group together = 4 if (use_zero_reg) { @@ -196,8 +193,8 @@ GenConvKernel::genForSingleFilterPoint( genU8I8S32FMA( a, actReg_V_, - WRegs(r * this->S_ + s), - Ymm(9), + x86::ymm(r * this->S_ + s), + x86::ymm9, oneReg16Bit_V_, tmpReg1_V_); } else { @@ -226,7 +223,7 @@ GenConvKernel::genForSingleFilterPoint( int kLoopMultiplier = 32 / this->C_per_G_; for (int k = 0; k < kLoopIters_; ++k) { a->vmovaps( - WRegs(0), + x86::ymm0, x86::dword_ptr( wghts_R_, (((r * this->S_ + s) * this->K_per_G_) + k * kLoopMultiplier) * @@ -235,7 +232,7 @@ GenConvKernel::genForSingleFilterPoint( // which consectutive 2 elements if summedforms one final output over // K_Per_G dimension genU8I8S32FMA( - a, actReg_V_, WRegs(0), Ymm(9 - k), oneReg16Bit_V_, tmpReg1_V_); + a, actReg_V_, x86::ymm0, x86::ymm(9 - k), oneReg16Bit_V_, tmpReg1_V_); } } } diff --git a/src/GroupwiseConvAcc32Avx512.cc b/src/GroupwiseConvAcc32Avx512.cc index c35524f809..e2be194932 100644 --- a/src/GroupwiseConvAcc32Avx512.cc +++ b/src/GroupwiseConvAcc32Avx512.cc @@ -18,8 +18,8 @@ namespace x86 = asmjit::x86; GCONV_INST_DEF_AVX512_AND_VNNI_HEADER GenConvKernel::genConstForPermutations(x86::Emitter* a) { - x86::Gp permute_const_reg_upper_half = a->gpz(12); - x86::Gp permute_const_reg_lower_half = a->gpz(13); + x86::Gp permute_const_reg_upper_half = x86::r12; + x86::Gp permute_const_reg_lower_half = x86::r13; auto const_reg_xmm = x86::xmm11; if (this->C_per_G_ == 4) { // 4 group together @@ -30,12 +30,8 @@ GenConvKernel::genConstForPermutations(x86::Emitter* a) { // e, a, 6, 2, // d, 9, 5, 1, // c, 8, 4, 0 in a 128-bit Xmm - a->mov( - permute_const_reg_lower_half, - static_cast(0x0d0905010c080400)); - a->mov( - permute_const_reg_upper_half, - static_cast(0x0f0b07030e0a0602)); + a->mov(permute_const_reg_lower_half, 0x0d0905010c080400); + a->mov(permute_const_reg_upper_half, 0x0f0b07030e0a0602); } else { // this->C_per_G_ == 2 // 8 group together @@ -51,16 +47,12 @@ GenConvKernel::genConstForPermutations(x86::Emitter* a) { // a, 2 // 9, 1 // 8, 0 in a 128-bit Xmm - a->mov( - permute_const_reg_lower_half, - static_cast(0x0b030a0209010800)); - a->mov( - permute_const_reg_upper_half, - static_cast(0x0f070e060d050c04)); + a->mov(permute_const_reg_lower_half, 0x0b030a0209010800); + a->mov(permute_const_reg_upper_half, 0x0f070e060d050c04); } - a->movq(const_reg_xmm, permute_const_reg_lower_half); - a->pinsrq(const_reg_xmm, permute_const_reg_upper_half, 1); + a->vmovq(const_reg_xmm, permute_const_reg_lower_half); + a->vpinsrq(const_reg_xmm, const_reg_xmm, permute_const_reg_upper_half, 1); // Zero extend 16 packed 8-bit integers in the low 8 bytes of const_reg_xmm // to 16 packed 32-bit integers in stPermReg_V_ a->vpmovzxbd(stPermReg_V_, const_reg_xmm); @@ -68,7 +60,6 @@ GenConvKernel::genConstForPermutations(x86::Emitter* a) { GCONV_INST_DEF_AVX512_AND_VNNI_HEADER GenConvKernel::genForLoadingWeights(x86::Emitter* a) { - using WRegs = Zmm; int paddedICPerG = (this->C_per_G_ + 3) / 4 * 4; // load weights for (int r = 0; r < this->R_; ++r) { @@ -78,7 +69,7 @@ GenConvKernel::genForLoadingWeights(x86::Emitter* a) { if (this->C_per_G_ != 16) { // still use aligned move since the weigh buffer is 64bytes aligned. a->vmovaps( - WRegs(r * this->S_ + s), + x86::zmm(r * this->S_ + s), // load 512 bits for weights, different grouping for different // workload x86::zmmword_ptr( @@ -94,38 +85,38 @@ GCONV_INST_DEF_AVX512_AND_VNNI_HEADER GenConvKernel::storeResult(x86::Emitter* a) { if (GTogether_ > 1) { // store with permutation - a->vpermd(Zmm(9), stPermReg_V_, Zmm(9)); + a->vpermd(x86::zmm9, stPermReg_V_, x86::zmm9); if (this->accum_) { - a->vpaddd(Zmm(9), Zmm(9), x86::zmmword_ptr(out_acts_R_)); + a->vpaddd(x86::zmm9, x86::zmm9, x86::zmmword_ptr(out_acts_R_)); } - a->vmovups(x86::zmmword_ptr(out_acts_R_), Zmm(9)); + a->vmovups(x86::zmmword_ptr(out_acts_R_), x86::zmm9); } else { // horizontal add and store if (this->C_per_G_ == 8) { - a->vextracti32x8(tmpReg1_V_.ymm(), Zmm(9), 1); - a->vphaddd(Ymm(9), Ymm(9), tmpReg1_V_.ymm()); - a->vpermq(Ymm(9), Ymm(9), static_cast(0xd8)); + a->vextracti32x8(tmpReg1_V_.ymm(), x86::zmm9, 1); + a->vphaddd(x86::ymm9, x86::ymm9, tmpReg1_V_.ymm()); + a->vpermq(x86::ymm9, x86::ymm9, 0xd8); if (this->accum_) { - a->vpaddd(Ymm(9), Ymm(9), x86::ymmword_ptr(out_acts_R_)); + a->vpaddd(x86::ymm9, x86::ymm9, x86::ymmword_ptr(out_acts_R_)); } - a->vmovups(x86::ymmword_ptr(out_acts_R_), Ymm(9)); + a->vmovups(x86::ymmword_ptr(out_acts_R_), x86::ymm9); } else if (this->K_per_G_ == 16) { // we have results in 4 Zmm registers, need to reduce them to 2 Ymm // register 2 * 8 * 32 where 16 is K_per_g // first reduce 4 * 16 * 32bits to 4 * 8 * 32bits for (int k = 0; k < kLoopIters_; ++k) { - auto source_reg = Zmm(9 - k); - auto result_reg = Ymm(9 - k); - a->vextracti32x8(Ymm(0), source_reg, 1); - a->vphaddd(result_reg, result_reg, Ymm(0)); - a->vpermq(result_reg, result_reg, static_cast(0xd8)); + auto source_reg = x86::zmm(9 - k); + auto result_reg = x86::ymm(9 - k); + a->vextracti32x8(x86::ymm0, source_reg, 1); + a->vphaddd(result_reg, result_reg, x86::ymm0); + a->vpermq(result_reg, result_reg, 0xd8); } // secondly reduce 4 * 8 * 32 to 2 * 8 * 32 bits; for (int k = 0, i = 0; k < kLoopIters_; k += 2, i++) { - auto result_reg = Ymm(9 - k); - auto adjacent_result_reg = Ymm(9 - k - 1); + auto result_reg = x86::ymm(9 - k); + auto adjacent_result_reg = x86::ymm(9 - k - 1); a->vphaddd(result_reg, result_reg, adjacent_result_reg); - a->vpermq(result_reg, result_reg, static_cast(0xd8)); + a->vpermq(result_reg, result_reg, 0xd8); if (this->accum_) { a->vpaddd( result_reg, result_reg, x86::ymmword_ptr(out_acts_R_, 32 * i)); @@ -192,8 +183,6 @@ GenConvKernel::genForSingleFilterPoint( int s, int act_s, bool use_zero_reg) { - using WRegs = Zmm; - if (use_zero_reg) { a->vmovapd(actReg_V_, zeroPTReg_V_); // 64 * 8 bit zero points } else { @@ -234,8 +223,8 @@ GenConvKernel::genForSingleFilterPoint( genU8I8S32FMA( a, actReg_V_, - WRegs(r * this->S_ + s), - WRegs(9), + x86::zmm(r * this->S_ + s), + x86::zmm9, oneReg16Bit_V_, tmpReg1_V_); } else { @@ -243,7 +232,7 @@ GenConvKernel::genForSingleFilterPoint( int kLoopMultiplier = 64 / this->C_per_G_; for (int k = 0; k < kLoopIters_; ++k) { a->vmovaps( - WRegs(0), + x86::zmm0, // copy 512 bits of weights into ZMM, 16(C_Per_g) * 4(1/4 of K_Per_g) x86::zmmword_ptr( wghts_R_, @@ -253,7 +242,7 @@ GenConvKernel::genForSingleFilterPoint( // in which consectutive 4 elements if summed forms one final output over // K_Per_G dimension, we need 16 final 32bits outputs. genU8I8S32FMA( - a, actReg_V_, WRegs(0), WRegs(9 - k), oneReg16Bit_V_, tmpReg1_V_); + a, actReg_V_, x86::zmm0, x86::zmm(9 - k), oneReg16Bit_V_, tmpReg1_V_); } } } diff --git a/src/RowWiseSparseAdagradFused.cc b/src/RowWiseSparseAdagradFused.cc index 52930eecc3..4b55069d6a 100644 --- a/src/RowWiseSparseAdagradFused.cc +++ b/src/RowWiseSparseAdagradFused.cc @@ -13,6 +13,8 @@ #include #include #include "./CodeCache.h" // @manual +#include "./CodeGenHelpers.h" // @manual +#include "./CodeStorage.h" // @manual #include "./MaskAvx2.h" // @manual #include "./RefImplementations.h" // @manual #include "fbgemm/SimdUtils.h" @@ -61,13 +63,6 @@ class GenRowWiseSparseAdagradFused { int grad_stride); private: - static asmjit::JitRuntime& runtime() { - static asmjit::JitRuntime rt; // JIT Runtime for asmjit - return rt; - } - - inline static mutex rtMutex_; /// Control access to runtime; - // The hash depends on: // avx2 mask array, embedding dimension (block size), prefetch distance, // use_offsets and use_stochastic_rouding switch @@ -107,8 +102,9 @@ typename ReturnFunctionSignature:: indxType, offsetType, dataType>::jit_sparse_adagrad_kernel { + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); constexpr bool areIndices64b = is_same_v; @@ -125,24 +121,25 @@ typename ReturnFunctionSignature:: } filename += ".txt"; FILE* codeLogFile = fopen(filename.c_str(), "w"); - auto codeLogger = std::make_unique(codeLogFile); - code.setLogger(codeLogger.get()); + + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif - x86::Gp rand_buffer = a->zax(); - x86::Gp output_size = a->zdi(); - x86::Gp index_size = a->zsi(); - x86::Gp data_size = a->zdx(); - x86::Gp w = a->zcx(); - x86::Gp g = a->gpz(8); - x86::Gp h = a->gpz(9); - x86::Gp indices = a->gpz(10); - x86::Gp lengths = a->gpz(11); - Xmm epsilon(0); - Xmm lr(1); - auto lengths_R = a->gpz(12).r32(); - x86::Gp scratchReg1 = a->gpz(13); - x86::Gp scratchReg2 = a->gpz(14); // for prefetching + x86::Gp rand_buffer = x86::rax; + x86::Gp output_size = x86::rdi; + x86::Gp index_size = x86::rsi; + x86::Gp data_size = x86::rdx; + x86::Gp w = x86::rcx; + x86::Gp g = x86::r8; + x86::Gp h = x86::r9; + x86::Gp indices = x86::r10; + x86::Gp lengths = x86::r11; + x86::Vec epsilon = x86::xmm0; + x86::Vec lr = x86::xmm1; + auto lengths_R = x86::r12d; + x86::Gp scratchReg1 = x86::r13; + x86::Gp scratchReg2 = x86::r14; // for prefetching asmjit::FuncDetail func; func.init( @@ -165,25 +162,25 @@ typename ReturnFunctionSignature:: frame.init(func); if constexpr (instSet == inst_set_t::avx2) { - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); } else { - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | - asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | - asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15) | + asmjit::Support::bit_mask(16, 17, 18, 19, 20, 21, 22, 23) | + asmjit::Support::bit_mask(24, 25, 26, 27, 28, 29, 30, 31)); } - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14)); asmjit::FuncArgsAssignment args(&func); - args.assignAll( + args.assign_all( output_size, index_size, data_size, @@ -196,10 +193,10 @@ typename ReturnFunctionSignature:: lr, rand_buffer); - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + a->emit_prolog(frame); + a->emit_args_assignment(frame, args); constexpr int vlen = simd_info::WIDTH_32BIT_ELEMS; constexpr int NUM_VEC_REG = simd_info::NUM_VEC_REGS; @@ -210,13 +207,13 @@ typename ReturnFunctionSignature:: int remainder = block_size % vlen; vec_reg_t src_vreg; // for holding embedding value temporarily - Ymm mask_vreg; + x86::Vec mask_vreg; // Reserve registers with small ids first because some of them need to // be used with an instruction not supported in avx512 for which a big // register id won't work. int first_available_vec_reg_id = 0; - Ymm partial_sum_vreg = Ymm(first_available_vec_reg_id); + x86::Vec partial_sum_vreg = x86::ymm(first_available_vec_reg_id); ++first_available_vec_reg_id; vec_reg_t float_step_vreg = vec_reg_t(first_available_vec_reg_id); ++first_available_vec_reg_id; @@ -281,7 +278,7 @@ typename ReturnFunctionSignature:: src_vreg = vec_reg_t(first_available_vec_reg_id); ++first_available_vec_reg_id; - mask_vreg = Ymm(first_available_vec_reg_id); + mask_vreg = x86::ymm(first_available_vec_reg_id); ++first_available_vec_reg_id; // Use scratchReg1 as temp a->mov(scratchReg1, asmjit::imm(mask_avx2)); @@ -291,7 +288,7 @@ typename ReturnFunctionSignature:: scratchReg1, (vlen - remainder) % vlen * sizeof(int32_t))); } else { a->mov(scratchReg1, (1 << remainder) - 1); - a->kmovw(x86::k(1), scratchReg1); + a->kmovw(x86::k1, scratchReg1); } } // Need an extra mask for computing sum of gradients @@ -299,7 +296,7 @@ typename ReturnFunctionSignature:: block_size % simd_info::WIDTH_32BIT_ELEMS; x86::KReg reduce_mask_avx512; if (remainder_avx2 && instSet == inst_set_t::avx512) { - reduce_mask_avx512 = x86::k(2); + reduce_mask_avx512 = x86::k2; a->mov(scratchReg1, (1 << remainder_avx2) - 1); a->kmovw(reduce_mask_avx512, scratchReg1); } @@ -307,17 +304,14 @@ typename ReturnFunctionSignature:: int unroll_factor = NUM_VEC_REG - first_available_vec_reg_id; // Compute the end address of indices - a->imul( - scratchReg1, - index_size, - static_cast(sizeof(indxType))); + a->imul(scratchReg1, index_size, sizeof(indxType)); a->add(scratchReg1, indices); a->mov(index_size, scratchReg1); - asmjit::Label exit = a->newLabel(); - asmjit::Label error = a->newLabel(); - asmjit::Label LoopRangeIndexBegin = a->newLabel(); - asmjit::Label LoopRangeIndexEnd = a->newLabel(); + asmjit::Label exit = a->new_label(); + asmjit::Label error = a->new_label(); + asmjit::Label LoopRangeIndexBegin = a->new_label(); + asmjit::Label LoopRangeIndexEnd = a->new_label(); // rangeIndex loop begin (iterate output_size times) a->bind(LoopRangeIndexBegin); @@ -342,7 +336,7 @@ typename ReturnFunctionSignature:: int cur_unroll_factor = std::min(unroll_factor, num_vec_regs_per_block_avx2 - vec_idx); for (int v = 0; v < cur_unroll_factor; ++v) { - Ymm out_vreg = Ymm(v + first_available_vec_reg_id); + x86::Vec out_vreg = x86::ymm(v + first_available_vec_reg_id); auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen_avx2 * sizeof(float)); @@ -360,25 +354,9 @@ typename ReturnFunctionSignature:: a->vaddps(partial_sum_vreg, partial_sum_vreg, out_vreg); } } - // Reduce sum to 1 value - // __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum); - // __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2); - // Use YMM/XMMs with smaller ids for AVX2 specific instructions like - // vhaddps - Xmm partial_sum_xmm(partial_sum_vreg.id()); - Xmm float_step_xmm(float_step_vreg.id()); - // a->vmovups(partial_sum_temp0_ymm, partial_sum_vreg); - a->vhaddps(partial_sum_vreg, partial_sum_vreg, partial_sum_vreg); - a->vhaddps(partial_sum_vreg, partial_sum_vreg, partial_sum_vreg); - - //_mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) - a->movss(float_step_xmm, partial_sum_xmm); - //_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1)) - a->vextractf128(partial_sum_xmm, partial_sum_vreg, 1); - - // final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) + - // _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1)); - a->addss(partial_sum_xmm, float_step_xmm); + + // Reduce sum to a single value. + emitReduceAddF32(a, partial_sum_vreg, float_step_vreg); // This fragment moves block size (N) to stack and bcasts it to xmm reg a->lea( @@ -386,13 +364,13 @@ typename ReturnFunctionSignature:: x86::dword_ptr(x86::rsp, -1 * static_cast(sizeof(int32_t)))); a->mov(x86::dword_ptr(x86::rsp), block_size); a->vbroadcastss( - float_step_xmm, + float_step_vreg.xmm(), x86::dword_ptr(x86::rsp)); // N is partial_sum_xmm1 - a->vcvtdq2ps(float_step_xmm, float_step_xmm); + a->vcvtdq2ps(float_step_vreg.xmm(), float_step_vreg.xmm()); a->lea(x86::rsp, x86::dword_ptr(x86::rsp, sizeof(int32_t))); // final_sum /= N - a->divss(partial_sum_xmm, float_step_xmm); + a->vdivss(partial_sum_vreg.xmm(), partial_sum_vreg.xmm(), float_step_vreg.xmm()); if (use_offsets) { a->mov(lengths_R, x86::dword_ptr(lengths, sizeof(offsetType))); @@ -402,15 +380,14 @@ typename ReturnFunctionSignature:: } // Array out of bound check - a->imul( - scratchReg1, lengths_R, static_cast(sizeof(indxType))); + a->imul(scratchReg1, lengths_R, sizeof(indxType)); a->add(scratchReg1, indices); a->cmp(scratchReg1, index_size); a->jg(error); - asmjit::Label LoopDataIndexBegin = a->newLabel(); - asmjit::Label LoopDataIndexEnd = a->newLabel(); + asmjit::Label LoopDataIndexBegin = a->new_label(); + asmjit::Label LoopDataIndexEnd = a->new_label(); // dataIndex loop begins (iterate lengths_R_ times) a->bind(LoopDataIndexBegin); @@ -430,13 +407,11 @@ typename ReturnFunctionSignature:: a->jae(error); if (prefetch) { - asmjit::Label pref_dist_reset_start = a->newLabel(); - asmjit::Label pref_dist_reset_end = a->newLabel(); + asmjit::Label pref_dist_reset_start = a->new_label(); + asmjit::Label pref_dist_reset_end = a->new_label(); // out of bound handling for prefetch a->mov(scratchReg2, indices); - a->add( - scratchReg2, - static_cast(prefetch * sizeof(indxType))); + a->add(scratchReg2, prefetch * sizeof(indxType)); a->cmp(scratchReg2, index_size); a->jge(pref_dist_reset_start); @@ -464,28 +439,28 @@ typename ReturnFunctionSignature:: a->bind(pref_dist_reset_end); } - a->add(indices, static_cast(sizeof(indxType))); + a->add(indices, sizeof(indxType)); if (prefetch) { a->prefetchw(x86::dword_ptr(h, scratchReg2, 2)); } // load h - a->movss(float_step_xmm, x86::dword_ptr(h, scratchReg1, 2)); + a->vmovss(float_step_vreg.xmm(), x86::dword_ptr(h, scratchReg1, 2)); // *h + final_sum - a->addss(float_step_xmm, partial_sum_xmm); + a->vaddss(float_step_vreg.xmm(), float_step_vreg.xmm(), partial_sum_vreg.xmm()); // store h - a->movss(x86::dword_ptr(h, scratchReg1, 2), float_step_xmm); + a->vmovss(x86::dword_ptr(h, scratchReg1, 2), float_step_vreg.xmm()); // sqrt(hi) - a->sqrtss(float_step_xmm, float_step_xmm); + a->vsqrtss(float_step_vreg.xmm(), float_step_vreg.xmm(), float_step_vreg.xmm()); // bcast partial to all of ymm/zmm reg - a->vpbroadcastd(float_step_vreg, float_step_xmm); + a->vpbroadcastd(float_step_vreg, float_step_vreg.xmm()); // lr / sqrt(hi) + epsilon a->vaddps(float_step_vreg, float_step_vreg, epsilon_vreg); a->vdivps(float_step_vreg, lr_vreg, float_step_vreg); - a->imul(scratchReg1, static_cast(block_size)); + a->imul(scratchReg1, block_size); if (prefetch) { - a->imul(scratchReg2, static_cast(block_size)); + a->imul(scratchReg2, block_size); } for (int vec_idx = 0; vec_idx < num_vec_regs_per_block; @@ -512,9 +487,9 @@ typename ReturnFunctionSignature:: a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm()); } else { - a->k(x86::k(1)).vmulps(out_vreg, float_step_vreg, g_ptr); - a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr); - a->k(x86::k(1)).vmovups(w_ptr, out_vreg); + a->k(x86::k1).vmulps(out_vreg, float_step_vreg, g_ptr); + a->k(x86::k1).vaddps(out_vreg, out_vreg, w_ptr); + a->k(x86::k1).vmovups(w_ptr, out_vreg); } } else { a->vmulps(out_vreg, float_step_vreg, g_ptr); @@ -552,37 +527,20 @@ typename ReturnFunctionSignature:: a->vpaddd(r0_vreg, S0_vreg, S3_vreg); a->vpslld(r1_vreg, r0_vreg, 7); a->vpsrld(r0_vreg, r0_vreg, 25); - if constexpr (instSet == inst_set_t::avx2) { - a->vpor(R_vreg.ymm(), r0_vreg.ymm(), r1_vreg.ymm()); - } else { - a->vpord(R_vreg, r0_vreg, r1_vreg); - } + emitVecOr(a, R_vreg, r0_vreg, r1_vreg); a->vpaddd(R_vreg, R_vreg, S0_vreg); a->vpslld(r0_vreg, S1_vreg, 9); - if constexpr (instSet == inst_set_t::avx2) { - a->vpxor(S2_vreg.ymm(), S2_vreg.ymm(), S0_vreg.ymm()); - a->vpxor(S3_vreg.ymm(), S3_vreg.ymm(), S1_vreg.ymm()); - a->vpxor(S1_vreg.ymm(), S1_vreg.ymm(), S2_vreg.ymm()); - a->vpxor(S0_vreg.ymm(), S0_vreg.ymm(), S3_vreg.ymm()); - - a->vpxor(S2_vreg.ymm(), S2_vreg.ymm(), r0_vreg.ymm()); - } else { - a->vpxord(S2_vreg, S2_vreg, S0_vreg); - a->vpxord(S3_vreg, S3_vreg, S1_vreg); - a->vpxord(S1_vreg, S1_vreg, S2_vreg); - a->vpxord(S0_vreg, S0_vreg, S3_vreg); + emitVecXor(a, S2_vreg, S2_vreg, S0_vreg); + emitVecXor(a, S3_vreg, S3_vreg, S1_vreg); + emitVecXor(a, S1_vreg, S1_vreg, S2_vreg); + emitVecXor(a, S0_vreg, S0_vreg, S3_vreg); + emitVecXor(a, S2_vreg, S2_vreg, r0_vreg); - a->vpxord(S2_vreg, S2_vreg, r0_vreg); - } a->vpslld(r0_vreg, S3_vreg, 11); a->vpsrld(r1_vreg, S3_vreg, 21); - if constexpr (instSet == inst_set_t::avx2) { - a->vpor(S3_vreg.ymm(), r0_vreg.ymm(), r1_vreg.ymm()); - } else { - a->vpord(S3_vreg, r0_vreg, r1_vreg); - } + emitVecOr(a, S3_vreg, r0_vreg, r1_vreg); // Extract byte 0 and shift to bits[5..13] a->vpslld(r0_vreg, R_vreg, 24); @@ -650,13 +608,13 @@ typename ReturnFunctionSignature:: a->mov(h, x86::ptr(x86::rsp)); a->lea(x86::rsp, x86::ptr(x86::rsp, 8)); } else { - a->k(x86::k(1)).vcvtph2ps(out_vreg, w_ptr); - a->k(x86::k(1)).vfmadd231ps(out_vreg, float_step_vreg, g_ptr); + a->k(x86::k1).vcvtph2ps(out_vreg, w_ptr); + a->k(x86::k1).vfmadd231ps(out_vreg, float_step_vreg, g_ptr); if (use_stochastic_rounding) { a->vpaddd(out_vreg, r0_vreg, out_vreg); } // Truncate rounding - a->k(x86::k(1)).vcvtps2ph(w_ptr, out_vreg, 11); + a->k(x86::k1).vcvtps2ph(w_ptr, out_vreg, 11); } } else { a->vcvtph2ps(out_vreg, w_ptr); @@ -687,8 +645,8 @@ typename ReturnFunctionSignature:: a->jmp(LoopDataIndexBegin); a->bind(LoopDataIndexEnd); - a->add(lengths, static_cast(sizeof(offsetType))); - a->add(g, static_cast(grad_stride * sizeof(float))); + a->add(lengths, sizeof(offsetType)); + a->add(g, grad_stride * sizeof(float)); a->jmp(LoopRangeIndexBegin); a->bind(LoopRangeIndexEnd); @@ -702,50 +660,25 @@ typename ReturnFunctionSignature:: a->bind(exit); if (areWeightsFp16 && use_stochastic_rounding) { - if constexpr (instSet == inst_set_t::avx2) { - a->vmovdqa(x86::dword_ptr(rand_buffer), S0_vreg.ymm()); - a->vmovdqa( - x86::dword_ptr(rand_buffer, 1 * vlen * sizeof(uint32_t)), - S1_vreg.ymm()); - a->vmovdqa( - x86::dword_ptr(rand_buffer, 2 * vlen * sizeof(uint32_t)), - S2_vreg.ymm()); - a->vmovdqa( - x86::dword_ptr(rand_buffer, 3 * vlen * sizeof(uint32_t)), - S3_vreg.ymm()); - } else { - a->vmovdqa32(x86::dword_ptr(rand_buffer), S0_vreg); - a->vmovdqa32( - x86::dword_ptr(rand_buffer, 1 * vlen * sizeof(uint32_t)), - S1_vreg); - a->vmovdqa32( - x86::dword_ptr(rand_buffer, 2 * vlen * sizeof(uint32_t)), - S2_vreg); - a->vmovdqa32( - x86::dword_ptr(rand_buffer, 3 * vlen * sizeof(uint32_t)), - S3_vreg); - } + emitVecMove(a, x86::ptr(rand_buffer), S0_vreg); + emitVecMove(a, x86::ptr(rand_buffer, 1 * vlen * sizeof(uint32_t)), S1_vreg); + emitVecMove(a, x86::ptr(rand_buffer, 2 * vlen * sizeof(uint32_t)), S2_vreg); + emitVecMove(a, x86::ptr(rand_buffer, 3 * vlen * sizeof(uint32_t)), S3_vreg); } a->mov(x86::eax, scratchReg1.r32()); - a->emitEpilog(frame); + a->emit_epilog(frame); // jit_fused8bitembedding_kernel fn; typename ReturnFunctionSignature:: jit_sparse_adagrad_kernel fn; - asmjit::Error err = 0; - { - unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + asmjit::Error err = runtime.add(&fn, &code); + + if (err != asmjit::Error::kOk) { cout << "Error: in fn add" << '\n'; return nullptr; } -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogFile); -#endif return fn; }); } // getOrCreate diff --git a/src/SparseAdagrad.cc b/src/SparseAdagrad.cc index 30c6e6cf8d..06d8886376 100644 --- a/src/SparseAdagrad.cc +++ b/src/SparseAdagrad.cc @@ -15,6 +15,9 @@ #include #include #include "./CodeCache.h" // @manual +#include "./CodeCache.h" // @manual +#include "./CodeGenHelpers.h" // @manual +#include "./CodeStorage.h" // @manual #include "./MaskAvx2.h" // @manual #include "./RefImplementations.h" // @manual #include "fbgemm/SimdUtils.h" @@ -55,11 +58,11 @@ class GenSparseAdagrad { int num_vec_regs_per_block, int remainder, int prefetch, - const typename simd_info::vec_reg_t& epsilon_vreg, - const typename simd_info::vec_reg_t& lr_vreg, - const Ymm& mask_vreg, - const typename simd_info::vec_reg_t& temp_vreg, - const typename simd_info::vec_reg_t& weight_decay_vreg, + const x86::Vec& epsilon_vreg, + const x86::Vec& lr_vreg, + const x86::Vec& mask_vreg, + const x86::Vec& temp_vreg, + const x86::Vec& weight_decay_vreg, bool has_weight_decay); void genRowwiseSparseAdagrad( @@ -69,11 +72,11 @@ class GenSparseAdagrad { int num_vec_regs_per_block, int remainder, int prefetch, - const typename simd_info::vec_reg_t& epsilon_vreg, - const typename simd_info::vec_reg_t& lr_vreg, - const Ymm& mask_vreg, - const typename simd_info::vec_reg_t& temp_vreg, - const typename simd_info::vec_reg_t& weight_decay_vreg, + const x86::Vec& epsilon_vreg, + const x86::Vec& lr_vreg, + const x86::Vec& mask_vreg, + const x86::Vec& temp_vreg, + const x86::Vec& weight_decay_vreg, bool has_weight_decay); typename ReturnFunctionSignature::jit_sparse_adagrad_kernel @@ -84,13 +87,6 @@ class GenSparseAdagrad { bool has_weight_decay); private: - static asmjit::JitRuntime& runtime() { - static asmjit::JitRuntime rt; // JIT Runtime for asmjit - return rt; - } - - inline static std::mutex rtMutex_; /// Controll access to runtime; - // The hash depends on embedding dimension (block size), prefetch distance, // rowwise, and has_weight_decay inline static CodeCache< @@ -118,11 +114,11 @@ void GenSparseAdagrad::genSparseAdagrad( int num_vec_regs_per_block, int remainder, int prefetch, - const typename simd_info::vec_reg_t& epsilon_vreg, - const typename simd_info::vec_reg_t& lr_vreg, - const Ymm& mask_vreg, - const typename simd_info::vec_reg_t& temp_vreg, - const typename simd_info::vec_reg_t& weight_decay_vreg, + const x86::Vec& epsilon_vreg, + const x86::Vec& lr_vreg, + const x86::Vec& mask_vreg, + const x86::Vec& temp_vreg, + const x86::Vec& weight_decay_vreg, bool has_weight_decay) { // NOTE: temp_vreg is defined only when remainder is true and instSet == avx2 using vec_reg_t = typename simd_info::vec_reg_t; @@ -175,24 +171,24 @@ void GenSparseAdagrad::genSparseAdagrad( a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm()); } else if constexpr (instSet == inst_set_t::avx512) { - a->k(x86::k(1)).vmovups(g_vreg, g_ptr); + a->k(x86::k1).vmovups(g_vreg, g_ptr); if (has_weight_decay) { - a->k(x86::k(1)).vfmadd231ps(g_vreg, weight_decay_vreg, w_ptr); + a->k(x86::k1).vfmadd231ps(g_vreg, weight_decay_vreg, w_ptr); } - a->k(x86::k(1)).vmulps(out_vreg, g_vreg, g_vreg); - a->k(x86::k(1)).vaddps(out_vreg, out_vreg, h_ptr); + a->k(x86::k1).vmulps(out_vreg, g_vreg, g_vreg); + a->k(x86::k1).vaddps(out_vreg, out_vreg, h_ptr); - a->k(x86::k(1)).vmovups(h_ptr, out_vreg); + a->k(x86::k1).vmovups(h_ptr, out_vreg); - a->k(x86::k(1)).vsqrtps(out_vreg, out_vreg); - a->k(x86::k(1)).vaddps(out_vreg, out_vreg, epsilon_vreg); + a->k(x86::k1).vsqrtps(out_vreg, out_vreg); + a->k(x86::k1).vaddps(out_vreg, out_vreg, epsilon_vreg); - a->k(x86::k(1)).vmulps(g_vreg, lr_vreg, g_vreg); - a->k(x86::k(1)).vdivps(out_vreg, g_vreg, out_vreg); + a->k(x86::k1).vmulps(g_vreg, lr_vreg, g_vreg); + a->k(x86::k1).vdivps(out_vreg, g_vreg, out_vreg); - a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr); + a->k(x86::k1).vaddps(out_vreg, out_vreg, w_ptr); - a->k(x86::k(1)).vmovups(w_ptr, out_vreg); + a->k(x86::k1).vmovups(w_ptr, out_vreg); } } else { a->vmovups(g_vreg, g_ptr); @@ -226,11 +222,11 @@ void GenSparseAdagrad::genRowwiseSparseAdagrad( int num_vec_regs_per_block, int remainder, int prefetch, - const typename simd_info::vec_reg_t& epsilon_vreg, - const typename simd_info::vec_reg_t& lr_vreg, - const Ymm& mask_vreg, - const typename simd_info::vec_reg_t& temp_vreg, - const typename simd_info::vec_reg_t& weight_decay_vreg, + const x86::Vec& epsilon_vreg, + const x86::Vec& lr_vreg, + const x86::Vec& mask_vreg, + const x86::Vec& temp_vreg, + const x86::Vec& weight_decay_vreg, bool has_weight_decay) { using vec_reg_t = typename simd_info::vec_reg_t; constexpr int vlen = simd_info::WIDTH_32BIT_ELEMS; @@ -254,7 +250,7 @@ void GenSparseAdagrad::genRowwiseSparseAdagrad( a->imul( areIndices64b ? base_offset : base_offset.r32(), indices_ptr, - static_cast(block_size * sizeof(float))); + block_size * sizeof(float)); } // Even with avx512, we only need to use avx2 registers when computing @@ -264,9 +260,7 @@ void GenSparseAdagrad::genRowwiseSparseAdagrad( int num_vec_regs_per_block_avx2 = (block_size + vlen_avx2 - 1) / vlen_avx2; // Use YMM/XMMs with smaller ids for AVX2 specific instructions like vhaddps - Ymm partial_sum_vreg_avx2(0); - Xmm partial_sum_xmm0(partial_sum_vreg_avx2.id()); - + x86::Vec partial_sum_vreg_avx2 = x86::ymm0; a->vxorps( partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2); @@ -313,33 +307,18 @@ void GenSparseAdagrad::genRowwiseSparseAdagrad( a->vaddps(partial_sum_vreg_avx2, partial_sum_vreg_avx2, out_vreg); } } - // Reduce sum to 1 value - // __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum); - // __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2); - a->vhaddps( - partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2); - a->vhaddps( - partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2); - - Xmm partial_sum_xmm1(1); - //_mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) - a->movss(partial_sum_xmm1, partial_sum_xmm0); - //_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1)) - a->vextractf128(partial_sum_xmm0, partial_sum_vreg_avx2, 1); - - // final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) + - // _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1)); - a->addss(partial_sum_xmm0, partial_sum_xmm1); + // Reduce sum to a single value. + x86::Vec tmpXmm = x86::xmm1; + emitReduceAddF32(a, partial_sum_vreg_avx2, tmpXmm); // This fragment moves block size (N) to stack and bcasts it to xmm reg a->lea( x86::rsp, x86::dword_ptr(x86::rsp, -1 * static_cast(sizeof(int32_t)))); a->mov(x86::dword_ptr(x86::rsp), block_size); - a->vbroadcastss( - partial_sum_xmm1, x86::dword_ptr(x86::rsp)); // N is partial_sum_xmm1 - a->vcvtdq2ps(partial_sum_xmm1, partial_sum_xmm1); + a->vbroadcastss(tmpXmm, x86::dword_ptr(x86::rsp)); // N is partial_sum_xmm1 + a->vcvtdq2ps(tmpXmm, tmpXmm); a->lea(x86::rsp, x86::dword_ptr(x86::rsp, sizeof(int32_t))); if (has_weight_decay) { @@ -347,21 +326,21 @@ void GenSparseAdagrad::genRowwiseSparseAdagrad( a->imul( areIndices64b ? base_offset : base_offset.r32(), indices_ptr, - static_cast(sizeof(float))); + sizeof(float)); } // final_sum /= N - a->divss(partial_sum_xmm0, partial_sum_xmm1); + a->vdivss(partial_sum_vreg_avx2.xmm(), partial_sum_vreg_avx2.xmm(), tmpXmm); // load h - a->movss(partial_sum_xmm1, x86::dword_ptr(h, base_offset)); + a->vmovss(tmpXmm, x86::dword_ptr(h, base_offset)); // *h + final_sum - a->addss(partial_sum_xmm0, partial_sum_xmm1); + a->vaddss(partial_sum_vreg_avx2.xmm(), partial_sum_vreg_avx2.xmm(), tmpXmm); // store h - a->movss(x86::dword_ptr(h, base_offset), partial_sum_xmm0); + a->vmovss(x86::dword_ptr(h, base_offset), partial_sum_vreg_avx2.xmm()); // sqrt(hi) - a->sqrtss(partial_sum_xmm0, partial_sum_xmm0); + a->vsqrtss(partial_sum_vreg_avx2.xmm(), partial_sum_vreg_avx2.xmm(), partial_sum_vreg_avx2.xmm()); // bcast partial to all of ymm/zmm reg - a->vpbroadcastd(partial_sum_vreg, partial_sum_xmm0); + a->vpbroadcastd(partial_sum_vreg, partial_sum_vreg_avx2.xmm()); // lr / sqrt(hi) + epsilon a->vaddps(partial_sum_vreg, partial_sum_vreg, epsilon_vreg); a->vdivps(partial_sum_vreg, lr_vreg, partial_sum_vreg); @@ -371,7 +350,7 @@ void GenSparseAdagrad::genRowwiseSparseAdagrad( a->imul( areIndices64b ? base_offset : base_offset.r32(), indices_ptr, - static_cast(block_size * sizeof(float))); + block_size * sizeof(float)); for (int vec_idx = 0; vec_idx < num_vec_regs_per_block; vec_idx += unroll_factor) { @@ -406,14 +385,14 @@ void GenSparseAdagrad::genRowwiseSparseAdagrad( a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm()); } else { if (has_weight_decay) { - a->k(x86::k(1)).vmovups(out_vreg, g_ptr); - a->k(x86::k(1)).vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr); - a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, out_vreg); + a->k(x86::k1).vmovups(out_vreg, g_ptr); + a->k(x86::k1).vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr); + a->k(x86::k1).vmulps(out_vreg, partial_sum_vreg, out_vreg); } else { - a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, g_ptr); + a->k(x86::k1).vmulps(out_vreg, partial_sum_vreg, g_ptr); } - a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr); - a->k(x86::k(1)).vmovups(w_ptr, out_vreg); + a->k(x86::k1).vaddps(out_vreg, out_vreg, w_ptr); + a->k(x86::k1).vmovups(w_ptr, out_vreg); } } else { if (has_weight_decay) { @@ -444,8 +423,9 @@ GenSparseAdagrad::getOrCreate( kernelSig, [&]() -> typename ReturnFunctionSignature::jit_sparse_adagrad_kernel { + asmjit::JitRuntime& runtime = CodeStorage::getRuntime(); asmjit::CodeHolder code; - code.init(runtime().environment()); + code.init(runtime.environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); constexpr bool areIndices64b = std::is_same_v; @@ -465,28 +445,29 @@ GenSparseAdagrad::getOrCreate( } filename += ".txt"; FILE* codeLogFile = fopen(filename.c_str(), "w"); - auto codeLogger = std::make_unique(codeLogFile); - code.setLogger(codeLogger.get()); + + auto codeLogger = std::make_unique(codeLogFile); + code.set_logger(codeLogger.get()); #endif - auto num_rows = a->zdi().r32(); - x86::Gp param_size = a->zsi(); - w = a->zdx(); - g = a->zcx(); - h = a->gpz(8); - indices = a->gpz(9); - Xmm epsilon(0); - Xmm lr(1); - x86::Gp mask_avx2 = a->gpz(10); - Xmm weight_decay(2); - x86::Gp counter = a->gpz(11); - x86::Gp counter_halflife = a->gpz(12); + auto num_rows = x86::edi; + x86::Gp param_size = x86::rsi; + w = x86::rdx; + g = x86::rcx; + h = x86::r8; + indices = x86::r9; + x86::Vec epsilon = x86::xmm0; + x86::Vec lr = x86::xmm1; + x86::Vec weight_decay = x86::xmm2; + x86::Gp mask_avx2 = x86::r10; + x86::Gp counter = x86::r11; + x86::Gp counter_halflife = x86::r12; // reuse mask_avx2 because mask_avx2 is used only at the beginning - base_offset = a->gpz(10); - temp1_ = a->gpz(13); - temp2_ = a->gpz(14); - temp3_ = a->gpz(15); + base_offset = x86::r10; + temp1_ = x86::r13; + temp2_ = x86::r14; + temp3_ = x86::r15; asmjit::FuncDetail func; func.init( @@ -510,25 +491,25 @@ GenSparseAdagrad::getOrCreate( frame.init(func); if constexpr (instSet == inst_set_t::avx2) { - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); } else { - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | - asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | - asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); + asmjit::Support::bit_mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15) | + asmjit::Support::bit_mask(16, 17, 18, 19, 20, 21, 22, 23) | + asmjit::Support::bit_mask(24, 25, 26, 27, 28, 29, 30, 31)); } - frame.setDirtyRegs( + frame.set_dirty_regs( asmjit::RegGroup::kGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::Support::bit_mask(8, 9, 10, 11, 12, 13, 14, 15)); asmjit::FuncArgsAssignment args(&func); - args.assignAll( + args.assign_all( num_rows, param_size, w, @@ -542,10 +523,10 @@ GenSparseAdagrad::getOrCreate( counter, counter_halflife); - args.updateFuncFrame(frame); + args.update_func_frame(frame); frame.finalize(); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + a->emit_prolog(frame); + a->emit_args_assignment(frame, args); constexpr int vlen = simd_info::WIDTH_32BIT_ELEMS; constexpr int NUM_VEC_REG = simd_info::NUM_VEC_REGS; @@ -560,7 +541,7 @@ GenSparseAdagrad::getOrCreate( vec_reg_t lr_vreg; vec_reg_t weight_decay_vreg; vec_reg_t adjusted_weight_decay_vreg; - Ymm mask_vreg; // mask for avx2 + x86::Vec mask_vreg; // mask for avx2 vec_reg_t temp_vreg; // temp vreg for avx2 to handle remainder computation @@ -584,18 +565,18 @@ GenSparseAdagrad::getOrCreate( // Creating masks for non multiples of vlen iterations if constexpr (instSet == inst_set_t::avx2) { --unroll_factor; - mask_vreg = Ymm(unroll_factor); + mask_vreg = x86::ymm(unroll_factor); a->vmovups(mask_vreg, x86::dword_ptr(mask_avx2)); } else { a->mov(temp1_, (1 << remainder) - 1); - a->kmovw(x86::k(1), temp1_); + a->kmovw(x86::k1, temp1_); } } // Need an extra mask for computing sum of gradients int remainder_avx2 = block_size % simd_info::WIDTH_32BIT_ELEMS; if (remainder_avx2 && instSet == inst_set_t::avx512 && rowwise) { - reduce_mask_avx512_ = x86::k(2); + reduce_mask_avx512_ = x86::k2; a->mov(temp1_, (1 << remainder_avx2) - 1); a->kmovw(reduce_mask_avx512_, temp1_); } @@ -604,9 +585,9 @@ GenSparseAdagrad::getOrCreate( unroll_factor = unroll_factor / 2; // accont for g_vreg } - asmjit::Label exit = a->newLabel(); - asmjit::Label LoopRangeIndexBegin = a->newLabel(); - asmjit::Label LoopRangeIndexEnd = a->newLabel(); + asmjit::Label exit = a->new_label(); + asmjit::Label LoopRangeIndexBegin = a->new_label(); + asmjit::Label LoopRangeIndexEnd = a->new_label(); a->vpbroadcastd(epsilon_vreg, epsilon); a->vpbroadcastd(lr_vreg, lr); @@ -628,8 +609,7 @@ GenSparseAdagrad::getOrCreate( a->imul( areIndices64b ? base_offset : base_offset.r32(), indices_ptr, - static_cast( - (rowwise ? 1 : block_size) * sizeof(float))); + (rowwise ? 1 : block_size) * sizeof(float)); // Perform this check // if (block_size + offsetIdx > param_size) { @@ -645,7 +625,7 @@ GenSparseAdagrad::getOrCreate( // Check counter != nullptr && counter[idx] > 0 a->vmovaps(adjusted_weight_decay_vreg, weight_decay_vreg); - asmjit::Label skip_adjust_freq = a->newLabel(); + asmjit::Label skip_adjust_freq = a->new_label(); a->cmp(counter, 0); a->je(skip_adjust_freq); @@ -658,10 +638,10 @@ GenSparseAdagrad::getOrCreate( // OK to use Xmm registers with small ids that are reserved for temp // values in the inner most loop. vec_reg_t counter_halflife_vreg(0); - Xmm counter_vreg(1); - a->cvtsi2sd(counter_halflife_vreg.xmm(), counter_halflife); - a->movq(counter_vreg, temp3_); - a->divpd(counter_halflife_vreg.xmm(), counter_vreg); + x86::Vec counter_vreg = x86::xmm1; + a->vcvtsi2sd(counter_halflife_vreg.xmm(), counter_halflife_vreg.xmm(), counter_halflife); + a->vmovq(counter_vreg, temp3_); + a->vdivpd(counter_halflife_vreg.xmm(), counter_halflife_vreg.xmm(), counter_vreg); a->vcvtpd2ps( counter_halflife_vreg.xmm(), counter_halflife_vreg.ymm()); a->vbroadcastss(counter_halflife_vreg, counter_halflife_vreg.xmm()); @@ -674,15 +654,13 @@ GenSparseAdagrad::getOrCreate( } a->inc(temp2_); - a->imul( - temp2_, - static_cast(block_size)); //(offsetIdx+1)*blocksize + a->imul(temp2_, block_size); //(offsetIdx+1)*blocksize a->cmp(temp2_, param_size); a->jg(exit); if (prefetch) { - asmjit::Label pref_dist_reset_start = a->newLabel(); - asmjit::Label pref_dist_reset_end = a->newLabel(); + asmjit::Label pref_dist_reset_start = a->new_label(); + asmjit::Label pref_dist_reset_end = a->new_label(); a->mov(temp2_, temp1_); a->add(temp2_, prefetch); @@ -696,12 +674,12 @@ GenSparseAdagrad::getOrCreate( a->imul( areIndices64b ? temp3_ : temp3_.r32(), pref_indices_ptr, - static_cast(sizeof(float))); + sizeof(float)); } a->imul( areIndices64b ? temp2_ : temp2_.r32(), pref_indices_ptr, - static_cast(block_size * sizeof(float))); + block_size * sizeof(float)); a->jmp(pref_dist_reset_end); @@ -709,12 +687,12 @@ GenSparseAdagrad::getOrCreate( a->imul( areIndices64b ? temp2_ : temp2_.r32(), indices_ptr, - static_cast(block_size * sizeof(float))); + block_size * sizeof(float)); if (rowwise) { a->imul( areIndices64b ? temp3_ : temp3_.r32(), indices_ptr, - static_cast(sizeof(float))); + sizeof(float)); } a->bind(pref_dist_reset_end); @@ -749,30 +727,24 @@ GenSparseAdagrad::getOrCreate( has_weight_decay); } - a->add(g, static_cast(block_size * sizeof(float))); + a->add(g, block_size * sizeof(float)); a->inc(temp1_); a->jmp(LoopRangeIndexBegin); a->bind(LoopRangeIndexEnd); a->bind(exit); a->mov(x86::eax, temp1_.r32()); - a->emitEpilog(frame); + a->emit_epilog(frame); typename ReturnFunctionSignature::jit_sparse_adagrad_kernel fn; - asmjit::Error err = 0; - { - std::unique_lock lock(rtMutex_); - err = runtime().add(&fn, &code); - } - if (err) { + asmjit::Error err = runtime.add(&fn, &code); + + if (err != asmjit::Error::kOk) { std::cout << "Error: in fn add" << '\n'; return nullptr; } -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogFile); -#endif return fn; }); } // getOrCreate