@@ -155,44 +155,66 @@ void jit_spmm_vnni_t::repeat_THx4xTW_matmal(dim_t m_start) {
155155 mov (reg_seq_indices, reinterpret_cast <uint64_t >(dense_load_offsets.data () + indptr_lo - indptr_kernel_start));
156156 break ;
157157 case ssd::subfunc_level::load_and_prod:
158+ case ssd::subfunc_level::k_dims:
158159 mov (reg_seq_indices, reinterpret_cast <uint64_t >(dense_load_offsets.data () + indptr_lo - indptr_kernel_start));
159160 mov (reg_wei, reinterpret_cast <uint64_t >(param_.weight + param_.blocksize [0 ] * param_.blocksize [1 ] * indptr_lo));
160161 break ;
161162 default :
162163 break ;
163164 }
164-
165- // kp (k-idx pointer is the idx of nnz blocks of the current row)
166- for (int64_t kp_lo = 0 ; kp_lo < nnz; kp_lo += spns::ADJ) {
167- const int64_t kp_hi = std::min (kp_lo + spns::ADJ, nnz); // end of k-index pointer (noninclusive)
168- dim_t element_offset = param_.blocksize [0 ] * param_.blocksize [1 ] * indptr_lo + kp_lo * TH ();
169-
170- // Step 1: load dense (activation). Note that k_indices length is processed.00-00
171- // Step 2: load sparse (weight) and reorder data for that.
172- // Step 3: tile product. Note that k_indices length is processed.
173- // A tile product can calculate at least 1 row and 16 columns of DST.
174- // Min tile calculation: Tile width/height is 1, compute (1, ADJ) x (ADJ, 16) = (1, 16) matmul.
175- switch (param_.sub_func ) {
176- case ssd::subfunc_level::none:
177- load_dense ({k_indices.begin () + kp_lo, k_indices.begin () + kp_hi});
178- load_sparse (reg_wei, element_offset * sizeof (decltype (*param_.weight )));
179- tile_product (TH (), TW ());
180- break ;
181- case ssd::subfunc_level::prod:
182- load_dense ({k_indices.begin () + kp_lo, k_indices.begin () + kp_hi});
183- load_sparse (reg_wei, element_offset * sizeof (decltype (*param_.weight )));
184- call (sfptr_tile_prod_);
185- break ;
186- case ssd::subfunc_level::dense_and_prod:
187- load_sparse (reg_wei, element_offset * sizeof (decltype (*param_.weight )));
188- call (sfptr_dense_and_prod_);
189- break ;
190- case ssd::subfunc_level::load_and_prod:
191- call (sfptr_load_and_prod_);
192- break ;
193- default :
165+ switch (param_.sub_func ) {
166+ case ssd::subfunc_level::none:
167+ case ssd::subfunc_level::prod:
168+ case ssd::subfunc_level::dense_and_prod:
169+ case ssd::subfunc_level::load_and_prod:
170+ // kp (k-idx pointer is the idx of nnz blocks of the current row)
171+ for (int64_t kp_lo = 0 ; kp_lo < nnz; kp_lo += spns::ADJ) {
172+ const int64_t kp_hi = std::min (kp_lo + spns::ADJ, nnz); // end of k-index pointer (noninclusive)
173+ dim_t element_offset = param_.blocksize [0 ] * param_.blocksize [1 ] * indptr_lo + kp_lo * TH ();
174+
175+ // Step 1: load dense (activation). Note that k_indices length is processed.00-00
176+ // Step 2: load sparse (weight) and reorder data for that.
177+ // Step 3: tile product. Note that k_indices length is processed.
178+ // A tile product can calculate at least 1 row and 16 columns of DST.
179+ // Min tile calculation: Tile width/height is 1, compute (1, ADJ) x (ADJ, 16) = (1, 16) matmul.
180+ switch (param_.sub_func ) {
181+ case ssd::subfunc_level::none:
182+ load_dense ({k_indices.begin () + kp_lo, k_indices.begin () + kp_hi});
183+ load_sparse (reg_wei, element_offset * sizeof (decltype (*param_.weight )));
184+ tile_product (TH (), TW ());
185+ break ;
186+ case ssd::subfunc_level::prod:
187+ load_dense ({k_indices.begin () + kp_lo, k_indices.begin () + kp_hi});
188+ load_sparse (reg_wei, element_offset * sizeof (decltype (*param_.weight )));
189+ call (sfptr_tile_prod_);
190+ break ;
191+ case ssd::subfunc_level::dense_and_prod:
192+ load_sparse (reg_wei, element_offset * sizeof (decltype (*param_.weight )));
193+ call (sfptr_dense_and_prod_);
194+ break ;
195+ case ssd::subfunc_level::load_and_prod:
196+ call (sfptr_load_and_prod_);
197+ break ;
198+ default :
199+ break ;
200+ }
201+ }
202+ break ;
203+ case ssd::subfunc_level::k_dims:
204+ if (nnz > 0 ) { // at least one iteration
205+ xor_ (reg_k_ptr, reg_k_ptr);
206+ add (reg_dense, reg_n_idx); // reg_dense += reg_n_idx * BYTE1
207+
208+ Xbyak::Label L_adj_k_loop;
209+ L (L_adj_k_loop);
210+ load_dense_sparse_prod ();
211+ add (reg_k_ptr, spns::ADJ);
212+ cmp (reg_k_ptr, static_cast <int >(nnz));
213+ jl (L_adj_k_loop); // Loop-N2 end.
214+
215+ sub (reg_dense, reg_n_idx); // reg_dense = reg_n_idx * BYTE1
194216 break ;
195- }
217+ }
196218 }
197219}
198220
@@ -316,10 +338,13 @@ void jit_spmm_vnni_t::gen_subfunc_dense_and_prod() {
316338 ret ();
317339}
318340
319- void jit_spmm_vnni_t::gen_subfunc_load_and_prod () {
320- sfptr_load_and_prod_ = getCurr ();
321- add (reg_dense, reg_n_idx); // reg_dense += reg_n_idx * BYTE1
322-
341+ /* *
342+ * Required registers:
343+ * reg_dense - the start of the current row of dense matrix
344+ * reg_seq_indices - the start of offset for each TW, it will be updated after read
345+ * reg_wei - the start of weight matrix, it will be updated after read
346+ */
347+ void jit_spmm_vnni_t::load_dense_sparse_prod () {
323348 constexpr size_t idx_size = sizeof (decltype (param_.indices )::value_type);
324349 mov (reg_addr_tmp[0 ], qword[reg_seq_indices + 0 * idx_size]);
325350 mov (reg_addr_tmp[1 ], qword[reg_seq_indices + 1 * idx_size]);
@@ -353,8 +378,16 @@ void jit_spmm_vnni_t::gen_subfunc_load_and_prod() {
353378 // tile prod
354379 vpdpbusd (dst_tile_Vmm (i, j), TW_Vmm (j), TH_Vmm (i));
355380 }
381+ // update reg_wei in the middle
356382 if (j == TW () / 2 ) add (reg_wei, TH () * spns::ADJ * wei_size);
357383 }
384+ }
385+
386+ void jit_spmm_vnni_t::gen_subfunc_load_and_prod () {
387+ sfptr_load_and_prod_ = getCurr ();
388+ add (reg_dense, reg_n_idx); // reg_dense += reg_n_idx * BYTE1
389+
390+ load_dense_sparse_prod ();
358391
359392 sub (reg_dense, reg_n_idx); // reg_dense = reg_n_idx * BYTE1
360393 ret ();
@@ -368,6 +401,7 @@ void jit_spmm_vnni_t::generate() {
368401 gen_subfunc_dst_epilogue ();
369402 switch (param_.sub_func ) {
370403 case ssd::subfunc_level::none:
404+ case ssd::subfunc_level::k_dims:
371405 break ;
372406 case ssd::subfunc_level::prod:
373407 gen_subfunc_tile_prod ();
0 commit comments