Skip to content

Commit 573a619

Browse files
authored
Transpose matmul kernel (#278)
* vnni kernel comment update * transposed matmul kernel * fix heap-buffer-overflow * dense_matmul => transpose_matmul * add benchmark * add ci * update ci cases * use preamble/postamble * make gcc happy * chore * use default value of struc def * make cpplint happy * more benchmark cases * avoid memleak of benchmark * make cpplint happy * set default value of raw pointers * typo * fix copyright
1 parent fbbc221 commit 573a619

39 files changed

+1499
-37
lines changed

nlp_toolkit/backends/neural_engine/SparseLib/include/interface.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ class sparse_matmul_desc : public kernel_desc_proxy {
101101
virtual ~sparse_matmul_desc() {}
102102
};
103103

104+
class transpose_matmul_desc : public kernel_desc_proxy {
105+
public:
106+
transpose_matmul_desc() {}
107+
explicit transpose_matmul_desc(const operator_desc& op_desc) : kernel_desc_proxy(op_desc) {}
108+
virtual ~transpose_matmul_desc() {}
109+
};
110+
104111
class postop_desc : public kernel_desc_proxy {
105112
public:
106113
postop_desc() {}
@@ -131,6 +138,14 @@ class sparse_matmul : public kernel_proxy {
131138
explicit sparse_matmul(const kernel_desc_proxy& kdp) : kernel_proxy(kdp) {}
132139
virtual ~sparse_matmul() {}
133140
};
141+
142+
class transpose_matmul : public kernel_proxy {
143+
public:
144+
transpose_matmul() {}
145+
explicit transpose_matmul(const kernel_desc_proxy& kdp) : kernel_proxy(kdp) {}
146+
virtual ~transpose_matmul() {}
147+
};
148+
134149
class postop : public kernel_proxy {
135150
public:
136151
postop() {}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Copyright (c) 2022 Intel Corporation
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef ENGINE_SPARSELIB_INCLUDE_JIT_DOMAIN_JIT_MATMU_AVX512F_P2031_p2013_HPP_
16+
#define ENGINE_SPARSELIB_INCLUDE_JIT_DOMAIN_JIT_MATMU_AVX512F_P2031_p2013_HPP_
17+
18+
#include "jit_generator.hpp"
19+
#include "kernels/matmul_types.hpp"
20+
#include "utils.hpp"
21+
22+
#define GET_OFF(field) offsetof(ssd::matmul_data_t, field)
23+
24+
namespace jd {
25+
/**
26+
* @brief jit_matmul_avx512f_p2031_p2013_t calculates this kind matmul: alpha * src0 x src1 + beta * src2 = dst.
27+
* alpha * src0(M, K) x src1(K, N) + beta * scr2(M, N) = dst(M, N)
28+
*/
29+
class jit_matmul_avx512f_p2031_p2013_t : public jit_generator {
30+
public:
31+
explicit jit_matmul_avx512f_p2031_p2013_t(const ssd::matmul_param_t& param)
32+
: jit_generator(),
33+
param_(param),
34+
TH_(param.m_tile),
35+
TW_(param.n_tile),
36+
ld_src0(param.M * param.batch * dsize_src0),
37+
ld_src1(param.N * param.batch * dsize_src1),
38+
ld_src2(param.N * dsize_src2),
39+
ld_dst(param.N * dsize_dst),
40+
k_iters(param.K / UNROLL_K) {}
41+
virtual ~jit_matmul_avx512f_p2031_p2013_t() {}
42+
43+
private:
44+
ssd::matmul_param_t param_;
45+
46+
void generate() override;
47+
void calc_THxkxTW();
48+
Xbyak::Zmm TH_Vmm(int i = 0); // Register allocator of load weight. 1D shape=(TH)
49+
Xbyak::Zmm TW_Vmm(int i = 0); // Register allocator of load activation. 1D shape=(TW)
50+
Xbyak::Zmm dst_tile_Vmm(int i = 0, int j = 0); // Reg alloc of DST tile. 2D shape=(TH,TW), stride=(TW,1)
51+
52+
const int TH_; // tile height (along m) in terms of #registers
53+
const int TW_; // tile width (along n) in terms of #registers
54+
static constexpr size_t dsize_src0 = sizeof(decltype(*ssd::matmul_data_t::src0));
55+
static constexpr size_t dsize_src1 = sizeof(decltype(*ssd::matmul_data_t::src1));
56+
static constexpr size_t dsize_src2 = sizeof(decltype(*ssd::matmul_data_t::src2));
57+
static constexpr size_t dsize_dst = sizeof(decltype(*ssd::matmul_data_t::dst));
58+
// leading dimension in #bytes
59+
const int ld_src0, ld_src1, ld_src2, ld_dst;
60+
const int k_iters;
61+
62+
const Xbyak::Zmm& vreg_temp = zmm31;
63+
static constexpr int VREG_NUMS = 32;
64+
static constexpr int USED_VREGS = 1;
65+
static constexpr int UNROLL_K = 8;
66+
static constexpr int BYTES_ZMM = 64;
67+
68+
const Xbyak::Reg64& parambase = rdi;
69+
const Xbyak::Reg64& reg_src0 = rsi;
70+
const Xbyak::Reg64& reg_src1 = rdx;
71+
const Xbyak::Reg64& reg_src2 = rcx;
72+
const Xbyak::Reg64& reg_dst = r8;
73+
const Xbyak::Reg64& reg_src0_end = r9;
74+
const Xbyak::Reg64& reg_src1_end = r10;
75+
const Xbyak::Reg64& reg_iterk = r11;
76+
const Xbyak::Reg64& reg_tmp = rbx;
77+
};
78+
} // namespace jd
79+
#endif // ENGINE_SPARSELIB_INCLUDE_JIT_DOMAIN_JIT_MATMU_AVX512F_P2031_p2013_HPP_

nlp_toolkit/backends/neural_engine/SparseLib/include/kernel_hashing.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,16 @@ class hash_t {
8080
hash_combine(seed, op_attrs["append_sum"]);
8181
hash_combine(seed, op_attrs["sub_func"]);
8282
break;
83-
// todo:remove it.
8483
case kernel_kind::postop:
8584
case kernel_kind::eltwiseop:
8685
case kernel_kind::layernorm_ba:
8786
hash_combine(seed, op_attrs["matrix_shape"]);
8887
break;
88+
case kernel_kind::transpose_matmul:
89+
hash_combine(seed, op_attrs["alpha"]);
90+
hash_combine(seed, op_attrs["beta"]);
91+
hash_combine(seed, op_attrs["m_tile"]);
92+
hash_combine(seed, op_attrs["n_tile"]);
8993
default:
9094
break;
9195
}

nlp_toolkit/backends/neural_engine/SparseLib/include/kernels/eltwiseop.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class eltwiseop_k_t : public kernel_t {
8888
bool eltwiseop_kernel_create(jit_eltwiseop_t** ker_pp, const ssd::eltwiseop_param_t& param);
8989

9090
private:
91-
jit_eltwiseop_t* jit_kers_;
91+
jit_eltwiseop_t* jit_kers_ = nullptr;
9292
int64_t nthr_;
9393
std::vector<ssd::eltwiseop_data_t*> td;
9494
};
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Copyright (c) 2022 Intel Corporation
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef ENGINE_SPARSELIB_INCLUDE_KERNELS_MATMUL_AVX512F_P2031_P2013_HPP_
16+
#define ENGINE_SPARSELIB_INCLUDE_KERNELS_MATMUL_AVX512F_P2031_P2013_HPP_
17+
18+
#include <glog/logging.h>
19+
#include <memory>
20+
#include <vector>
21+
#include <algorithm>
22+
23+
#include "jit_domain/jit_matmul_avx512f_p2031_p2013.hpp"
24+
#include "kernel.hpp"
25+
#include "kernel_desc.hpp"
26+
27+
namespace jd {
28+
// By convention,
29+
// 1. xxxx_kd_t is the descriptor of a specific derived primitive/kernel.
30+
// 2. xxxx_k_t is a specific derived primitive/kernel.
31+
// 3. jit_xxxx_t is JIT assembly implementation of a specific derived
32+
// primitive/kernel. where, "xxxx" represents an algorithm, such as brgemm,
33+
// GEMM and so on.
34+
class matmul_avx512f_p2031_p2013_k_t;
35+
36+
/**
37+
* @brief a derived kernel descriptor. matmul_param_t is its class member.
38+
*/
39+
class matmul_avx512f_p2031_p2013_kd_t : public kernel_desc_t {
40+
public:
41+
explicit matmul_avx512f_p2031_p2013_kd_t(const jd::operator_desc& op_desc)
42+
: kernel_desc_t(kernel_kind::sparse_matmul), op_desc_(op_desc) {}
43+
virtual ~matmul_avx512f_p2031_p2013_kd_t() {}
44+
45+
bool init() override;
46+
47+
// kernel_desc_t::create_primitive() override.
48+
DECLARE_COMMON_PD_T(matmul_avx512f_p2031_p2013_k_t, matmul_avx512f_p2031_p2013_kd_t);
49+
50+
const jd::operator_desc& operator_desc() const override { return op_desc_; }
51+
const ssd::matmul_param_t& jit_param() const { return jit_param_; }
52+
53+
inline std::vector<dim_t> shape() const {
54+
std::vector<dim_t> result(op_desc_.tensor_descs()[ssd::SRC0].shape());
55+
result.push_back(op_desc_.tensor_descs()[ssd::SRC0].shape().back());
56+
return result;
57+
}
58+
59+
private:
60+
bool matmul_params_init(const jd::operator_desc& op_desc);
61+
62+
jd::operator_desc op_desc_;
63+
ssd::matmul_param_t jit_param_;
64+
};
65+
66+
/**
67+
* @brief a derived kernel. kd_t and jit_domain are its class members.
68+
*/
69+
class matmul_avx512f_p2031_p2013_k_t : public kernel_t {
70+
public:
71+
using kd_t = matmul_avx512f_p2031_p2013_kd_t;
72+
explicit matmul_avx512f_p2031_p2013_k_t(const std::shared_ptr<const kd_t>& kd);
73+
virtual ~matmul_avx512f_p2031_p2013_k_t() {
74+
if (jit_ker_ != nullptr) {
75+
delete jit_ker_;
76+
jit_ker_ = nullptr;
77+
}
78+
}
79+
80+
// Delete move constructor and move operator
81+
matmul_avx512f_p2031_p2013_k_t(matmul_avx512f_p2031_p2013_k_t&& other) = delete;
82+
matmul_avx512f_p2031_p2013_k_t& operator=(matmul_avx512f_p2031_p2013_k_t&& other) = delete;
83+
// Delete copy constructor and copy operator
84+
matmul_avx512f_p2031_p2013_k_t(const matmul_avx512f_p2031_p2013_k_t& other) = delete;
85+
matmul_avx512f_p2031_p2013_k_t& operator=(const matmul_avx512f_p2031_p2013_k_t& other) = delete;
86+
87+
bool init() override;
88+
bool execute(const std::vector<const void*>& rt_data) const override;
89+
const std::shared_ptr<const kd_t> derived_kd() const { return std::static_pointer_cast<const kd_t>(kd_); }
90+
91+
private:
92+
bool matmul_kernel_create(jit_matmul_avx512f_p2031_p2013_t** ker_pp, const ssd::matmul_param_t& param);
93+
94+
private:
95+
jit_matmul_avx512f_p2031_p2013_t* jit_ker_ = nullptr;
96+
const std::vector<std::vector<dim_t>> t_shapes_;
97+
const std::vector<dim_t> src0_perm_shape_; // src0 shape after perm2031
98+
const std::vector<dim_t> src1_perm_shape_; // src1 shape after perm2013
99+
const dim_t M_, K_, N_; // dim of matrix multiplication
100+
const dim_t bs0_; // outer batch size dim
101+
const dim_t bs1_; // innter batch size dim
102+
};
103+
104+
} // namespace jd
105+
106+
#endif // ENGINE_SPARSELIB_INCLUDE_KERNELS_MATMUL_AVX512F_P2031_P2013_HPP_
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) 2022 Intel Corporation
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef ENGINE_SPARSELIB_INCLUDE_KERNELS_MATMUL_TYPES_HPP_
16+
#define ENGINE_SPARSELIB_INCLUDE_KERNELS_MATMUL_TYPES_HPP_
17+
18+
#include <cstdint>
19+
#include <vector>
20+
21+
#include "param_types.hpp"
22+
#include "../utils.hpp"
23+
namespace jd {
24+
namespace ssd {
25+
/**
26+
* @brief tensors index configuration of this kernel.
27+
* TODO(Yi): potential confliction with indices of other op types
28+
*/
29+
static constexpr int SRC0 = 0;
30+
static constexpr int SRC1 = 1;
31+
static constexpr int DST0 = 2;
32+
static constexpr int SRC2 = 3; // for binary add
33+
34+
struct matmul_param_t {
35+
dim_t M;
36+
dim_t N;
37+
dim_t K;
38+
dim_t batch; // leading dim is `batch` times of its num_cols
39+
float alpha = 1.f, beta = 1.f; // alpha * (src0 * src1) + beta * src_binary_add = dst
40+
dim_t m_tile = 8;
41+
dim_t n_tile = 2;
42+
};
43+
44+
struct matmul_data_t {
45+
const float* src0;
46+
const float* src1;
47+
float* dst;
48+
const float* src2;
49+
};
50+
51+
} // namespace ssd
52+
} // namespace jd
53+
#endif // ENGINE_SPARSELIB_INCLUDE_KERNELS_MATMUL_TYPES_HPP_

nlp_toolkit/backends/neural_engine/SparseLib/include/kernels/sparse_data.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include <utility>
1919
#include <vector>
2020

21-
#include "kernels/spmm_types.hpp"
2221
#include "param_types.hpp"
2322
#include "utils.hpp"
2423

nlp_toolkit/backends/neural_engine/SparseLib/include/kernels/spmm_types.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class bsr_data_t;
3030
namespace ssd {
3131
/**
3232
* @brief tensors index configuration of this kernel.
33+
* TODO(Yi): potential confliction with indices of other op types
3334
*/
3435
static constexpr int WEI = 0;
3536
static constexpr int SRC = 1;
@@ -52,7 +53,7 @@ enum class subfunc_level : uint8_t {
5253
prod, // use sub-function for tile product
5354
dense_and_prod, // use fused sub-function for dense loading & tile product
5455
load_and_prod, // use fused sub-function for dense loading & sparse loading & tile product
55-
k_dims, // a whole THxKxTW tile generates a constent size of code
56+
k_dims, // a whole THxKxTW tile generates a constent size of code
5657
subfunc_level_MAX = k_dims
5758
};
5859

nlp_toolkit/backends/neural_engine/SparseLib/include/param_types.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include <map>
2121
namespace jd {
2222
// The main kinds of kernel.
23-
enum class kernel_kind : uint8_t { undef, sparse_matmul, postop, eltwiseop, layernorm_ba };
23+
enum class kernel_kind : uint8_t { undef, sparse_matmul, postop, eltwiseop, layernorm_ba, transpose_matmul };
2424

2525
enum class postop_alg : uint8_t { undef, exp, tanh, gelu, relu, quantize, dequantize, linear, int8_lut };
2626

nlp_toolkit/backends/neural_engine/SparseLib/src/cpu_engine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ const std::vector<impl_list_item_t>* cpu_engine::get_implementation_list(const o
2525
DECLARE_IMPL_LIST(postop);
2626
DECLARE_IMPL_LIST(eltwiseop);
2727
DECLARE_IMPL_LIST(layernorm_ba);
28+
DECLARE_IMPL_LIST(transpose_matmul);
2829

2930
#undef DECLARE_IMPL_LIST
3031

@@ -38,6 +39,7 @@ const std::vector<impl_list_item_t>* cpu_engine::get_implementation_list(const o
3839
CASE(postop);
3940
CASE(eltwiseop);
4041
CASE(layernorm_ba);
42+
CASE(transpose_matmul);
4143
default:
4244
return &cpu_engine::empty_list;
4345
}

0 commit comments

Comments
 (0)