Skip to content

Commit 834b3d3

Browse files
[SparseLib] softmax int8 using lut (#215)
* benchmark done. * update ci.
1 parent 157670d commit 834b3d3

38 files changed

+1201
-1418
lines changed

nlp_toolkit/backends/neural_engine/SparseLib/include/interface.hpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,6 @@ class transpose_matmul_desc : public kernel_desc_proxy {
108108
virtual ~transpose_matmul_desc() {}
109109
};
110110

111-
class postop_desc : public kernel_desc_proxy {
112-
public:
113-
postop_desc() {}
114-
explicit postop_desc(const operator_desc& op_desc) : kernel_desc_proxy(op_desc) {}
115-
virtual ~postop_desc() {}
116-
};
117-
118111
class eltwiseop_desc : public kernel_desc_proxy {
119112
public:
120113
eltwiseop_desc() {}
@@ -129,6 +122,13 @@ class layernorm_ba_desc : public kernel_desc_proxy {
129122
virtual ~layernorm_ba_desc() {}
130123
};
131124

125+
class softmax_desc : public kernel_desc_proxy {
126+
public:
127+
softmax_desc() {}
128+
explicit softmax_desc(const operator_desc& op_desc) : kernel_desc_proxy(op_desc) {}
129+
virtual ~softmax_desc() {}
130+
};
131+
132132
/**
133133
* @brief Derived proxy class, interfacing to the real/cached sparse_matmul_t.
134134
*/
@@ -146,13 +146,6 @@ class transpose_matmul : public kernel_proxy {
146146
virtual ~transpose_matmul() {}
147147
};
148148

149-
class postop : public kernel_proxy {
150-
public:
151-
postop() {}
152-
explicit postop(const kernel_desc_proxy& kdp) : kernel_proxy(kdp) {}
153-
virtual ~postop() {}
154-
};
155-
156149
class eltwiseop : public kernel_proxy {
157150
public:
158151
eltwiseop() {}
@@ -167,5 +160,12 @@ class layernorm_ba : public kernel_proxy {
167160
virtual ~layernorm_ba() {}
168161
};
169162

163+
class softmax : public kernel_proxy {
164+
public:
165+
softmax() {}
166+
explicit softmax(const kernel_desc_proxy& kdp) : kernel_proxy(kdp) {}
167+
virtual ~softmax() {}
168+
};
169+
170170
} // namespace jd
171171
#endif // ENGINE_SPARSELIB_INCLUDE_INTERFACE_HPP_

nlp_toolkit/backends/neural_engine/SparseLib/include/jit_domain/jit_eltwise_injector.hpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef ENGINE_SPARSELIB_INCLUDE_JIT_DOMAIN_JIT_ELTWISE_INJECTOR_HPP_
1616
#define ENGINE_SPARSELIB_INCLUDE_JIT_DOMAIN_JIT_ELTWISE_INJECTOR_HPP_
1717

18+
#include <glog/logging.h>
1819
#include <string>
1920
#include <vector>
2021
#include <unordered_map>
@@ -54,10 +55,12 @@ class jit_eltwise_injector {
5455
void quantize_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
5556
void dequantize_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
5657
void linear_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
57-
void int8_lut_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
58+
void bit8_lut_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
59+
void bit16_lut_compute_vector_fwd(const Xbyak::Zmm& zmm_src);
5860
void register_table_entries(const std::vector<postop_attr>& postop_attrs);
5961
void assert_check(const std::vector<postop_attr>& postop_attrs);
60-
uint32_t get_int8_lut_term(int int8, const std::vector<postop_attr>& postop_attrs, data_type input_dt);
62+
uint32_t get_bit8_lut_term(int integer, const std::vector<postop_attr>& postop_attrs, data_type output_dt);
63+
uint32_t get_bit16_lut_term(int integer, const std::vector<postop_attr>& postop_attrs, data_type output_dt);
6164
std::string get_attr_idx_key(const postop_attr& attr); // for get the key of alpha_idx,beta_idx,scale_idx map.
6265

6366
void load_table_addr() { h->mov(p_table, l_table); }
@@ -133,10 +136,12 @@ class jit_eltwise_injector {
133136
tanh_saturation_lbound,
134137
tanh_pol_table,
135138
exchange_zmm_low256_high256,
136-
int8_lut_term,
137-
int8_64,
138-
int8_255,
139-
select_permt_idx,
139+
bit8_lut_term,
140+
bit8_64,
141+
bit8_255,
142+
bit16_lut_term,
143+
bit16_32,
144+
bit16_255,
140145
undef_key,
141146
};
142147

nlp_toolkit/backends/neural_engine/SparseLib/include/jit_domain/jit_eltwiseop.hpp

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -66,23 +66,13 @@ class jit_eltwiseop_t : public jit_generator {
6666
Xbyak::Opmask remain_task_mask;
6767
Xbyak::Reg64 scratch_;
6868

69-
size_t dtype_size(postop_attr attr) {
70-
// quantize only happend on first postop,we load the data from memory to zmm,in tail case,the offset is one byte;
71-
// dequantize only happend on last postop,we store the data from zmm to memory,in taill case,the offset is also one
72-
// byte.
73-
if (attr.op_alg == postop_alg::int8_lut || attr.op_alg == postop_alg::quantize ||
74-
attr.op_alg == postop_alg::dequantize)
75-
return 1u;
76-
switch (attr.dt) {
77-
case data_type::fp32:
78-
return 4u;
79-
case data_type::bf16:
80-
return 2u;
81-
}
82-
}
83-
8469
size_t load_offset() {
85-
if (param_.postop_attrs[0].op_alg == postop_alg::int8_lut) return 64u; // special case:int8_lut
70+
if (param_.postop_attrs[0].op_alg == postop_alg::eltop_int_lut && param_.postop_attrs[0].alpha == 8) {
71+
return 64u; // special case:bit8_lut
72+
}
73+
if (param_.postop_attrs[0].op_alg == postop_alg::eltop_int_lut && param_.postop_attrs[0].alpha == 16) {
74+
return 32u; // special case:bit16_lut
75+
}
8676
auto head_dt = param_.postop_attrs.front().dt;
8777
switch (head_dt) {
8878
case data_type::fp32:
@@ -94,12 +84,11 @@ class jit_eltwiseop_t : public jit_generator {
9484
}
9585

9686
size_t store_offset() {
97-
// todo:the logic is little confuse,we need to optimize.
98-
if (param_.postop_attrs.front().op_alg == postop_alg::int8_lut) return 64u; // int8_lut case;
99-
if (param_.postop_attrs.back().op_alg == postop_alg::quantize) return 16u; // quantize case.
100-
if (param_.postop_attrs.back().op_alg == postop_alg::dequantize) return 64u; // dequantize case.
101-
auto tail_dt = param_.postop_attrs.back().dt;
102-
switch (tail_dt) {
87+
// todo:except dequantize case, our zmm always full of result data and needs to be stored.
88+
if (param_.postop_attrs.front().op_alg == postop_alg::eltop_int_lut) return 64u; // lut case;
89+
if (param_.postop_attrs.back().op_alg == postop_alg::quantize) return 16u; // quantize case.
90+
if (param_.postop_attrs.back().op_alg == postop_alg::dequantize) return 64u; // dequantize case.
91+
switch (param_.out_dt) {
10392
case data_type::fp32:
10493
case data_type::bf16:
10594
return 64u;
@@ -108,8 +97,9 @@ class jit_eltwiseop_t : public jit_generator {
10897

10998
size_t process_element_num() {
11099
auto front_attr = param_.postop_attrs.front();
111-
if (front_attr.op_alg == postop_alg::int8_lut) return 64; // special case:int8_lut
112-
switch (front_attr.dt) {
100+
if (front_attr.op_alg == postop_alg::eltop_int_lut && front_attr.alpha == 8) return 64; // special case:bit8_lut
101+
if (front_attr.op_alg == postop_alg::eltop_int_lut && front_attr.alpha == 16) return 32; // sepcial case:bit16_lut
102+
switch (param_.in_dt) {
113103
case data_type::fp32:
114104
return 16;
115105
case data_type::bf16:

nlp_toolkit/backends/neural_engine/SparseLib/include/jit_domain/jit_postop_default.hpp

Lines changed: 0 additions & 178 deletions
This file was deleted.

0 commit comments

Comments
 (0)