Skip to content

Commit 96e3514

Browse files
Jira-61/support eltwiseop-fusion (#150)
* eltwiseop kernel with eltwise_injector, support exp, tanh, gelu, relu, quantize(fp32->u8) and dequantize(u8->fp32)
1 parent ae95707 commit 96e3514

19 files changed

+1793
-55
lines changed

nlp_toolkit/backends/neural_engine/SparseLib/include/interface.hpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,15 @@ namespace jd {
3434
* @brief Proxy pattern. The proxy could interface to anything.
3535
* Similar to onednn's "struct handle". oneapi/dnnl/dnnl.hpp:136.
3636
*/
37-
template<typename T, typename arg_t = void>
37+
template <typename T, typename arg_t = void>
3838
class proxy_base {
3939
public:
4040
proxy_base() {}
4141
virtual ~proxy_base() {}
4242

4343
public:
44-
inline void reset_sp(const std::shared_ptr<const T>& sp) {
45-
data_handle_ = sp;
46-
}
47-
inline const std::shared_ptr<const T>& get_sp() const {
48-
return data_handle_;
49-
}
44+
inline void reset_sp(const std::shared_ptr<const T>& sp) { data_handle_ = sp; }
45+
inline const std::shared_ptr<const T>& get_sp() const { return data_handle_; }
5046

5147
protected:
5248
// internal functions of creat the proxy object.
@@ -86,14 +82,13 @@ class kernel_proxy : public proxy_base<kernel_t, std::shared_ptr<const kernel_de
8682

8783
protected:
8884
bool create_proxy_object(std::shared_ptr<const kernel_t>& result_ref,
89-
const std::shared_ptr<const kernel_desc_t>& kd) override;
85+
const std::shared_ptr<const kernel_desc_t>& kd) override;
9086

9187
public:
9288
inline const jd::kernel_kind& kernel_kind() const { return get_sp()->kd()->kernel_kind(); }
9389
void execute(const std::vector<const void*>& rt_data);
9490
};
9591

96-
9792
//// The following paragraphs are the various derived kernels and its descriptors.
9893
/**
9994
* @brief Derived proxy class, interfacing to the real/cached sparse_matmul_desc_t.
@@ -111,6 +106,14 @@ class postop_desc : public kernel_desc_proxy {
111106
explicit postop_desc(const operator_desc& op_desc) : kernel_desc_proxy(op_desc) {}
112107
virtual ~postop_desc() {}
113108
};
109+
110+
class eltwiseop_desc : public kernel_desc_proxy {
111+
public:
112+
eltwiseop_desc(){};
113+
explicit eltwiseop_desc(const operator_desc& op_desc) : kernel_desc_proxy(op_desc) {}
114+
virtual ~eltwiseop_desc() {}
115+
};
116+
114117
/**
115118
* @brief Derived proxy class, interfacing to the real/cached sparse_matmul_t.
116119
*/
@@ -126,5 +129,13 @@ class postop : public kernel_proxy {
126129
explicit postop(const kernel_desc_proxy& kdp) : kernel_proxy(kdp) {}
127130
virtual ~postop() {}
128131
};
132+
133+
class eltwiseop : public kernel_proxy {
134+
public:
135+
eltwiseop() {}
136+
explicit eltwiseop(const kernel_desc_proxy& kdp) : kernel_proxy(kdp) {}
137+
virtual ~eltwiseop() {}
138+
};
139+
129140
} // namespace jd
130141
#endif // ENGINE_SPARSELIB_INCLUDE_INTERFACE_HPP_
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright (c) 2022 Intel Corporation
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef ENGINE_SPARSELIB_INCLUDE_JIT_DOMAIN_JIT_ELTWISE_INJECTOR_HPP_
16+
#define ENGINE_SPARSELIB_INCLUDE_JIT_DOMAIN_JIT_ELTWISE_INJECTOR_HPP_
17+
18+
#include "jit_generator.hpp"
19+
#include "utils.hpp"
20+
#include "param_types.hpp"
21+
#include <map>
22+
#include <set>
23+
24+
namespace jd {
25+
class jit_eltwise_injector {
26+
using Zmm = Xbyak::Zmm;
27+
using Ymm = Xbyak::Ymm;
28+
using Xmm = Xbyak::Xmm;
29+
30+
public:
31+
explicit jit_eltwise_injector(){};
32+
virtual ~jit_eltwise_injector() {}
33+
34+
void eltwise_injector_init(jit_generator* ptr, const std::vector<postop_attr>& postop_attrs);
35+
void vector_compute(const Xbyak::Zmm& zmm_src, const std::vector<postop_attr>& postop_attrs);
36+
void escape_regs(reg_type type, int reg_idx);
37+
void escape_erase(reg_type type, int reg_idx = -1);
38+
void prepare_table();
39+
40+
private:
41+
void assign_regs();
42+
void exp_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
43+
void tanh_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
44+
void gelu_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
45+
void relu_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
46+
void quantize_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
47+
void dequantize_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
48+
void register_table_entries(const std::vector<postop_attr>& postop_attrs);
49+
void assert_check(const std::vector<postop_attr>& postop_attrs);
50+
void load_table_addr() { h->mov(p_table, l_table); };
51+
52+
private:
53+
postop_attr cur_postop_attr_;
54+
int cur_iter_idx_; // for alpha,beta,scale lookup.
55+
jit_generator* h;
56+
std::unordered_map<reg_type, std::set<int>> used_regs;
57+
58+
/*labels*/
59+
Xbyak::Label l_table;
60+
61+
/*register for fwd*/
62+
Xbyak::Reg64 p_table;
63+
Xbyak::Reg64 reg64_tmp;
64+
65+
Zmm zmm_mask, zmm_aux0, zmm_aux1, zmm_aux2, zmm_aux3, zmm_aux4, zmm_tmp;
66+
Ymm ymm_tmp;
67+
Xmm xmm_tmp;
68+
Xbyak::Opmask k_mask;
69+
static constexpr int n_mantissa_bits = 23;
70+
static constexpr int max_mask_idx = 7;
71+
static constexpr int max_zmm_idx = 31;
72+
static constexpr int max_reg64_idx = 15;
73+
74+
enum {
75+
_cmp_eq_oq = 0u,
76+
_cmp_lt_os = 1u,
77+
_cmp_le_os = 2u,
78+
_cmp_neq_uq = 4u,
79+
_cmp_nlt_us = 5u,
80+
_cmp_nle_us = 6u,
81+
82+
_op_floor = 1u,
83+
_op_mxcsr = 4u,
84+
};
85+
86+
enum key_t {
87+
scale = 0, // scale argument
88+
alpha, // alpha argument
89+
beta, // beta argument
90+
zero, // 0.f
91+
half, // 0.5f
92+
one, // 1.f or mask for exponent bits
93+
two, // 2.f
94+
three, // 3.f
95+
six, // 6.f
96+
minus_one, // -1.f or changes sign to opposite
97+
minus_two, // -2.f
98+
minus_three, // -3.f
99+
ln2f, // 0.69314718f
100+
positive_mask, // changes sign to positive
101+
sign_mask, // gets sign value
102+
exponent_bias, // (127 = 2^7 - 1), gets exponent bits
103+
exp_log2ef, // 1.44269502f - formula-based for approx
104+
exp_ln_flt_max_f, // logf(FLT_MAX) - max normal value
105+
exp_ln_flt_min_f, // logf(FLT_MIN) - min normal value
106+
exp_pol, // see correspondent table for float values
107+
gelu_tanh_fitting_const, // 0.044715f
108+
gelu_tanh_fitting_const_times_three, // 0.134145f
109+
gelu_tanh_sqrt_two_over_pi, // sqrtf(2.f/pi) = 0.797884f
110+
gelu_tanh_flt_max_x,
111+
gelu_tanh_flt_min_x,
112+
tanh_idx_bias,
113+
tanh_idx_mask,
114+
tanh_linear_ubound,
115+
tanh_saturation_lbound,
116+
tanh_pol_table,
117+
exchange_zmm_low256_high256,
118+
undef_key,
119+
};
120+
121+
size_t table_off(key_t key, size_t key_off_val_shift = 0);
122+
Xbyak::Address table_val(key_t key, size_t key_off_val_shift = 0);
123+
using table_entry_val_t = uint32_t;
124+
using table_entry_offset_t = size_t; // offsets are in bytes wrt p_table
125+
using table_entry_bcast_t = bool;
126+
127+
struct table_entry_t {
128+
table_entry_val_t val;
129+
table_entry_bcast_t bcast;
130+
};
131+
struct mapped_table_entry_t {
132+
table_entry_offset_t off;
133+
table_entry_val_t val;
134+
table_entry_bcast_t bcast;
135+
};
136+
using table_t = std::multimap<key_t, table_entry_t>;
137+
using mapped_table_t = std::multimap<key_t, mapped_table_entry_t>;
138+
mapped_table_t entry_map;
139+
};
140+
} // namespace jd
141+
#endif
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright (c) 2022 Intel Corporation
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef ENGINE_SPARSELIB_INCLUDE_JIT_DOMAIN_JIT_ELTWISEOP_HPP_
16+
#define ENGINE_SPARSELIB_INCLUDE_JIT_DOMAIN_JIT_ELTWISEOP_HPP_
17+
18+
#include "../jit_generator.hpp"
19+
#include "utils.hpp"
20+
#include "kernels/eltwiseop_types.hpp"
21+
#include "jit_eltwise_injector.hpp"
22+
#include <map>
23+
24+
#define ELT_GET_OFF(field) offsetof(ssd::eltwiseop_data_t, field)
25+
26+
namespace jd {
27+
class jit_eltwiseop_t : public jit_generator {
28+
using Zmm = Xbyak::Zmm;
29+
using Ymm = Xbyak::Ymm;
30+
using Xmm = Xbyak::Xmm;
31+
32+
public:
33+
explicit jit_eltwiseop_t(const ssd::eltwiseop_param_t& param) : jit_generator(), param_(param) {
34+
eltwise_injector.eltwise_injector_init(this, param_.postop_attrs);
35+
assign_regs();
36+
}
37+
virtual ~jit_eltwiseop_t() {}
38+
39+
private:
40+
void generate() override;
41+
void assign_regs();
42+
void store_dst(Xbyak::Zmm reg_src, Xbyak::Reg64 dst_addr);
43+
void store_tail(Xbyak::Zmm reg_src, Xbyak::Reg64 dst_addr);
44+
void load_src(Xbyak::Zmm reg_src, Xbyak::Reg64 src_addr);
45+
void load_tail(Xbyak::Zmm reg_src, Xbyak::Reg64 src_addr);
46+
void prepare_mask();
47+
48+
private:
49+
ssd::eltwiseop_param_t param_;
50+
jit_eltwise_injector eltwise_injector;
51+
52+
/*labels*/
53+
Xbyak::Label vectorized_loop_start;
54+
Xbyak::Label vectorized_loop_end;
55+
Xbyak::Label reminder_loop_start;
56+
Xbyak::Label reminder_loop_end;
57+
58+
/* registers for fwd*/
59+
Xbyak::Reg64 reg_param = rdi;
60+
Zmm reg_src;
61+
Xbyak::Reg64 addr_src = r15;
62+
Xbyak::Reg64 addr_dst = r14;
63+
Xbyak::Reg64 remain_element_num = rsi;
64+
65+
/* registers for bf16 tasks*/
66+
Xbyak::Opmask remain_task_mask;
67+
Xbyak::Reg64 scratch_;
68+
69+
size_t dtype_size(postop_attr attr) {
70+
// quantize only happend on first postop,we load the data from memory to zmm,in tail case,the offset is one byte;
71+
// dequantize only happend on last postop,we store the data from zmm to memory,in taill case,the offset is also one
72+
// byte.
73+
if (attr.op_alg == postop_alg::quantize || attr.op_alg == postop_alg::dequantize) return 1u;
74+
switch (attr.dt) {
75+
case data_type::fp32:
76+
return 4u;
77+
case data_type::bf16:
78+
return 2u;
79+
}
80+
};
81+
82+
size_t load_offset() {
83+
auto head_dt = param_.postop_attrs.front().dt;
84+
switch (head_dt) {
85+
case data_type::fp32:
86+
case data_type::bf16:
87+
return 64u;
88+
case data_type::u8: // dequantize case
89+
return 16u;
90+
}
91+
}
92+
93+
size_t store_offset() {
94+
if (param_.postop_attrs.back().op_alg == postop_alg::quantize) return 16u; // quantize case.
95+
if (param_.postop_attrs.back().op_alg == postop_alg::dequantize) return 64u; // dequantize case.
96+
auto tail_dt = param_.postop_attrs.back().dt;
97+
switch (tail_dt) {
98+
case data_type::fp32:
99+
case data_type::bf16:
100+
return 64u;
101+
}
102+
}
103+
104+
size_t process_element_num() {
105+
auto front_attr = param_.postop_attrs.front();
106+
switch (front_attr.dt) {
107+
case data_type::fp32:
108+
return 16;
109+
case data_type::bf16:
110+
return 32;
111+
case data_type::u8: // dequantize case
112+
return 16;
113+
}
114+
}
115+
116+
void load_params() {
117+
mov(addr_dst, ptr[reg_param + ELT_GET_OFF(dst)]);
118+
mov(addr_src, ptr[reg_param + ELT_GET_OFF(src)]);
119+
mov(remain_element_num, ptr[reg_param + ELT_GET_OFF(element_num)]);
120+
}
121+
};
122+
} // namespace jd
123+
#endif

