diff --git a/neural_compressor/adaptor/torch_utils/hawq_metric.py b/neural_compressor/adaptor/torch_utils/hawq_metric.py index 758651a8078..8295fc99cc2 100644 --- a/neural_compressor/adaptor/torch_utils/hawq_metric.py +++ b/neural_compressor/adaptor/torch_utils/hawq_metric.py @@ -52,9 +52,9 @@ def remove(self): class HessianTrace: """HessianTrace Class. - Please refer to Yao, Zhewei, et al. "Pyhessian: Neural networks through the lens of the hessian." + Please refer to Yao, Zhewei, et al. "Pyhessian: Neural networks through the lens of the hessian." 2020 IEEE international conference on big data (Big data). IEEE, 2020. - Dong, Zhen, et al. "Hawq-v2: Hessian aware trace-weighted quantization of neural networks." + Dong, Zhen, et al. "Hawq-v2: Hessian aware trace-weighted quantization of neural networks." Advances in neural information processing systems 33 (2020): 18518-18529. https://github.com/openvinotoolkit/nncf/blob/develop/nncf/torch/quantization/hessian_trace.py """ @@ -173,7 +173,7 @@ def act_grad_hook(model, grad_input, grad_output): def _get_enable_act_grad_hook(self, name): def enable_act_grad_hook(model, inputs, outputs): input = inputs[0] - if input.requires_grad is False: + if input.requires_grad is False: # input.requires_grad = True self.layer_acts[name] = input @@ -251,13 +251,13 @@ def _sample_rademacher(self, params): r.masked_fill_(r == 0, -1) samples.append(r) return samples - + def _sample_rademacher_like_params(self): def sample(parameter): r = torch.randint_like(parameter, high=2, device=self.device) return r.masked_fill_(r == 0, -1) return [sample(p) for p in self.params] - + def _sample_normal_like_params(self): return [torch.randn(p.size(), device=self.device) for p in self.params] @@ -391,7 +391,7 @@ def _insert_hook(self, model, target_module_list): for layer, module in model.named_modules(): for target_module in target_module_list: # print("layer:",layer) - # print("target_model:",target_module) + # print("target_model:",target_module) if layer == target_module: logging.debug("Collect: %s" % (module)) # print("Collect: %s" % (module)) @@ -408,7 +408,7 @@ def _insert_hook_quantize(self, model, target_module_list): # print("layer:",layer) length = len("_model.") new_key = layer[length:] - # print("target_model:",target_module) + # print("target_model:",target_module) if new_key == target_module: logging.debug("Collect: %s" % (module)) # print("Collect: %s" % (module)) @@ -521,7 +521,7 @@ def compare_weights( float_dict: Dict[str, Any], quantized_dict: Dict[str, Any] ) -> Dict[str, Dict[str, torch.Tensor]]: r"""Compare the weights of the float module with its corresponding quantized module. - + Returns a dict with key corresponding to module names and each entry being a dictionary with two keys 'float' and 'quantized', containing the float and quantized weights. This dict can be used to compare and compute the quantization @@ -608,7 +608,7 @@ def hawq_top(fp32_model, q_model, dataloader, criterion, enable_act): op_qnt_tensor = weight_quant_loss[key]['quantized'].dequantize() diff_l2 = (torch.norm(op_float_tensor - op_qnt_tensor, p=2) ** 2) pertur_lst[key] = diff_l2 - + if enable_act: act_to_traces = traces['activation'] for trace_i, pertur_i, act_i in zip(op_to_traces.keys(), pertur_lst.keys(), act_to_traces.keys()): diff --git a/neural_compressor/contrib/strategy/sigopt.py b/neural_compressor/contrib/strategy/sigopt.py index 227b21edd9b..336c9adf91c 100644 --- a/neural_compressor/contrib/strategy/sigopt.py +++ b/neural_compressor/contrib/strategy/sigopt.py @@ -194,7 +194,7 @@ def traverse(self): This is SigOpt version of traverse -- with additional constraints setting to HPO. """ - self._eval_baseline() + self._prepare_tuning() baseline_msg = '[Accuracy: {:.4f}'.format(self.baseline[0]) + \ ''.join([', {}: {:.4f}'.format(x,y) for x,y in zip( \ diff --git a/neural_compressor/contrib/strategy/tpe.py b/neural_compressor/contrib/strategy/tpe.py index 2f80cb7751b..18bf0a76105 100644 --- a/neural_compressor/contrib/strategy/tpe.py +++ b/neural_compressor/contrib/strategy/tpe.py @@ -191,6 +191,7 @@ def _configure_hpopt_search_space_and_params(self, search_space): def traverse(self): """Tpe traverse logic.""" logger.info("Start to run tpe strategy.") + self._prepare_tuning() # prepare log file trials_file = os.path.join(os.path.dirname(self.history_path), 'tpe_trials.csv') best_result_file = os.path.join(os.path.dirname(self.history_path), 'tpe_best_result.csv') diff --git a/neural_compressor/strategy/auto.py b/neural_compressor/strategy/auto.py index b79c0a3cfc1..9a77c7645d4 100644 --- a/neural_compressor/strategy/auto.py +++ b/neural_compressor/strategy/auto.py @@ -79,13 +79,10 @@ def sequential_traverse(self): eval_dataloader=self.eval_dataloader, eval_metric=self.eval_metric, resume=self._resume, - q_hooks=self.q_hooks) + q_hooks=self.q_hooks, + pre_strategy = pre_strategy + ) - if pre_strategy: - #TODO add tuning history from the previous stage to current stage. - strategy.baseline = deepcopy(pre_strategy.baseline) - strategy.trials_count = pre_strategy.trials_count - strategy.objectives.baseline = deepcopy(pre_strategy.baseline) pre_strategy = strategy strategy.traverse() self.best_qmodel = strategy.best_qmodel diff --git a/neural_compressor/strategy/auto_mixed_precision.py b/neural_compressor/strategy/auto_mixed_precision.py index 78e045e72ed..fe3f5663cd8 100644 --- a/neural_compressor/strategy/auto_mixed_precision.py +++ b/neural_compressor/strategy/auto_mixed_precision.py @@ -128,8 +128,7 @@ def next_tune_cfg(self): def traverse(self): """Traverse the tuning space according to auto-mixed precision strategy.""" - # get fp32 model baseline - self._eval_baseline() + self._prepare_tuning() for op_tuning_cfg in self.next_tune_cfg(): # add tune_cfg here as quantize use tune_cfg diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py index 0944cb99293..f675ac27183 100644 --- a/neural_compressor/strategy/strategy.py +++ b/neural_compressor/strategy/strategy.py @@ -61,13 +61,37 @@ def strategy_registry(cls): assert cls.__name__.endswith( 'TuneStrategy' ), "The name of subclass of TuneStrategy should end with \'TuneStrategy\' substring." - if cls.__name__[:-len('TuneStrategy')].lower() in STRATEGIES: + if cls.__name__[:-len('TuneStrategy')].lower() in STRATEGIES: # pragma: no cover raise ValueError('Cannot have two strategies with the same name') STRATEGIES[cls.__name__[:-len('TuneStrategy')].lower()] = cls return cls +class TuneStrategyMeta(type): + """Tuning strategy metaclass.""" + + def __call__(cls, *args, pre_strategy=None, **kwargs): + """Create new strategy instance based on the previous one if has. + + Args: + pre_strategy: The previous strategy instance. Defaults to None. + + Returns: + The newly created strategy instance. + """ + new_strategy = super().__call__(*args, **kwargs) + if pre_strategy: + new_strategy.adaptor = pre_strategy.adaptor + new_strategy.framework = pre_strategy.framework + new_strategy.baseline = deepcopy(pre_strategy.baseline) + new_strategy.trials_count = pre_strategy.trials_count + new_strategy.objectives.baseline = deepcopy(pre_strategy.baseline) + new_strategy.capability = pre_strategy.capability + new_strategy.tuning_space = pre_strategy.tuning_space + new_strategy.algo_scheduler = pre_strategy.algo_scheduler + return new_strategy + @strategy_registry -class TuneStrategy(object): +class TuneStrategy(metaclass=TuneStrategyMeta): """Basic class for tuning strategy.""" def __init__(self, @@ -111,18 +135,19 @@ def __init__(self, self.q_func = q_func self.q_hooks = q_hooks GLOBAL_STATE.STATE = MODE.QUANTIZATION - framework, framework_specific_info = self._set_framework_info(q_dataloader, q_func) - self.adaptor = FRAMEWORKS[framework](framework_specific_info) - self.framework = framework - self.set_q_func() - self._set_objectives() + # following attributes may set by pre strategy: + # adaptor, framework, baseline, trials_count, capability, tuning_space, algo_scheduler + self._adaptor = None + self._framework = None + self.check_q_func() + self.objectives = self._set_objectives() self.tune_data = {} self.tune_result_record = [] self.tuning_history = [] self.tuning_result_data = [] - self.baseline = None + self._baseline = None self.last_tune_result = None self.last_qmodel = None self.last_tune_cfg = None @@ -131,23 +156,15 @@ def __init__(self, # track the best tuning config correspondence to the best quantized model self.best_tuning_cfg = None # track the current best accuracy - self.cur_best_acc = self.initial_best_acc() + self.cur_best_acc = None # track tuning cfg with the current best accuracy self.cur_best_tuning_cfg = {} self.re_quant = False - self.trials_count = 0 - - # query capability and build tuning space - self.capability = self.adaptor.query_fw_capability(model) - logger.debug(self.capability) - self.set_tuning_space(self.config) - # set algo scheduler - self.algo_scheduler = AlgorithmScheduler(self.config.recipes) - # reuse the calibration iteration - self.algo_scheduler.dataloader = self.calib_dataloader - self.algo_scheduler.origin_model = self.model - self.algo_scheduler.adaptor = self.adaptor + self._trials_count = 0 + self._capability = None + self._tuning_space = None + self._algo_scheduler = None self._optype_statistics = None self.fallback_stats_baseline = None @@ -169,6 +186,125 @@ def __init__(self, self._resume = resume if self._resume is not None: self.setup_resume(resume) + @property + def adaptor(self): + """Gets the adaptor.""" + return self._adaptor + + @adaptor.setter + def adaptor(self, value): + """Sets the adaptor. + + Args: + value: The new value for the adaptor. + """ + self._adaptor = value + + @property + def framework(self): + """Gets the framework.""" + return self._framework + + @framework.setter + def framework(self, value): + """Sets the framework. + + Args: + value: The new value for the framework. + """ + self._framework = value + + @property + def baseline(self): + """Gets the baseline.""" + return self._baseline + + @baseline.setter + def baseline(self, value): + """Sets the baseline. + + Args: + value (float): The new value for the baseline. + """ + self._baseline = value + + @property + def trials_count(self): + """Gets the trials_count.""" + return self._trials_count + + @trials_count.setter + def trials_count(self, value): + """Sets the trials_count. + + Args: + value (int): The new value for the trials_count. + """ + self._trials_count = value + + @property + def capability(self): + """Gets the capability.""" + return self._capability + + @capability.setter + def capability(self, value): + """Sets the capability. + + Args: + value: The new value for the capability. + """ + self._capability = value + + @property + def tuning_space(self): + """Gets the tuning_space.""" + return self._tuning_space + + @tuning_space.setter + def tuning_space(self, value): + """Sets the tuning_space. + + Args: + value (list): The new value for the tuning_space. + """ + self._tuning_space = value + + @property + def algo_scheduler(self): + """Gets the algo_scheduler.""" + return self._algo_scheduler + + @algo_scheduler.setter + def algo_scheduler(self, value): + """Sets the algo_scheduler. + + Args: + value: The new value for the algo_scheduler. + """ + self._algo_scheduler = value + + def _initialize_algo_scheduler(self): + algo_scheduler = AlgorithmScheduler(self.config.recipes) + # reuse the calibration iteration + algo_scheduler.dataloader = self.calib_dataloader + algo_scheduler.origin_model = self.model + algo_scheduler.adaptor = self.adaptor + return algo_scheduler + + def _prepare_tuning(self): + """Prepare to tune and avoid repeated initialization of the adaptor and tuning space.""" + framework, framework_specific_info = self._set_framework_info(self.calib_dataloader, self.q_func) + self.adaptor = self.adaptor or FRAMEWORKS[framework](framework_specific_info) + self.framework = self.framework or framework + self.cur_best_acc = self.cur_best_acc or self.initial_best_acc() + # query capability and build tuning space + self.capability = self.capability or self.adaptor.query_fw_capability(self.model) + logger.debug(self.capability) + self.tuning_space = self.tuning_space or self.build_tuning_space(self.config) + self.algo_scheduler = self.algo_scheduler or self._initialize_algo_scheduler() + self._eval_baseline() + def _check_tuning_status(self): # got eval func if self.eval_func: @@ -187,11 +323,11 @@ def _check_tuning_status(self): return else: # got eval dataloader but not eval metric - if self.eval_dataloader: + if self.eval_dataloader: # pragma: no cover assert self.eval_metric, "Detected evaluation dataloader but no evaluation metric, " \ "Please provide both to perform tuning process or neither for the default quantization." # got eval metric but not eval dataloader - if self.eval_metric: + if self.eval_metric: # pragma: no cover assert self.eval_dataloader, "Detected evaluation metric but no evaluation dataloader, "\ "Please provide both to perform tuning process or neither for the default quantization." # not tuning @@ -231,7 +367,7 @@ def traverse(self): The main traverse logic which could be override by some concrete strategy which needs more hooks. """ - self._eval_baseline() + self._prepare_tuning() if self.config.use_distributed_tuning: logger.info("use distributed traverse: {}".format(self.config.use_distributed_tuning)) return self.distributed_traverse() @@ -241,7 +377,7 @@ def traverse(self): tune_cfg = self._tune_cfg_converter(op_tuning_cfg) self.trials_count += 1 tuning_history = self._find_tuning_history(tune_cfg) - if tuning_history and self.trials_count < self.config.tuning_criterion.max_trials: + if tuning_history and self.trials_count < self.config.tuning_criterion.max_trials: # pragma: no cover self.last_tune_result = tuning_history['last_tune_result'] self.best_tune_result = tuning_history['best_tune_result'] logger.warn("Find evaluated tuning config, skip.") @@ -252,14 +388,14 @@ def traverse(self): self.tuning_times += 1 # set the parameter for pre quantization algos and run self.set_param_for_pre_quantization_algos(self.algo_scheduler, tune_cfg, self.model) - self.model = self.algo_scheduler('pre_quantization') + self.model = self.algo_scheduler('pre_quantization') # pylint: disable=E1102 # quantize q_model = self.adaptor.quantize(copy.deepcopy(tune_cfg), self.model, self.calib_dataloader, self.q_func) assert self.adaptor.pre_optimized_model # set the parameter for post quantization algos and run self.set_param_for_post_quantization_algos(self.algo_scheduler, tune_cfg,\ self.adaptor.pre_optimized_model, q_model) - self.last_qmodel = self.algo_scheduler('post_quantization') + self.last_qmodel = self.algo_scheduler('post_quantization') # pylint: disable=E1102 self.last_tune_cfg = copy.deepcopy(tune_cfg) # remove the reference to model self.algo_scheduler.reset_exec_algorithms() @@ -303,7 +439,7 @@ def traverse(self): logger.debug(f'*** Start to do diagnosis (inspect tensor).') self._diagnosis() if self.use_multi_objective and len(self.tune_result_record) > 1 and \ - self.best_tune_result is not None: + self.best_tune_result is not None: # pragma: no cover best_trail, best_result = self.objectives.best_result(self.tune_result_record, copy.deepcopy(self.baseline)) if best_result != self.best_tune_result: @@ -429,6 +565,9 @@ def master_worker_handle(self, comm): # record eval_results for context coordination of stage 3 self.last_tune_result = eval_res + self.objectives.val = eval_res + self.trials_count = self.overall_trials + 1 + self.stop(self.config.tuning_criterion.timeout, None) self.eval_results[tag] = eval_res self.overall_trials += 1 @@ -543,14 +682,14 @@ def slave_worker_handle(self, comm): # set the parameter for pre quantization algos and run self.set_param_for_pre_quantization_algos(self.algo_scheduler, tune_cfg, self.model) - self.model = self.algo_scheduler('pre_quantization') + self.model = self.algo_scheduler('pre_quantization') # pylint: disable=E1102 # quantize q_model = self.adaptor.quantize(copy.deepcopy(tune_cfg), self.model, self.calib_dataloader, self.q_func) assert self.adaptor.pre_optimized_model # set the parameter for post quantization algos and run self.set_param_for_post_quantization_algos(self.algo_scheduler, tune_cfg, self.adaptor.pre_optimized_model, q_model) - self.last_qmodel = self.algo_scheduler('post_quantization') + self.last_qmodel = self.algo_scheduler('post_quantization') # pylint: disable=E1102 self.last_tune_cfg = copy.deepcopy(tune_cfg) # Remove the reference to model self.algo_scheduler.reset_exec_algorithms() @@ -571,6 +710,7 @@ def distributed_traverse(self): The main traverse logic which could be override by some concrete strategy which needs more hooks. """ + self._prepare_tuning() MPI = LazyImport("mpi4py.MPI") comm = MPI.COMM_WORLD rank = comm.Get_rank() @@ -634,7 +774,7 @@ def apply_recipe_one_by_one(self, tune_cfg): new_tune_cfg = self._fallback_ops(copy.deepcopy(tune_cfg), \ self.capability['recipes_ops'][recipe_name], self.tuning_space) yield new_tune_cfg - if recipe_name == "smooth_quant": + if recipe_name == "smooth_quant": # pragma: no cover sq_args = {'smooth_quant': True} if 'recipe_cfgs' not in new_tune_cfg: new_tune_cfg['recipe_cfgs'] = sq_args @@ -690,13 +830,13 @@ def set_param_for_post_quantization_algos(self, algo_scheduler, tune_cfg, pre_op algo_scheduler.reset_exec_algorithms() recipe_cfgs = tune_cfg.get('recipe_cfgs', None) # for fast_bias_correction - if recipe_cfgs and recipe_cfgs.get('fast_bias_correction', False): + if recipe_cfgs and recipe_cfgs.get('fast_bias_correction', False): # pragma: no cover fbc_algo = ALGORITHMS()['fast_bias_correction'] fbc_algo.quantization_cfg = deepcopy(tune_cfg) algo_scheduler.append_algorithm('post_quantization', fbc_algo) logger.debug(f"Add fast bias correction as the post quantization algo.") # for weight correction - if recipe_cfgs and recipe_cfgs.get('weight_correction', False): + if recipe_cfgs and recipe_cfgs.get('weight_correction', False): # pragma: no cover w_algo = ALGORITHMS()['weight_correction'] w_algo.quantization_cfg = deepcopy(tune_cfg) algo_scheduler.append_algorithm('post_quantization', w_algo) @@ -790,13 +930,13 @@ def _compare_optype_statistics(self, fields=None, optypes=None, adaptor_statistics = self.adaptor.optype_statistics def _field_skipped(field): - if fields != None: + if fields != None: # pragma: no cover return field not in fields elif skip_fields != None: return field in skip_fields def _optype_skipped(optype): - if optypes != None: + if optypes != None: # pragma: no cover return optype not in optypes elif skip_optypes != None: return optype in skip_optypes @@ -955,7 +1095,7 @@ def _tune_cfg_converter(self, op_tuning_cfg): tune_cfg['recipe_cfgs'][recipe_name] = recipe_val return tune_cfg - def set_tuning_space(self, config): + def build_tuning_space(self, config): """Create the tuning space. Create the tuning space based on the framework capability and user configuration. @@ -975,7 +1115,8 @@ def set_tuning_space(self, config): 'calib': {'calib_sampling_size': calib_sampling_size_lst}, 'op': self.capability['opwise'] } - self.tuning_space = TuningSpace(adaptor_cap, conf=config, framework=self.framework) + tuning_space = TuningSpace(adaptor_cap, conf=config, framework=self.framework) + return tuning_space def setup_resume(self, resume): """Resume the best quantized model from tuning history. @@ -985,7 +1126,7 @@ def setup_resume(self, resume): """ self.__dict__.update(resume) for history in self.tuning_history: - if self._same_conf(history['cfg'], self.conf): + if self._same_conf(history['cfg'], self.conf): # pragma: no cover self.__dict__.update({k: v for k, v in history.items() \ if k not in ['version', 'history']}) logger.info("Start to resume tuning process.") @@ -1002,8 +1143,8 @@ def setup_resume(self, resume): break - def set_q_func(self): - """Set the training function for quantization aware training.""" + def check_q_func(self): + """Check the training function for quantization aware training.""" if self.config.approach == 'quant_aware_training': assert self.q_func != None, "Please set train func for quantization aware training" @@ -1121,12 +1262,13 @@ def _set_objectives(self): accuracy_criterion_conf = self.config.accuracy_criterion accuracy_criterion[accuracy_criterion_conf.criterion] = accuracy_criterion_conf.tolerable_loss accuracy_criterion['higher_is_better'] = accuracy_criterion_conf.higher_is_better - self.objectives = MultiObjective(objectives=objectives, - accuracy_criterion=accuracy_criterion, - metric_criterion=self.metric_criterion, - metric_weight=self.metric_weight, - obj_criterion=obj_higher_is_better, - obj_weight=obj_weight) + objectives = MultiObjective(objectives=objectives, + accuracy_criterion=accuracy_criterion, + metric_criterion=self.metric_criterion, + metric_weight=self.metric_weight, + obj_criterion=obj_higher_is_better, + obj_weight=obj_weight) + return objectives def _same_conf(self, src_conf, dst_conf): """Check if the two configs are the same.""" diff --git a/test/strategy/test_distributed_tuning.py b/test/strategy/test_distributed_tuning.py index 24340826f0f..0a00d2e616a 100644 --- a/test/strategy/test_distributed_tuning.py +++ b/test/strategy/test_distributed_tuning.py @@ -91,10 +91,10 @@ def test_pt_stage_1_met(self): # fake evaluation function num_baseline = num_processes # TODO, replace num_baseline with 1 when evaluating baseline only once. acc_lst = [2.0] * num_baseline + [1.0, 2.1, 2.2, 2.3, 2.0] #the tuning result (2.1) - perf_lst = [2.0] * num_baseline + [2.5, 2.0, 1.5, 1.1, 5.0] + perf_lst = [2.0] * num_baseline + [2.5, 2.0, 1.5, 1.1, 5.0] # make sure this path can be accessed by all nodes - acc_perf_data_file_path = 'test_pt_stage_1_met.json' + acc_perf_data_file_path = 'test_pt_stage_1_met.json' save_acc_perf_to_local(acc_lst, perf_lst, acc_perf_data_file_path) def _fake_eval(model): @@ -108,7 +108,7 @@ def _fake_eval(model): dataloader = DATALOADERS["pytorch"](dataset) # tuning and accuracy criterion - conf = PostTrainingQuantConfig(use_distributed_tuning=True) + conf = PostTrainingQuantConfig(quant_level=1, use_distributed_tuning=True) # fit q_model = fit(model=resnet18, conf=conf, @@ -133,7 +133,7 @@ def test_pt_stage_3_fp32_met(self): perf_lst = [2.0] * num_baseline + [1.0] * 16 + [1.0, 1.0, 1.0] # make sure this path can be accessed by all nodes - acc_perf_data_file_path = 'test_pt_stage_3_fp32_met.json' + acc_perf_data_file_path = 'test_pt_stage_3_fp32_met.json' save_acc_perf_to_local(acc_lst, perf_lst, acc_perf_data_file_path) def _fake_eval(model): @@ -147,7 +147,7 @@ def _fake_eval(model): dataloader = DATALOADERS["pytorch"](dataset) # tuning and accuracy criterion - conf = PostTrainingQuantConfig(use_distributed_tuning=True) + conf = PostTrainingQuantConfig(quant_level=1, use_distributed_tuning=True) # fit q_model = fit(model=resnet18, conf=conf, @@ -172,7 +172,7 @@ def test_pt_stage_4_fp32_met(self): perf_lst = [2.0] * num_baseline + [1.0] * 37 + [1.0, 1.0, 1.0] # make sure this path can be accessed by all nodes - acc_perf_data_file_path = 'test_pt_stage_stage_4_fp32_met.json' + acc_perf_data_file_path = 'test_pt_stage_stage_4_fp32_met.json' save_acc_perf_to_local(acc_lst, perf_lst, acc_perf_data_file_path) def _fake_eval(model): @@ -186,7 +186,7 @@ def _fake_eval(model): dataloader = DATALOADERS["pytorch"](dataset) # tuning and accuracy criterion - conf = PostTrainingQuantConfig(use_distributed_tuning=True) + conf = PostTrainingQuantConfig(quant_level=1, use_distributed_tuning=True) # fit q_model = fit(model=resnet18, conf=conf, @@ -210,7 +210,7 @@ def test_pt_stage_not_met(self): perf_lst = [2.0] * num_baseline + [1.0] * 57 # make sure this path can be accessed by all nodes - acc_perf_data_file_path = 'test_pt_stage_not_met.json' + acc_perf_data_file_path = 'test_pt_stage_not_met.json' save_acc_perf_to_local(acc_lst, perf_lst, acc_perf_data_file_path) def _fake_eval(model): @@ -224,7 +224,7 @@ def _fake_eval(model): dataloader = DATALOADERS["pytorch"](dataset) # tuning and accuracy criterion - conf = PostTrainingQuantConfig(use_distributed_tuning=True) + conf = PostTrainingQuantConfig(quant_level=1, use_distributed_tuning=True) # fit q_model = fit(model=resnet18, conf=conf, @@ -263,7 +263,7 @@ def _fake_eval(model): dataloader = DATALOADERS["pytorch"](dataset) # tuning and accuracy criterion - conf = PostTrainingQuantConfig(use_distributed_tuning=True) + conf = PostTrainingQuantConfig(quant_level=1, use_distributed_tuning=True) # fit q_model = fit(model=resnet18, conf=conf, diff --git a/test/strategy/test_new_datatype.py b/test/strategy/test_new_datatype.py index 599dfa2a947..02e2d48d1f2 100644 --- a/test/strategy/test_new_datatype.py +++ b/test/strategy/test_new_datatype.py @@ -40,14 +40,14 @@ def add_cap(filename): }, } } - + with open(filename) as f: con = yaml.safe_load(f) con[0]['int4'] = int4_cap with open(filename, 'w') as out: yaml.dump(con, out) -class TestBasicTuningStrategy(unittest.TestCase): +class TestAddNewDataType(unittest.TestCase): @classmethod def setUpClass(self): @@ -56,7 +56,7 @@ def setUpClass(self): @classmethod def tearDownClass(self): shutil.rmtree('saved', ignore_errors=True) - + def test_add_int4(self): import shutil import importlib @@ -68,12 +68,12 @@ def test_add_int4(self): from neural_compressor.quantization import fit from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion from neural_compressor.data import Datasets, DATALOADERS - + # dataset and dataloader dataset = Datasets("pytorch")["dummy"](((100, 3, 224, 224))) dataloader = DATALOADERS["pytorch"](dataset) model = build_model() - + def fake_eval(model): return 1