@@ -32,16 +32,15 @@ namespace jd {
3232 */
3333class jit_spmm_vnni_t : public jit_generator {
3434 public:
35- explicit jit_spmm_vnni_t (const ssd::flat_param_t & param)
36- : jit_generator(), param_(param), csrp_(param_.sparse_ptr) {}
35+ explicit jit_spmm_vnni_t (const ssd::flat_param_t & param);
3736 virtual ~jit_spmm_vnni_t () {}
3837
3938 public:
4039 const void * sequence_vals () const { return seq_vals_.data (); }
4140
4241 private:
4342 ssd::flat_param_t param_;
44- csrp_data_t <int8_t >* csrp_ ;
43+ bsr_data_t <int8_t >* bsr_ ;
4544 std::vector<int8_t > seq_vals_;
4645
4746 private:
@@ -54,30 +53,26 @@ class jit_spmm_vnni_t : public jit_generator {
5453 Xbyak::Zmm dst_tile_Vmm (int i, int j); // Reg alloc of DST tile. 2D shape=(TH,TW), stride=(TW,1)
5554 void params_alias (const ssd::flat_param_t & param);
5655 void read_params ();
57- void load_bias (const std::vector< int64_t >& m_indices );
56+ void load_bias (int64_t m_start );
5857 void load_dense (const std::vector<int64_t >& k_indices);
59- void load_sparse ();
58+ void load_sparse (const int8_t * bsr_data, int64_t kp_lo, int64_t kp_hi );
6059 void tile_product (int tile_height, int tile_width);
61- void handle_dst_buffer_init (int kb_idx, const std::vector< int64_t >& m_indices );
62- void handle_dst_buffer_epilogue (int kb_idx, const std::vector< int64_t >& m_indices );
60+ void handle_dst_buffer_init (int kb_idx, int64_t m_start );
61+ void handle_dst_buffer_epilogue (int kb_idx, int64_t m_start );
6362 void mul_scale (int i);
6463 void move_out (int i, int j, int row_idx, int bytes = 1 );
65- std::unordered_map<int64_t , std::vector<int64_t >> get_idx_balanced (const std::vector< int64_t >& m_indices ,
64+ std::unordered_map<int64_t , std::vector<int64_t >> get_idx_balanced (int64_t m_start ,
6665 const std::vector<int64_t >& sparse_indptr,
6766 const std::vector<int64_t >& sparse_indices, int lo,
6867 int hi);
69- std::unordered_map<int64_t , std::vector<int8_t >> get_val_balanced (const std::vector< int64_t >& m_indices ,
68+ std::unordered_map<int64_t , std::vector<int8_t >> get_val_balanced (int64_t m_start ,
7069 const std::vector<int64_t >& sparse_indptr,
7170 const std::vector<int64_t >& sparse_indices, int lo,
7271 int hi, const std::vector<int8_t >& sparse_inddata);
73- void repeat_THx4xTW_matmal (const std::vector<int64_t >& m_indices,
74- const std::unordered_map<int64_t , std::vector<int64_t >>& k_indices_map,
75- const std::unordered_map<int64_t , std::vector<int8_t >>& k_inddata_map);
72+ void repeat_THx4xTW_matmal (int64_t imb);
7673 void clear_dst_tile ();
77- void load_intermediate_dst (const std::vector<int64_t >& m_indices);
78- void store_intermediate_dst (const std::vector<int64_t >& m_indices);
79- void save_sequence_vals (const std::vector<int64_t >& m_indices,
80- const std::unordered_map<int64_t , std::vector<int8_t >>& k_inddata_map, int pos1, int pos2);
74+ void load_intermediate_dst (int64_t m_start);
75+ void store_intermediate_dst (int64_t m_start);
8176 void gen_sub_function ();
8277
8378 private:
0 commit comments