Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ endfunction()
project(
fbgemm
VERSION 1.4.0
LANGUAGES CXX)
LANGUAGES C CXX)

# Add C++ compiler flag detection
include(CheckCXXCompilerFlag)
Expand Down
1 change: 1 addition & 0 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_fbgemm_base_srcs():

def get_fbgemm_generic_srcs(with_base = False, msvc = False, buck = False):
sources = [
"src/CodeStorage.cc",
"src/EmbeddingSpMDM.cc",
"src/EmbeddingSpMDMNBit.cc",
"src/ExecuteKernel.cc",
Expand Down
2 changes: 1 addition & 1 deletion external/asmjit
Submodule asmjit updated 300 files
240 changes: 123 additions & 117 deletions src/CodeGenHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,61 +15,89 @@ namespace fbgemm {

namespace x86 = asmjit::x86;

class FileLoggerWithClose : public asmjit::FileLogger {
explicit FileLoggerWithClose(FILE* f)
: FileLogger(f) {}

~FileLoggerWithClose() noexcept override {
if (_file) {
fclose(_file);
}
}
};

/**
* @brief Emit an instruction that fills a vector register with
* all ones.
*
* @param dest Destination register to fill
*/
template <inst_set_t instSet>
void emitVecFillWithOnes(x86::Emitter* a, const x86::Vec& dest) {
if constexpr (instSet == inst_set_t::avx2) {
a->vpcmpeqb(dest, dest, dest);
} else {
a->vpternlogd(dest, dest, dest, 0xff);
}
}

/**
* @brief Create instruction sequence to generate 8-bit 1s
*
* @param dest Once the instruction sequence is executed,
* dest[0:7] will have 0x01, dest[8:15]
* will have 0x01 and so on
*/
template <inst_set_t instSet>
void gen8BitVectorOne(x86::Emitter* a, const x86::Vec& dest) {
emitVecFillWithOnes<instSet>(a, dest);
a->vpabsb(dest, dest);
}

/**
* @brief Create instruction sequence to generate 16-bit 1s
* @tparam T Register type of destination, e.g., Ymm or Zmm
*
* @param dest Once the instruction sequence is executed,
* dest[0:15] will have 0x0001, dest[16:31]
* will have 0x0001 and so on
*/
template <
inst_set_t instSet,
typename T,
std::enable_if_t<instSet == inst_set_t::avx2, int> = 0>
void gen16BitVectorOne(x86::Emitter* a, T dest) {
a->vpcmpeqw(dest, dest, dest);
a->vpsrlw(dest, dest, 15);
}

template <
inst_set_t instSet,
typename T,
std::enable_if_t<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int> = 0>
void gen16BitVectorOne(x86::Emitter* a, T dest) {
a->vpternlogd(dest, dest, dest, 0xff);
template <inst_set_t instSet>
void gen16BitVectorOne(x86::Emitter* a, const x86::Vec& dest) {
emitVecFillWithOnes<instSet>(a, dest);
a->vpsrlw(dest, dest, 15);
}

/**
* @brief Emit instruction do load 32-bit integer. AVX512 has
* different instrunction to load registers with index >= 16
* @tparam T Register type of destination, e.g., Ymm or Zmm
* different instruction to load registers with index >= 16
*
* @param dest Destination vector register
*/
template <
inst_set_t instSet,
typename T,
std::enable_if_t<instSet == inst_set_t::avx2, int> = 0>
void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa(dest, ptr);
template <inst_set_t instSet, typename Dest, typename Src>
void emitVecMove(x86::Emitter* a, const Dest& dest, const Src& src) {
if constexpr (instSet == inst_set_t::avx2) {
a->vmovdqa(dest, src);
} else {
a->vmovdqa32(dest, src);
}
}

template <
inst_set_t instSet,
typename T,
std::enable_if_t<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int> = 0>
void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa32(dest, ptr);
template <inst_set_t instSet, typename Dest, typename Src1, typename Src2>
void emitVecOr(x86::Emitter* a, const Dest& dest, const Src1& src1, const Src2& src2) {
if constexpr (instSet == inst_set_t::avx2) {
a->vpor(dest, src1, src2);
} else {
a->vpord(dest, src1, src2);
}
}

template <inst_set_t instSet, typename Dest, typename Src1, typename Src2>
void emitVecXor(x86::Emitter* a, const Dest& dest, const Src1& src1, const Src2& src2) {
if constexpr (instSet == inst_set_t::avx2) {
a->vpxor(dest, src1, src2);
} else {
a->vpxord(dest, src1, src2);
}
}

/**
Expand All @@ -81,68 +109,22 @@ void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
* @param vec Source (full) vector register
* @param idx Index of of the half vector 0 or 1
*/
template <
inst_set_t instSet,
typename T,
std::enable_if_t<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int> = 0>
void emitExtractHalfVector(
x86::Emitter* a,
const Ymm& half,
const Zmm& vec,
int idx) {
a->vextracti32x8(half, vec, idx);
}

template <
inst_set_t instSet,
typename T,
std::enable_if_t<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int> = 0>
template <inst_set_t instSet>
void emitExtractHalfVector(
x86::Emitter* a,
const Xmm& half,
const Ymm& vec,
const x86::Vec& half,
const x86::Vec& vec,
int idx) {
a->vextracti32x4(half, vec, idx);
}

template <
inst_set_t instSet,
typename T,
std::enable_if_t<instSet == inst_set_t::avx2, int> = 0>
void emitExtractHalfVector(
x86::Emitter* a,
const Xmm& half,
const Ymm& vec,
int idx) {
a->vextracti128(half, vec, idx);
}

/**
* @brief Create instruction sequence to generate 8-bit 1s
* @tparam T Register type of destination, e.g., Ymm or Zmm
*
* @param dest Once the instruction sequence is executed,
* dest[0:7] will have 0x01, dest[8:15]
* will have 0x01 and so on
*/
template <typename T, std::enable_if_t<std::is_same_v<T, Ymm>, int> = 0>
void gen8BitVectorOne(x86::Emitter* a, T dest) {
a->vpcmpeqw(dest, dest, dest);
a->vpabsb(dest, dest);
}

template <typename T, std::enable_if_t<std::is_same_v<T, Zmm>, int> = 0>
void gen8BitVectorOne(x86::Emitter* a, T dest) {
a->vpternlogd(dest, dest, dest, 0xff);
a->vpabsb(dest, dest);
if constexpr (instSet == inst_set_t::avx2) {
a->vextracti128(half, vec, idx);
} else {
if (vec.is_vec512()) {
a->vextracti32x8(half, vec, idx);
} else {
a->vextracti32x4(half, vec, idx);
}
}
}

/**
Expand All @@ -160,11 +142,11 @@ template <
int> = 0>
void genU8I8S32FMA(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t aReg,
typename simd_info<INST_SET>::vec_reg_t bReg,
typename simd_info<INST_SET>::vec_reg_t cReg,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
const x86::Vec& aReg,
const x86::Vec& bReg,
const x86::Vec& cReg,
const x86::Vec& oneReg16Bit,
const x86::Vec& tmpReg) {
a->vpmaddubsw(tmpReg, aReg, bReg);
a->vpmaddwd(tmpReg, oneReg16Bit, tmpReg);
a->vpaddd(cReg, tmpReg, cReg);
Expand All @@ -175,11 +157,11 @@ template <
std::enable_if_t<INST_SET == inst_set_t::avx512_vnni, int> = 0>
void genU8I8S32FMA(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t aReg,
typename simd_info<INST_SET>::vec_reg_t bReg,
typename simd_info<INST_SET>::vec_reg_t cReg,
typename simd_info<INST_SET>::vec_reg_t /*oneReg16Bit*/,
typename simd_info<INST_SET>::vec_reg_t /*tmpReg*/) {
const x86::Vec& aReg,
const x86::Vec& bReg,
const x86::Vec& cReg,
const x86::Vec& /*oneReg16Bit*/,
const x86::Vec& /*tmpReg*/) {
a->vpdpbusd(cReg, aReg, bReg);
}

Expand All @@ -200,11 +182,11 @@ template <
int> = 0>
void genU8Sum4(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t src,
typename simd_info<INST_SET>::vec_reg_t dest,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
gen8BitVectorOne(a, tmpReg);
const x86::Vec& src,
const x86::Vec& dest,
const x86::Vec& oneReg16Bit,
const x86::Vec& tmpReg) {
gen8BitVectorOne<INST_SET>(a, tmpReg);
a->vpmaddubsw(tmpReg, src, tmpReg);
a->vpmaddwd(tmpReg, tmpReg, oneReg16Bit);
a->vpaddd(dest, tmpReg, dest);
Expand All @@ -220,11 +202,11 @@ template <
std::enable_if_t<INST_SET == inst_set_t::avx512_vnni, int> = 0>
void genU8Sum4(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t src,
typename simd_info<INST_SET>::vec_reg_t dest,
typename simd_info<INST_SET>::vec_reg_t /*oneReg16Bit*/,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
gen8BitVectorOne(a, tmpReg);
const x86::Vec& src,
const x86::Vec& dest,
const x86::Vec& /*oneReg16Bit*/,
const x86::Vec& tmpReg) {
gen8BitVectorOne<INST_SET>(a, tmpReg);
a->vpdpbusd(dest, src, tmpReg);
}

Expand Down Expand Up @@ -263,8 +245,32 @@ template <typename T>
void broadcast8Bit(x86::Emitter* a, x86::Gp src, T dest) {
// move src to dest
auto xmm = dest.xmm();
a->movq(xmm, src);
a->vmovq(xmm, src);
a->vpbroadcastb(dest, xmm);
}

template <inst_set_t instSet>
void emitReduceAddF32(x86::Emitter* a, x86::Vec v, const x86::Vec& tmp) {
if constexpr (instSet != inst_set_t::avx2) {
if (v.is_vec512()) {
a->vextractf32x8(tmp.ymm(), v, 1);
v = v.ymm();
a->vaddps(v.ymm(), v, tmp);
}
}

if (v.is_vec256()) {
if constexpr (instSet == inst_set_t::avx2) {
a->vextractf128(tmp.xmm(), v, 1);
} else {
a->vextractf32x4(tmp.xmm(), v, 1);
}
v = v.xmm();
a->vaddps(v, v, tmp.xmm());
}

a->vhaddps(v, v, v);
a->vhaddps(v, v, v);
}

} // namespace fbgemm
23 changes: 23 additions & 0 deletions src/CodeStorage.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "./CodeStorage.h"

namespace fbgemm {
namespace CodeStorage {

asmjit::JitRuntime& getRuntime() {
// JIT Runtime for asmjit, depends on other static variables.
// Required to prevent initialization order fiasco
static asmjit::JitRuntime rt;

return rt;
}

} // namespace CodeStorage
} // namespace fbgemm
19 changes: 19 additions & 0 deletions src/CodeStorage.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <asmjit/core.h> // @manual

namespace fbgemm {
namespace CodeStorage {

asmjit::JitRuntime& getRuntime();

} // namespace CodeStorage
} // namespace fbgemm
11 changes: 0 additions & 11 deletions src/DirectConv.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ class DirectConvCodeGenBase {
int o1Xoc,
int i1);

inline static std::mutex rtMutex_; ///< Control access to runtime;

// The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr
inline static CodeCache<
std::tuple<bool, int, int, int, int, int, int>,
Expand Down Expand Up @@ -194,15 +192,6 @@ class DirectConvCodeGenBase {
const x86::Gp& C_offset,
int rowRegs,
int colRegs);

private:
static asmjit::JitRuntime& runtime() {
static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
// depents on other static
// variables. Required to prevent
// initialization order fiasco
return rt;
}
};

} // namespace fbgemm
Loading