Skip to content

Commit b4ce5de

Browse files
authored
Sparse VNNI tile fold (#223)
1 parent a6b975a commit b4ce5de

File tree

5 files changed

+97
-57
lines changed

5 files changed

+97
-57
lines changed

nlp_toolkit/backends/neural_engine/SparseLib/include/jit_domain/jit_spmm_vnni.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class jit_spmm_vnni_t : public jit_generator {
7575
void store_intermediate_dst(dim_t m_start);
7676
void gen_subfunc_tile_prod();
7777
void gen_subfunc_dense_and_prod();
78+
void load_dense_sparse_prod();
7879
void gen_subfunc_load_and_prod();
7980
void gen_subfunc_dst_epilogue();
8081
void handle_postop_escape_vmms();
@@ -118,6 +119,7 @@ class jit_spmm_vnni_t : public jit_generator {
118119
const Xbyak::Reg64& reg_scale = rbx; // the scale
119120
const Xbyak::Opmask& reg_k1 = k1;
120121

122+
const Xbyak::Reg64& reg_k_ptr = param1;
121123
const Xbyak::Reg64& reg_tmp = r9;
122124
const Xbyak::Reg64& reg_dst_idx = r8;
123125
const Xbyak::Reg64& reg_m_idx = reg_tmp;

nlp_toolkit/backends/neural_engine/SparseLib/include/kernels/spmm_types.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ enum class subfunc_level : uint8_t {
5252
prod, // use sub-function for tile product
5353
dense_and_prod, // use fused sub-function for dense loading & tile product
5454
load_and_prod, // use fused sub-function for dense loading & sparse loading & tile product
55-
subfunc_level_MAX = load_and_prod
55+
k_dims, // a whole THxKxTW tile generates a constent size of code
56+
subfunc_level_MAX = k_dims
5657
};
5758

5859
/**

nlp_toolkit/backends/neural_engine/SparseLib/src/jit_domain/jit_spmm_vnni.cpp

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -155,44 +155,66 @@ void jit_spmm_vnni_t::repeat_THx4xTW_matmal(dim_t m_start) {
155155
mov(reg_seq_indices, reinterpret_cast<uint64_t>(dense_load_offsets.data() + indptr_lo - indptr_kernel_start));
156156
break;
157157
case ssd::subfunc_level::load_and_prod:
158+
case ssd::subfunc_level::k_dims:
158159
mov(reg_seq_indices, reinterpret_cast<uint64_t>(dense_load_offsets.data() + indptr_lo - indptr_kernel_start));
159160
mov(reg_wei, reinterpret_cast<uint64_t>(param_.weight + param_.blocksize[0] * param_.blocksize[1] * indptr_lo));
160161
break;
161162
default:
162163
break;
163164
}
164-
165-
// kp (k-idx pointer is the idx of nnz blocks of the current row)
166-
for (int64_t kp_lo = 0; kp_lo < nnz; kp_lo += spns::ADJ) {
167-
const int64_t kp_hi = std::min(kp_lo + spns::ADJ, nnz); // end of k-index pointer (noninclusive)
168-
dim_t element_offset = param_.blocksize[0] * param_.blocksize[1] * indptr_lo + kp_lo * TH();
169-
170-
// Step 1: load dense (activation). Note that k_indices length is processed.00-00
171-
// Step 2: load sparse (weight) and reorder data for that.
172-
// Step 3: tile product. Note that k_indices length is processed.
173-
// A tile product can calculate at least 1 row and 16 columns of DST.
174-
// Min tile calculation: Tile width/height is 1, compute (1, ADJ) x (ADJ, 16) = (1, 16) matmul.
175-
switch (param_.sub_func) {
176-
case ssd::subfunc_level::none:
177-
load_dense({k_indices.begin() + kp_lo, k_indices.begin() + kp_hi});
178-
load_sparse(reg_wei, element_offset * sizeof(decltype(*param_.weight)));
179-
tile_product(TH(), TW());
180-
break;
181-
case ssd::subfunc_level::prod:
182-
load_dense({k_indices.begin() + kp_lo, k_indices.begin() + kp_hi});
183-
load_sparse(reg_wei, element_offset * sizeof(decltype(*param_.weight)));
184-
call(sfptr_tile_prod_);
185-
break;
186-
case ssd::subfunc_level::dense_and_prod:
187-
load_sparse(reg_wei, element_offset * sizeof(decltype(*param_.weight)));
188-
call(sfptr_dense_and_prod_);
189-
break;
190-
case ssd::subfunc_level::load_and_prod:
191-
call(sfptr_load_and_prod_);
192-
break;
193-
default:
165+
switch (param_.sub_func) {
166+
case ssd::subfunc_level::none:
167+
case ssd::subfunc_level::prod:
168+
case ssd::subfunc_level::dense_and_prod:
169+
case ssd::subfunc_level::load_and_prod:
170+
// kp (k-idx pointer is the idx of nnz blocks of the current row)
171+
for (int64_t kp_lo = 0; kp_lo < nnz; kp_lo += spns::ADJ) {
172+
const int64_t kp_hi = std::min(kp_lo + spns::ADJ, nnz); // end of k-index pointer (noninclusive)
173+
dim_t element_offset = param_.blocksize[0] * param_.blocksize[1] * indptr_lo + kp_lo * TH();
174+
175+
// Step 1: load dense (activation). Note that k_indices length is processed.00-00
176+
// Step 2: load sparse (weight) and reorder data for that.
177+
// Step 3: tile product. Note that k_indices length is processed.
178+
// A tile product can calculate at least 1 row and 16 columns of DST.
179+
// Min tile calculation: Tile width/height is 1, compute (1, ADJ) x (ADJ, 16) = (1, 16) matmul.
180+
switch (param_.sub_func) {
181+
case ssd::subfunc_level::none:
182+
load_dense({k_indices.begin() + kp_lo, k_indices.begin() + kp_hi});
183+
load_sparse(reg_wei, element_offset * sizeof(decltype(*param_.weight)));
184+
tile_product(TH(), TW());
185+
break;
186+
case ssd::subfunc_level::prod:
187+
load_dense({k_indices.begin() + kp_lo, k_indices.begin() + kp_hi});
188+
load_sparse(reg_wei, element_offset * sizeof(decltype(*param_.weight)));
189+
call(sfptr_tile_prod_);
190+
break;
191+
case ssd::subfunc_level::dense_and_prod:
192+
load_sparse(reg_wei, element_offset * sizeof(decltype(*param_.weight)));
193+
call(sfptr_dense_and_prod_);
194+
break;
195+
case ssd::subfunc_level::load_and_prod:
196+
call(sfptr_load_and_prod_);
197+
break;
198+
default:
199+
break;
200+
}
201+
}
202+
break;
203+
case ssd::subfunc_level::k_dims:
204+
if (nnz > 0) { // at least one iteration
205+
xor_(reg_k_ptr, reg_k_ptr);
206+
add(reg_dense, reg_n_idx); // reg_dense += reg_n_idx * BYTE1
207+
208+
Xbyak::Label L_adj_k_loop;
209+
L(L_adj_k_loop);
210+
load_dense_sparse_prod();
211+
add(reg_k_ptr, spns::ADJ);
212+
cmp(reg_k_ptr, static_cast<int>(nnz));
213+
jl(L_adj_k_loop); // Loop-N2 end.
214+
215+
sub(reg_dense, reg_n_idx); // reg_dense = reg_n_idx * BYTE1
194216
break;
195-
}
217+
}
196218
}
197219
}
198220

@@ -316,10 +338,13 @@ void jit_spmm_vnni_t::gen_subfunc_dense_and_prod() {
316338
ret();
317339
}
318340

319-
void jit_spmm_vnni_t::gen_subfunc_load_and_prod() {
320-
sfptr_load_and_prod_ = getCurr();
321-
add(reg_dense, reg_n_idx); // reg_dense += reg_n_idx * BYTE1
322-
341+
/**
342+
* Required registers:
343+
* reg_dense - the start of the current row of dense matrix
344+
* reg_seq_indices - the start of offset for each TW, it will be updated after read
345+
* reg_wei - the start of weight matrix, it will be updated after read
346+
*/
347+
void jit_spmm_vnni_t::load_dense_sparse_prod() {
323348
constexpr size_t idx_size = sizeof(decltype(param_.indices)::value_type);
324349
mov(reg_addr_tmp[0], qword[reg_seq_indices + 0 * idx_size]);
325350
mov(reg_addr_tmp[1], qword[reg_seq_indices + 1 * idx_size]);
@@ -353,8 +378,16 @@ void jit_spmm_vnni_t::gen_subfunc_load_and_prod() {
353378
// tile prod
354379
vpdpbusd(dst_tile_Vmm(i, j), TW_Vmm(j), TH_Vmm(i));
355380
}
381+
// update reg_wei in the middle
356382
if (j == TW() / 2) add(reg_wei, TH() * spns::ADJ * wei_size);
357383
}
384+
}
385+
386+
void jit_spmm_vnni_t::gen_subfunc_load_and_prod() {
387+
sfptr_load_and_prod_ = getCurr();
388+
add(reg_dense, reg_n_idx); // reg_dense += reg_n_idx * BYTE1
389+
390+
load_dense_sparse_prod();
358391

359392
sub(reg_dense, reg_n_idx); // reg_dense = reg_n_idx * BYTE1
360393
ret();
@@ -368,6 +401,7 @@ void jit_spmm_vnni_t::generate() {
368401
gen_subfunc_dst_epilogue();
369402
switch (param_.sub_func) {
370403
case ssd::subfunc_level::none:
404+
case ssd::subfunc_level::k_dims:
371405
break;
372406
case ssd::subfunc_level::prod:
373407
gen_subfunc_tile_prod();
Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,38 @@
11
### ncores_per_inst operator output_channel input_channel bsxseq sparsity micro_bs is_fp32_out has_append_sum micro_oc sub_func_level postop
22

33
# bert mini sp90 bs=1 c4
4-
4 vnni 256 256 128 0.9 -1 0 0 -1 0
5-
4 vnni 256 1024 128 0.9 -1 0 0 -1 0
6-
4 vnni 1024 256 128 0.9 -1 0 0 -1 0
4+
4 vnni 256 256 128 0.9 -1 0 0 -1 -1
5+
4 vnni 256 1024 128 0.9 -1 0 0 -1 -1
6+
4 vnni 1024 256 128 0.9 -1 0 0 -1 -1
77

88
# bert large sp90 bs=1 c4
9-
4 vnni 1024 1024 128 0.9 -1 0 0 -1 0
10-
4 vnni 1024 4096 128 0.9 -1 0 0 -1 0
11-
4 vnni 4096 1024 128 0.9 -1 0 0 -1 0
9+
4 vnni 1024 1024 128 0.9 -1 0 0 -1 -1
10+
4 vnni 1024 4096 128 0.9 -1 0 0 -1 -1
11+
4 vnni 4096 1024 128 0.9 -1 0 0 -1 -1
1212

1313
# distilbert sp80 bs=8 c4/7
14-
4 vnni 768 768 256 0.8 -1 0 0 -1 0
15-
4 vnni 768 3072 256 0.8 -1 0 0 -1 0
16-
4 vnni 3072 768 256 0.8 -1 0 0 -1 0
17-
7 vnni 768 768 256 0.8 -1 0 0 -1 0
18-
7 vnni 768 3072 256 0.8 -1 0 0 -1 0
19-
7 vnni 3072 768 256 0.8 -1 0 0 -1 0
20-
7 vnni 768 768 256 0.8 -1 0 0 -1 0 gelu
21-
7 vnni 768 3072 256 0.8 -1 0 0 -1 0 gelu
22-
7 vnni 3072 768 256 0.8 -1 0 0 -1 0 gelu
14+
4 vnni 768 768 256 0.8 -1 0 0 -1 -1
15+
4 vnni 768 3072 256 0.8 -1 0 0 -1 -1
16+
4 vnni 3072 768 256 0.8 -1 0 0 -1 -1
17+
7 vnni 768 768 256 0.8 -1 0 0 -1 -1
18+
7 vnni 768 3072 256 0.8 -1 0 0 -1 -1
19+
7 vnni 3072 768 256 0.8 -1 0 0 -1 -1
20+
7 vnni 768 768 256 0.8 -1 0 0 -1 -1 gelu
21+
7 vnni 768 3072 256 0.8 -1 0 0 -1 -1 gelu
22+
7 vnni 3072 768 256 0.8 -1 0 0 -1 -1 gelu
2323

2424
# distilbert sp80 bs=8 c7 micro_bs/micro_oc
25-
4 vnni 768 768 256 0.8 128 0 0 -1 0
26-
4 vnni 3072 768 256 0.8 -1 0 0 128 0
25+
4 vnni 768 768 256 0.8 128 0 0 -1 -1
26+
4 vnni 3072 768 256 0.8 -1 0 0 128 -1
2727

2828
# cases for differnt op_attrs
29+
8 vnni 3072 768 256 0.8 -1 0 0 -1 -1
30+
8 vnni 3072 768 256 0.8 128 0 0 -1 -1
31+
8 vnni 3072 768 256 0.8 -1 1 0 -1 -1
32+
8 vnni 3072 768 256 0.8 -1 1 1 -1 -1
33+
8 vnni 3072 768 256 0.8 -1 0 0 128 -1
2934
8 vnni 3072 768 256 0.8 -1 0 0 -1 0
30-
8 vnni 3072 768 256 0.8 128 0 0 -1 0
31-
8 vnni 3072 768 256 0.8 -1 1 0 -1 0
32-
8 vnni 3072 768 256 0.8 -1 1 1 -1 0
33-
8 vnni 3072 768 256 0.8 -1 0 0 128 0
3435
8 vnni 3072 768 256 0.8 -1 0 0 -1 1
3536
8 vnni 3072 768 256 0.8 -1 0 0 -1 2
3637
8 vnni 3072 768 256 0.8 -1 0 0 -1 3
38+
8 vnni 3072 768 256 0.8 -1 0 0 -1 4

nlp_toolkit/backends/neural_engine/test/gtest/SparseLib/test_spmm_vnni_kernel.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ static auto case_func = []() {
397397
cases.push_back({gen_case(32, 32, 128, .7f, -1, nthr, dt::fp32, {{"sub_func", "1"}})});
398398
cases.push_back({gen_case(32, 32, 128, .7f, -1, nthr, dt::fp32, {{"sub_func", "2"}})});
399399
cases.push_back({gen_case(32, 32, 128, .7f, -1, nthr, dt::fp32, {{"sub_func", "3"}})});
400+
cases.push_back({gen_case(32, 32, 128, .7f, -1, nthr, dt::fp32, {{"sub_func", "4"}})});
400401

401402
// case: sparse: s8xu8+s32=s8, weight(M, K) * activation(K, N) + bias(M, 1) = dst(M, N)
402403
cases.push_back({gen_case(32, 32, 128, .7f, -1, nthr, dt::s8)});

0 commit comments

Comments
 (0)