2020from pathlib import Path
2121from functools import partial
2222import numpy as np
23- import hyperopt as hpo
24- from hyperopt import fmin , hp , STATUS_OK , Trials
2523from neural_compressor .utils import logger
24+ from neural_compressor .utils .utility import LazyImport
2625from neural_compressor .strategy .strategy import strategy_registry , TuneStrategy
2726from collections import OrderedDict
2827from neural_compressor .strategy .st_utils .tuning_sampler import OpWiseTuningSampler
2928from neural_compressor .strategy .st_utils .tuning_structs import OpTuningConfig
3029
30+ hyperopt = LazyImport ('hyperopt' )
3131
3232try :
3333 import pandas as pd
@@ -85,10 +85,19 @@ def __init__(self, model, conf, q_dataloader, q_func=None,
8585 eval_dataloader = None , eval_func = None , dicts = None , q_hooks = None ):
8686 assert conf .usr_cfg .quantization .approach == 'post_training_static_quant' , \
8787 "TPE strategy is only for post training static quantization!"
88+ # Initialize the tpe tuning strategy if the user specified to use it.
89+ strategy_name = conf .usr_cfg .tuning .strategy .name
90+ if strategy_name .lower () == "tpe" :
91+ try :
92+ import hyperopt
93+ except ImportError :
94+ raise ImportError (f"Please install hyperopt for using { strategy_name } strategy." )
95+ else :
96+ pass
8897 self .hpopt_search_space = None
8998 self .warm_start = False
9099 self .cfg_evaluated = False
91- self .hpopt_trials = Trials ()
100+ self .hpopt_trials = hyperopt . Trials ()
92101 self .max_trials = conf .usr_cfg .tuning .exit_policy .get ('max_trials' , 200 )
93102 self .loss_function_config = {
94103 'acc_th' : conf .usr_cfg .tuning .accuracy_criterion .relative if \
@@ -140,7 +149,7 @@ def __getstate__(self):
140149 def _configure_hpopt_search_space_and_params (self , search_space ):
141150 self .hpopt_search_space = {}
142151 for param , configs in search_space .items ():
143- self .hpopt_search_space [(param )] = hp .choice ((param [0 ]), configs )
152+ self .hpopt_search_space [(param )] = hyperopt . hp .choice ((param [0 ]), configs )
144153 # Find minimum number of choices for params with more than one choice
145154 multichoice_params = [len (configs ) for param , configs in search_space .items ()
146155 if len (configs ) > 1 ]
@@ -149,7 +158,7 @@ def _configure_hpopt_search_space_and_params(self, search_space):
149158 min_param_size = min (multichoice_params ) if len (multichoice_params ) > 0 else 1
150159 self .tpe_params ['n_EI_candidates' ] = min_param_size
151160 self .tpe_params ['prior_weight' ] = 1 / min_param_size
152- self ._algo = partial (hpo .tpe .suggest ,
161+ self ._algo = partial (hyperopt .tpe .suggest ,
153162 n_startup_jobs = self .tpe_params ['n_initial_point' ],
154163 gamma = self .tpe_params ['gamma' ],
155164 n_EI_candidates = self .tpe_params ['n_EI_candidates' ],
@@ -225,12 +234,12 @@ def initial_op_quant_mode(items_lst, target_quant_mode, op_item_dtype_dict):
225234 self ._configure_hpopt_search_space_and_params (first_run_cfg )
226235 # Run first iteration with best result from history
227236 trials_count = len (self .hpopt_trials .trials ) + 1
228- fmin (partial (self .object_evaluation , model = self .model ),
229- space = self .hpopt_search_space ,
230- algo = self ._algo ,
231- max_evals = trials_count ,
232- trials = self .hpopt_trials ,
233- show_progressbar = False )
237+ hyperopt . fmin (partial (self .object_evaluation , model = self .model ),
238+ space = self .hpopt_search_space ,
239+ algo = self ._algo ,
240+ max_evals = trials_count ,
241+ trials = self .hpopt_trials ,
242+ show_progressbar = False )
234243 if pd is not None :
235244 self ._save_trials (trials_file )
236245 self ._update_best_result (best_result_file )
@@ -266,12 +275,12 @@ def initial_op_quant_mode(items_lst, target_quant_mode, op_item_dtype_dict):
266275 self .cfg_evaluated = False
267276 logger .debug ("Trial iteration start: {} / {}." .format (
268277 trials_count , self .max_trials ))
269- fmin (partial (self .object_evaluation , model = self .model ),
270- space = self .hpopt_search_space ,
271- algo = self ._algo ,
272- max_evals = trials_count ,
273- trials = self .hpopt_trials ,
274- show_progressbar = False )
278+ hyperopt . fmin (partial (self .object_evaluation , model = self .model ),
279+ space = self .hpopt_search_space ,
280+ algo = self ._algo ,
281+ max_evals = trials_count ,
282+ trials = self .hpopt_trials ,
283+ show_progressbar = False )
275284 trials_count += 1
276285 if pd is not None :
277286 self ._save_trials (trials_file )
@@ -349,7 +358,7 @@ def _compute_metrics(self, tune_cfg, acc, lat):
349358 'acc_loss' : acc_diff ,
350359 'lat_diff' : lat_diff ,
351360 'quantization_ratio' : quantization_ratio ,
352- 'status' : STATUS_OK }
361+ 'status' : hyperopt . STATUS_OK }
353362
354363 def _calculate_acc_lat_diff (self , acc , lat ):
355364 int8_acc = acc
0 commit comments