@@ -76,12 +76,25 @@ class Dispatcher {
7676 void Prepare (const vector<Tensor*>& input, const vector<Tensor*>& output) {
7777 // (TODO) handle the case that different kernel with different output data type
7878 // Prepare will change some status on kernel, but should not on output
79+ for (int i = 0 ; i < kernel_handler_.size (); ++i) sparselib_available_.push_back (false );
80+ int idx = 0 ;
81+ // let default kernel prepare first
82+ kernel_handler_[type_]->Prepare (input, output);
7983 for (const auto & k_pair : kernel_handler_) {
84+ auto kernel_name = k_pair.first ;
8085 auto kernel = k_pair.second ;
8186 kernel->set_dispatch_from_type (type_);
82- kernel->Prepare (input, output);
87+ if (kernel_name != type_) kernel->Prepare (input, output);
88+ sparselib_available_[idx++] = kernel->kernel_type () == SparseLib ? true : false ;
89+ if (tune_dense_in_sparse_ && do_tuning_ && kernel->kernel_type () == SparseLib) {
90+ kernel->set_kernel_type (Dense);
91+ kernel->Prepare (input, output);
92+ kernel->set_kernel_type (SparseLib);
93+ }
94+ if ((kernel_handler_.size () < 2 || kernel->monopolize_dispatcher ())
95+ && !sparselib_available_[0 ]) no_tuning_space_ = true ;
8396 if (kernel->monopolize_dispatcher ()) {
84- disable_dispatch_ = true ;
97+ monopoly_kernel_ = kernel_name ;
8598 break ;
8699 }
87100 }
@@ -110,7 +123,7 @@ class Dispatcher {
110123 if (kernel_handler_.size () > 1 ) kernel_handler_[type_]->set_do_shape_infer (true );
111124 kernel_handler_[type_]->Reshape (input, output);
112125 }
113- if (!disable_dispatch_ && has_dispatch_table_file) {
126+ if (!no_tuning_space_ && has_dispatch_table_file) {
114127 // generate hash key and find the best kernel if has dispatch table
115128 // only load once
116129 if (DispatchTable::Size () == 0 ) {
@@ -120,9 +133,16 @@ class Dispatcher {
120133 vector<string> kernel_config = DispatchTable::Find (type_, GetHash (input));
121134 if (!kernel_config.empty ()) {
122135 string kernel_name = kernel_config[0 ];
123- if (kernel_handler_.count (kernel_name) > 0 ) {
124- execute_kernel_ = kernel_name;
125- kernel_handler_[kernel_name]->set_dispatch_config (kernel_config);
136+ // sparselib
137+ if (kernel_name == " SparseLib" ) {
138+ execute_kernel_ = type_;
139+ kernel_handler_[type_]->set_dispatch_config (kernel_config);
140+ } else {
141+ // dense
142+ if (kernel_handler_.count (kernel_name) > 0 ) {
143+ execute_kernel_ = kernel_name;
144+ kernel_handler_[kernel_name]->set_dispatch_config (kernel_config);
145+ }
126146 }
127147 }
128148 }
@@ -136,41 +156,60 @@ class Dispatcher {
136156 size_t input_hash = GetHash (input);
137157 iter_cnt_ += 1 ;
138158 // consider warmup when tuning
139- if (!disable_dispatch_ && kernel_handler_.size () > 1 && (iter_cnt_<= warmup_iter_ + 1 ||
140- DispatchTable::Find (type_, input_hash).empty ())) {
159+ if (!no_tuning_space_ && (iter_cnt_<= warmup_iter_ + 1 || DispatchTable::Find (type_, input_hash).empty ())) {
141160 // keep kernel with the least time as first pair
142161 std::map<float , vector<string>, std::less<float >> timer;
143162 OpTuning op_tuning (type_);
144163 // increase input tensors' life when tune
145164 // default kernel does not count towards the extra life
165+ int idx = 0 ;
166+ string suffix;
146167 for (const auto & k_pair : kernel_handler_) {
147168 auto kernel_name = k_pair.first ;
148169 auto kernel = k_pair.second ;
149- op_tuning.Start (kernel_name, kernel, input, output, reshape_model);
170+ suffix = sparselib_available_[idx++] ? " SparseLib" : kernel_name;
171+ if (tune_dense_in_sparse_ && suffix == " SparseLib" ) {
172+ kernel->set_kernel_type (Dense);
173+ op_tuning.Start (kernel_name, kernel, input, output, reshape_model);
174+ kernel->set_kernel_type (SparseLib);
175+ }
176+ op_tuning.Start (suffix, kernel, input, output, reshape_model);
177+ if (monopoly_kernel_ == kernel_name) break ;
150178 }
151179 for (auto & tensor : input) tensor->disposable_extra_life (op_tuning.extra_tensor_life ());
152180 op_tuning.reset_extra_tensor_life ();
153181 // tune kernel
182+ idx = 0 ;
154183 for (const auto & k_pair : kernel_handler_) {
155184 auto kernel_name = k_pair.first ;
156185 auto kernel = k_pair.second ;
186+ suffix = sparselib_available_[idx++] == true ? " SparseLib" : kernel_name;
157187 try {
158- op_tuning.Run (kernel_name, kernel, input, output, reshape_model);
188+ if (tune_dense_in_sparse_ && suffix == " SparseLib" ) {
189+ kernel->set_kernel_type (Dense);
190+ op_tuning.Run (kernel_name, kernel, input, output, reshape_model);
191+ kernel->set_kernel_type (SparseLib);
192+ }
193+ op_tuning.Run (suffix, kernel, input, output, reshape_model);
159194 timer[op_tuning.best_execute_time ()] = op_tuning.kernel_config ();
160195 // some kernels don't support specific dtype, fusion, etc.
161196 } catch (const std::exception& e) {
162197 LOG (WARNING) << kernel_name << " kernel tuning failure: " << e.what ();
163198 }
199+ if (monopoly_kernel_ == kernel_name) break ;
164200 }
165201 if (timer.size () > 0 ) {
166202 execute_kernel_ = timer.begin ()->second [0 ];
167203 LOG (INFO) << " best kernel is " << execute_kernel_ << " with time " << timer.begin ()->first << " ms" ;
168204 if (execute_kernel_ != type_) DispatchTable::Insert (type_, input_hash, timer.begin ()->second );
169205 }
170206 } else {
171- LOG (INFO) << " Skip tuning function due to existing input hash..." ;
172- if (reshape_model) kernel_handler_[type_]->Reshape (input, output);
173- kernel_handler_[type_]->Forward (input, output);
207+ LOG (INFO) << " Skip tuning function due to existing input hash or no tuning space..." ;
208+ vector<string> kernel_config = DispatchTable::Find (type_, input_hash);
209+ string kernel_name = (!kernel_config.empty () && kernel_config[0 ] != " SparseLib" ) ? kernel_config[0 ] : type_;
210+ kernel_handler_[kernel_name]->set_dispatch_config (kernel_config);
211+ if (reshape_model || !kernel_config.empty ()) kernel_handler_[kernel_name]->Reshape (input, output);
212+ kernel_handler_[kernel_name]->Forward (input, output);
174213 }
175214 }
176215 }
@@ -182,7 +221,7 @@ class Dispatcher {
182221 inline const string& type () const { return type_; }
183222 inline const OperatorConfig& operator_conf () const { return operator_conf_; }
184223 inline const string& execute_kernel () const { return execute_kernel_; }
185- inline const bool & disable_dispatch () const { return disable_dispatch_ ; }
224+ inline const bool & no_tuning_space () const { return no_tuning_space_ ; }
186225 inline const void set_warmup_iter (const int & warmup_iter) { warmup_iter_ = warmup_iter; }
187226 // for profiling
188227 inline void set_post_op (const string& post_op) { kernel_handler_[execute_kernel_]->set_post_op (post_op); }
@@ -215,6 +254,7 @@ class Dispatcher {
215254 size_t input_hash = 0 ;
216255 for (const auto & tensor : input) combine_hash.push_back (tensor->get_hash ());
217256 input_hash = get_array_hash (input_hash, combine_hash, combine_hash.size ());
257+ input_hash = get_array_hash (input_hash, sparselib_available_, sparselib_available_.size ());
218258 return input_hash;
219259 }
220260
@@ -225,9 +265,12 @@ class Dispatcher {
225265 KernelHandler kernel_handler_;
226266 string execute_kernel_;
227267 bool do_tuning_ = false ;
228- bool disable_dispatch_ = false ;
268+ bool no_tuning_space_ = false ;
229269 int64_t warmup_iter_ = 1 ;
230270 int64_t iter_cnt_ = 0 ;
271+ vector<bool > sparselib_available_;
272+ bool tune_dense_in_sparse_ = false ;
273+ string monopoly_kernel_;
231274};
232275} // namespace executor
233276
0 commit comments