nlp_toolkit/backends/neural_engine/SparseLib/include/jit_domain/jit_postop_default.hpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class jit_postop_default_t : public jit_generator {
4444
void gelu_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
4545
void load_bf16_cvt_to_f32(Xbyak::Zmm reg_src, Xbyak::Reg64 src_addr, bool is_tail = false, size_t offset = 0);
4646
void cvt_f32_to_bf16_store(Xbyak::Zmm reg_src, Xbyak::Reg64 addr_dst, bool is_tail = false, size_t offset = 0);
47-
void init_vcvtneps2bf16();
4847

4948
bool is_bf16() {
5049
if (param_.dt == ssd::data_type::bf16) return true;
@@ -71,10 +70,6 @@ class jit_postop_default_t : public jit_generator {
7170
return 0;
7271
};
7372

74-
public:
75-
const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
76-
const int EVEX_max_8b_offt = 0x200;
77-
7873
private:
7974
ssd::postop_param_t param_;
8075

@@ -95,7 +90,6 @@ class jit_postop_default_t : public jit_generator {
9590

9691
/* register for bf16 tasks*/
9792
Xbyak::Opmask remain_task_mask;
98-
Zmm one_, even_, selector_, tr0_;
9993
Xbyak::Reg64 scratch_;
10094

10195
Zmm zmm_mask, zmm_aux0, zmm_aux1, zmm_aux2, zmm_aux3, zmm_aux4, zmm_tmp;

nlp_toolkit/backends/neural_engine/SparseLib/include/kernel_hashing.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ class hash_t {
6565
uint64_t get_attr_hash(const std::unordered_map<std::string, std::string>& attrs, const kernel_kind& ker_kind) const {
6666
auto op_attrs = attrs;
6767
uint64_t seed = 0;
68-
hash_combine(seed, op_attrs["post_op"]);
68+
// if front op want to apply postop-fusion,they should add a filed named postop_list in op_attr
69+
// for distinguishing.
70+
hash_combine(seed, op_attrs["postop_list"]);
6971
switch (ker_kind) {
7072
case kernel_kind::undef:
7173
break;
@@ -75,8 +77,13 @@ class hash_t {
7577
hash_combine(seed, op_attrs["tile_shape"]);
7678
hash_combine(seed, op_attrs["sparse_scheme"]);
7779
break;
80+
// todo:remove it.
7881
case kernel_kind::postop:
79-
hash_combine(seed, op_attrs["exp"]);
82+
break;
83+
case kernel_kind::eltwiseop:
84+
hash_combine(seed, op_attrs["reg64"]);
85+
hash_combine(seed, op_attrs["zmm"]);
86+
hash_combine(seed, op_attrs["mask"]);
8087
break;
8188
default:
8289
break;

0 commit comments

Comments
 (0)