diff --git a/neural_compressor/__init__.py b/neural_compressor/__init__.py index 35c8cf0c357..bee0012eb81 100644 --- a/neural_compressor/__init__.py +++ b/neural_compressor/__init__.py @@ -22,7 +22,7 @@ # we need to set a global 'NA' backend, or Model can't be used from .utils.utility import set_random_seed, set_tensorboard, set_workspace from .utils import options -from .config import conf +# from .config import conf from .config import DistillationConfig, PostTrainingQuantConfig, \ WeightPruningConfig, QuantizationAwareTrainingConfig, \ MixedPrecisionConfig diff --git a/neural_compressor/adaptor/mxnet.py b/neural_compressor/adaptor/mxnet.py index 6a6e9087148..bf368651353 100644 --- a/neural_compressor/adaptor/mxnet.py +++ b/neural_compressor/adaptor/mxnet.py @@ -25,7 +25,7 @@ dump_elapsed_time, singleton) from neural_compressor.adaptor.mxnet_utils.util import * from collections import OrderedDict -from ..experimental.data.dataloaders.base_dataloader import BaseDataLoader +from neural_compressor.data.dataloaders.base_dataloader import BaseDataLoader from copy import deepcopy import math diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 0ab92ada06d..5a1adb95175 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -31,7 +31,7 @@ from neural_compressor.utils.utility import LazyImport, dump_elapsed_time, \ GLOBAL_STATE, MODE from neural_compressor.utils.utility import Statistics -from neural_compressor.experimental.data.dataloaders.base_dataloader import BaseDataLoader +from neural_compressor.data.dataloaders.base_dataloader import BaseDataLoader from neural_compressor.conf.dotdict import deep_get from neural_compressor.utils.utility import CpuInfo import math diff --git a/neural_compressor/contrib/strategy/tpe.py b/neural_compressor/contrib/strategy/tpe.py index 2082fa8021b..8e8d1f653ac 100644 --- a/neural_compressor/contrib/strategy/tpe.py +++ b/neural_compressor/contrib/strategy/tpe.py @@ -505,7 +505,7 @@ def stop(self, timeout, trials_count): if timeout == 0 and self.best_tune_result: need_stop = True - elif trials_count >= self.cfg.tuning.exit_policy.max_trials: + elif trials_count >= self.conf.quantization.tuning_criterion.max_trials: need_stop = True else: need_stop = False diff --git a/neural_compressor/experimental/graph_optimization.py b/neural_compressor/experimental/graph_optimization.py index d1351a5b4d5..a47a221c02d 100644 --- a/neural_compressor/experimental/graph_optimization.py +++ b/neural_compressor/experimental/graph_optimization.py @@ -25,7 +25,7 @@ import yaml from ..conf.config import Graph_Optimization_Conf from ..conf.dotdict import deep_get, deep_set, DotDict -from ..strategy import STRATEGIES +from .strategy import EXP_STRATEGIES from ..utils import logger from ..utils.create_obj_from_config import create_dataloader from ..utils.utility import CpuInfo, time_limit @@ -139,7 +139,7 @@ def __call__(self): strategy = cfg.tuning.strategy.name.lower() - assert strategy in STRATEGIES, "Tuning strategy {} is NOT supported".format(strategy) + assert strategy in EXP_STRATEGIES, "Tuning strategy {} is NOT supported".format(strategy) _resume = None # check if interrupted tuning procedure exists. if yes, it will resume the @@ -152,7 +152,7 @@ def __call__(self): with open(self.resume_file, 'rb') as f: _resume = pickle.load(f).__dict__ - self.strategy = STRATEGIES[strategy]( + self.strategy = EXP_STRATEGIES[strategy]( self._model, self.conf, None, diff --git a/neural_compressor/experimental/mixed_precision.py b/neural_compressor/experimental/mixed_precision.py index 438f2e749bb..448e3bab6a8 100644 --- a/neural_compressor/experimental/mixed_precision.py +++ b/neural_compressor/experimental/mixed_precision.py @@ -24,7 +24,7 @@ from ..conf.config import MixedPrecision_Conf from ..conf.pythonic_config import Config from ..conf.dotdict import deep_get -from ..strategy import STRATEGIES +from .strategy import EXP_STRATEGIES from ..utils import logger from ..utils.create_obj_from_config import create_dataloader from ..utils.utility import CpuInfo, time_limit @@ -149,7 +149,7 @@ def __call__(self): strategy = cfg.tuning.strategy.name.lower() - assert strategy in STRATEGIES, "Tuning strategy {} is NOT supported".format(strategy) + assert strategy in EXP_STRATEGIES, "Tuning strategy {} is NOT supported".format(strategy) _resume = None # check if interrupted tuning procedure exists. if yes, it will resume the @@ -162,7 +162,7 @@ def __call__(self): with open(self.resume_file, 'rb') as f: _resume = pickle.load(f).__dict__ - self.strategy = STRATEGIES[strategy]( + self.strategy = EXP_STRATEGIES[strategy]( self._model, self.conf, None, diff --git a/neural_compressor/experimental/quantization.py b/neural_compressor/experimental/quantization.py index 3701f4e3def..8aa059242df 100644 --- a/neural_compressor/experimental/quantization.py +++ b/neural_compressor/experimental/quantization.py @@ -23,7 +23,7 @@ import numpy as np from .component import Component from ..conf.dotdict import deep_get, deep_set, DotDict -from ..strategy import STRATEGIES +from .strategy import EXP_STRATEGIES from ..utils import logger from ..utils.utility import time_limit from ..utils.create_obj_from_config import create_dataloader @@ -144,7 +144,7 @@ def pre_process(self): strategy = "basic" logger.warning(f"MSE_v2 does not support {self.framework} now, use basic instead.") logger.warning("Only tensorflow, pytorch_fx is supported by MSE_v2 currently.") - assert strategy in STRATEGIES, "Tuning strategy {} is NOT supported".format(strategy) + assert strategy in EXP_STRATEGIES, "Tuning strategy {} is NOT supported".format(strategy) _resume = None # check if interrupted tuning procedure exists. if yes, it will resume the @@ -157,7 +157,7 @@ def pre_process(self): with open(self.resume_file, 'rb') as f: _resume = pickle.load(f).__dict__ - self.strategy = STRATEGIES[strategy]( + self.strategy = EXP_STRATEGIES[strategy]( self._model, self.conf, self._calib_dataloader, diff --git a/neural_compressor/experimental/strategy/__init__.py b/neural_compressor/experimental/strategy/__init__.py new file mode 100644 index 00000000000..f4a137cb792 --- /dev/null +++ b/neural_compressor/experimental/strategy/__init__.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Intel Neural Compressor Strategy.""" + +from .strategy import EXP_STRATEGIES +from os.path import dirname, basename, isfile, join +import glob + +modules = glob.glob(join(dirname(__file__), "*.py")) + +for f in modules: + if isfile(f) and not f.startswith('__') and not f.endswith('__init__.py'): + __import__(basename(f)[:-3], globals(), locals(), level=1) + +__all__ = ["EXP_STRATEGIES"] diff --git a/neural_compressor/experimental/strategy/auto_mixed_precision.py b/neural_compressor/experimental/strategy/auto_mixed_precision.py new file mode 100644 index 00000000000..76ef7c8bb7e --- /dev/null +++ b/neural_compressor/experimental/strategy/auto_mixed_precision.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The auto-mixed precision strategy.""" + +import copy +import numpy as np +from collections import OrderedDict +from .strategy import strategy_registry, TuneStrategy +from ...utils import logger + +from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler +from .utils.tuning_structs import OpTuningConfig + + +@strategy_registry +class AutoMixedPrecisionTuneStrategy(TuneStrategy): + """Tuning strategy for auto mixed precision.""" + + def next_tune_cfg(self): + """Generate the next tuning config. + + Tuning configurations are generated according to the following rules: + 1. First, it tries to convert all ops into target date type as many as possible. + 2. If the accuracy does not meets the requirements, it starts the stage of fallback + which converts ops into higher precision. + + Yields: + tune_config (dict): A dict containing the tuning configuration. + """ + from copy import deepcopy + + # filter quantization dtype + # TODO align with the old mixed-precison + target_dtypes = self.cfg.graph_optimization.precisions if self.cfg.graph_optimization \ + else self.cfg.mixed_precision.precisions + target_dtypes = list(set(target_dtypes) - set(['fp32'])) + tuning_space = self.tuning_space + initial_op_tuning_cfg = {} + for item in tuning_space.root_item.options: + if item.item_type == 'op': + op_name, op_type = item.name + initial_op_tuning_cfg[item.name] = OpTuningConfig(op_name, op_type, 'fp32', tuning_space) + + if not target_dtypes: + target_dtypes = ['bf16'] + # step1. target_dtype AMAP, collect the ops that support target_dtype + bf16_items_name = [] + op_tuning_cfg = {} + for idx, target_dtype in enumerate(target_dtypes): + bf16_items = tuning_space.query_items_by_quant_mode(target_dtype) + if len(bf16_items) == 0 and \ + not (idx == len(target_dtypes) - 1 and len(bf16_items_name) == 0): + continue + bf16_items_name = [item.name for item in bf16_items] + op_tuning_cfg = deepcopy(initial_op_tuning_cfg) + for op_name_type in bf16_items_name: + op_tuning_cfg[op_name_type] = \ + OpTuningConfig(op_name_type[0], op_name_type[1], target_dtype, tuning_space) + calib_sampling_size = 1 + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg + + # step2. fallback + target_dtype = 'fp32' + fallback_items_name_lst = bf16_items_name[::-1] + if fallback_items_name_lst: + logger.info(f"Start to fallback op to {target_dtype} one by one.") + self._fallback_started() + op_dtypes = OrderedDict(zip(fallback_items_name_lst, [target_dtype] * len(fallback_items_name_lst))) + initial_op_tuning_cfg = deepcopy(op_tuning_cfg) + fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], + initial_op_tuning_cfg=initial_op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=False) + op_fallback_acc_impact = OrderedDict() + for op_index, op_tuning_cfg in enumerate(fallback_sampler): + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg + acc, _ = self.last_tune_result + op_fallback_acc_impact[fallback_items_name_lst[op_index]] = acc + + # do accumulated fallback according to the order in the previous stage + if len(op_fallback_acc_impact) > 0: + ordered_ops = sorted(op_fallback_acc_impact.keys(), key=lambda key: op_fallback_acc_impact[key], + reverse=self.higher_is_better) + op_dtypes = OrderedDict(zip(ordered_ops, [target_dtype] * len(fallback_items_name_lst))) + logger.info("Start to accumulate fallback to {target_dtype}.") + initial_op_tuning_cfg = deepcopy(op_tuning_cfg) + fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], + initial_op_tuning_cfg=initial_op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=True) + for op_tuning_cfg in fallback_sampler: + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg + + def traverse(self): + """Traverse the tuning space according to auto-mixed precision strategy.""" + # get fp32 model baseline + self._eval_baseline() + + trials_count = 0 + for op_tuning_cfg in self.next_tune_cfg(): + # add tune_cfg here as quantize use tune_cfg + tune_cfg = self._tune_cfg_converter(op_tuning_cfg) + trials_count += 1 + tuning_history = self._find_tuning_history(tune_cfg) + if tuning_history and trials_count < self.cfg.tuning.exit_policy.max_trials: + self.last_tune_result = tuning_history['last_tune_result'] + self.best_tune_result = tuning_history['best_tune_result'] + logger.warn("Find evaluated tuning config, skip.") + continue + + logger.debug("Dump current mixed precision configuration:") + logger.debug(tune_cfg) + self.last_qmodel = self.adaptor.quantize( + tune_cfg, self.model, self.calib_dataloader, self.q_func) + assert self.last_qmodel + # Return the last quantized model as a result. if performance only. + if self.cfg.tuning.exit_policy.performance_only: + self.best_qmodel = self.last_qmodel + self._add_tuning_history(copy.deepcopy(tune_cfg), (-1, [0]), q_config=self.last_qmodel.q_config) + return + self.last_tune_cfg = copy.deepcopy(tune_cfg) + if self.eval_dataloader or self.eval_func: + q_config = copy.deepcopy(self.last_qmodel.q_config) + self.last_tune_result = self._evaluate(self.last_qmodel) + self.cur_best_acc, self.cur_best_tuning_cfg = self.update_best_op_tuning_cfg(op_tuning_cfg) + need_stop = self.stop(self.cfg.tuning.exit_policy.timeout, trials_count) + # record the tuning history + saved_tune_cfg = copy.deepcopy(tune_cfg) + saved_last_tune_result = copy.deepcopy(self.last_tune_result) + self._add_tuning_history(saved_tune_cfg, saved_last_tune_result, q_config=q_config) + else: + # If the eval_dataloader was not specified under the config yaml file, + # We only converted the model with customized precisions. + self.best_qmodel = self.last_qmodel + need_stop = True + + if need_stop: + break + + diff --git a/neural_compressor/experimental/strategy/basic.py b/neural_compressor/experimental/strategy/basic.py new file mode 100644 index 00000000000..33ea8c7d675 --- /dev/null +++ b/neural_compressor/experimental/strategy/basic.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The basic tuning strategy.""" +import copy +import numpy as np +from collections import OrderedDict +from .strategy import strategy_registry, TuneStrategy +from ...utils import logger + +from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler, ModelWiseTuningSampler +from .utils.tuning_structs import OpTuningConfig +from .utils.constant import TUNING_ITEMS_LST + +@strategy_registry +class BasicTuneStrategy(TuneStrategy): + """The basic tuning strategy. + + There are three stages executed by Basic strategy sequentially, + and the tuning process ends once the condition meets the exit policy. + """ + + def distributed_next_tune_cfg_lst(self, comm): + """Generate and yield the next tuning config list with below order. + + 1. OP Type Wise Tuning + 2. Fallback OP One by One + 3. Fallback Multiple OPs Accumulated + + Yields: + tuning_config_list (list): A list containing dicts of the tuning configuration for quantization. + """ + from copy import deepcopy + tuning_space = self.tuning_space + calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options + rank = comm.Get_rank() + for calib_sampling_size in calib_sampling_size_lst: + # Initialize the tuning config for each op according to the quantization approach + op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg = self.initial_tuning_cfg() + # Optype-wise tuning tuning items: the algorithm/scheme/granularity of activation(weight) + early_stop_tuning = False + stage1_cnt = 0 + quant_ops = quant_mode_wise_items['static'] if 'static' in quant_mode_wise_items else [] + quant_ops += quant_mode_wise_items['dynamic'] if 'dynamic' in quant_mode_wise_items else [] + stage1_max = 1e9 # TODO set a more appropriate value + op_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [], + op_item_dtype_dict, initial_op_tuning_cfg) + ############ stage 1: yield op_tune_cfg_lst + op_tuning_cfg_lst_stage_1 = [] + for op_tuning_cfg in op_wise_tuning_sampler: + stage1_cnt += 1 + if early_stop_tuning and stage1_cnt > stage1_max: + logger.info("Early stopping the stage 1.") + break + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + op_tuning_cfg_lst_stage_1.append(deepcopy(op_tuning_cfg)) + logger.info("yield op_tuning_cfg_lst_stage_1 with length {}".format(len(op_tuning_cfg_lst_stage_1))) + yield op_tuning_cfg_lst_stage_1 + + #### Coordinate: only master knows cur best tune cfg + cur_best_tuning_cfg = self.cur_best_tuning_cfg if rank == 0 else None + if rank == 0: + comm.bcast(cur_best_tuning_cfg, root=0) + else: + self.cur_best_tuning_cfg = comm.bcast(cur_best_tuning_cfg, root=0) + + ############ stage 2: yield new_op_tuning_cfg_lst (length of 1) + # Fallback the ops supported both static and dynamic from static to dynamic + # Tuning items: None + if self.cfg.quantization.approach == 'post_training_auto_quant': + static_dynamic_items = [item for item in tuning_space.query_items_by_quant_mode('static') if + item in tuning_space.query_items_by_quant_mode('dynamic')] + if static_dynamic_items: + logger.info("Fallback all ops that support both dynamic and static to dynamic.") + else: + logger.info("Non ops that support both dynamic") + + new_op_tuning_cfg = deepcopy(self.cur_best_tuning_cfg) + for item in static_dynamic_items: + new_op_tuning_cfg[item.name] = self._initial_dynamic_cfg_based_on_static_cfg( + new_op_tuning_cfg[item.name]) + new_op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + op_tuning_cfg_lst_stage_2 = [deepcopy(new_op_tuning_cfg)] + logger.info("yield op_tuning_cfg_lst_stage_2 with length {}".format(len(op_tuning_cfg_lst_stage_2))) + yield op_tuning_cfg_lst_stage_2 + + #### Coordinate: only master knows cur best tune cfg + cur_best_tuning_cfg = self.cur_best_tuning_cfg if rank == 0 else None + if rank == 0: + comm.bcast(cur_best_tuning_cfg, root=0) + else: + self.cur_best_tuning_cfg = comm.bcast(cur_best_tuning_cfg, root=0) + + best_op_tuning_cfg_stage1 = deepcopy(self.cur_best_tuning_cfg) + + # Fallback + ############ stage 3, 4: yield op_tuning_cfg_lst + op_tuning_cfg_lst_stage_3 = [] + op_tuning_cfg_lst_stage_4 = [] + for target_dtype in ['bf16', 'fp32']: + target_type_lst = set(tuning_space.query_items_by_quant_mode(target_dtype)) + fallback_items_lst = [item for item in quant_ops if item in target_type_lst] + if fallback_items_lst: + logger.info(f"Start to fallback op to {target_dtype} one by one.") + self._fallback_started() + fallback_items_name_lst = [item.name for item in fallback_items_lst][::-1] # from bottom to up + op_dtypes = OrderedDict(zip(fallback_items_name_lst, [target_dtype] * len(fallback_items_name_lst))) + initial_op_tuning_cfg = deepcopy(best_op_tuning_cfg_stage1) + fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], + initial_op_tuning_cfg=initial_op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=False) + op_fallback_acc_impact = OrderedDict() + for op_index, op_tuning_cfg in enumerate(fallback_sampler): + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + # yield op_tuning_cfg + op_tuning_cfg_lst_stage_3.append(deepcopy(op_tuning_cfg)) + logger.info("yield op_tuning_cfg_lst_stage_3 with length {}".format(len(op_tuning_cfg_lst_stage_3))) + yield op_tuning_cfg_lst_stage_3 + + # Only master updates op_fallback_acc_impact + if rank == 0: + for op_index, op_tuning_cfg in enumerate(fallback_sampler): + acc, _ = self.eval_results[op_index] + op_fallback_acc_impact[fallback_items_name_lst[op_index]] = acc + + #### Coordinate: only master knows op_fallback_acc_impact + op_fallback_acc_impact = op_fallback_acc_impact if rank == 0 else None + if rank == 0: + comm.bcast(op_fallback_acc_impact, root=0) + else: + op_fallback_acc_impact = comm.bcast(op_fallback_acc_impact, root=0) + + # Fallback OPs accumulated according to the order in the previous stage + if len(op_fallback_acc_impact) > 0: + ordered_ops = sorted(op_fallback_acc_impact.keys(), + key=lambda key: op_fallback_acc_impact[key], + reverse=self.higher_is_better) + op_dtypes = OrderedDict(zip(ordered_ops, [target_dtype] * len(fallback_items_name_lst))) + logger.info(f"Start to accumulate fallback to {target_dtype}.") + initial_op_tuning_cfg = deepcopy(best_op_tuning_cfg_stage1) + fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], + initial_op_tuning_cfg=initial_op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=True) + for op_tuning_cfg in fallback_sampler: + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + # yield op_tuning_cfg + op_tuning_cfg_lst_stage_4.append(deepcopy(op_tuning_cfg)) + logger.info("yield op_tuning_cfg_lst_stage_4 with length {}".format(len(op_tuning_cfg_lst_stage_4))) + yield op_tuning_cfg_lst_stage_4 + + def next_tune_cfg(self): + """Generate and yield the next tuning config with below order. + + 1. OP Type Wise Tuning: tries to quantize the OPs as many as possible + and traverse all OP type wise tuning configs + 2. Fallback OP One by One: it performs high-precision OP (FP32, BF16 ...) + fallbacks one by one based on the tuning config with the best result + in the previous stage, and records the impact of each OP. + 3. Fallback Multiple OPs Accumulated: first sorted the OPs list + according to the impact score in stage II, and tries to incrementally + fallback multiple OPs to high precision according to the sorted OP list. + + Returns: + tune_config (dict): A dict containing the tuning configuration for quantization. + """ + from copy import deepcopy + tuning_space = self.tuning_space + calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options + for calib_sampling_size in calib_sampling_size_lst: + # Initialize the tuning config for each op according to the quantization approach. + op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg = self.initial_tuning_cfg() + # Optype-wise tuning tuning items: the algorithm/scheme/granularity of activation(weight) + early_stop_tuning = False + stage1_cnt = 0 + quant_ops = quant_mode_wise_items.get('static', []) + quant_ops += quant_mode_wise_items.get('dynamic', []) + stage1_max = 1e9 # TODO set a more appropriate value + op_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [], + op_item_dtype_dict, initial_op_tuning_cfg) + for index, op_tuning_cfg in enumerate(op_wise_tuning_sampler): + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + # Apply all recipes, if not got the qmodel that meet the requirements, discard it. + if index == 1 and not self.applied_all_recipes_flag: + logger.info("Apply all recipes.") + self.applied_all_recipes_flag = True + yield self.apply_all_tuning_recipes(deepcopy(self.cur_best_tuning_cfg)) + stage1_cnt += 1 + if early_stop_tuning and stage1_cnt > stage1_max: + logger.info("Early stopping the stage 1.") + break + yield op_tuning_cfg + + # Apply all recipes, if not got the qmodel that meet the requirements, discard it. + if stage1_cnt == 1 and not self.applied_all_recipes_flag: + logger.info("Apply all recipes.") + self.applied_all_recipes_flag = True + yield self.apply_all_tuning_recipes(deepcopy(self.cur_best_tuning_cfg)) + + # Fallback the ops supported both static and dynamic from static to dynamic + # Tuning items: None + if self.cfg.quantization.approach == 'post_training_auto_quant': + static_dynamic_items = [item for item in tuning_space.query_items_by_quant_mode('static') if + item in tuning_space.query_items_by_quant_mode('dynamic')] + if static_dynamic_items: + logger.info("Fallback all ops that support both dynamic and static to dynamic.") + else: + logger.info("Non ops that support both dynamic") + + new_op_tuning_cfg = deepcopy(self.cur_best_tuning_cfg) + for item in static_dynamic_items: + new_op_tuning_cfg[item.name] = self._initial_dynamic_cfg_based_on_static_cfg( + new_op_tuning_cfg[item.name]) + new_op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield new_op_tuning_cfg + + logger.info("Apply recipe one by one.") + for tune_cfg in self.apply_recipe_one_by_one(deepcopy(self.cur_best_tuning_cfg)): + yield tune_cfg + best_op_tuning_cfg_stage1 = deepcopy(self.cur_best_tuning_cfg) + + # Fallback + for target_dtype in ['bf16', 'fp32']: + target_type_lst = set(tuning_space.query_items_by_quant_mode(target_dtype)) + fallback_items_lst = [item for item in quant_ops if item in target_type_lst] + if fallback_items_lst: + logger.info(f"Start to fallback op to {target_dtype} one by one.") + self._fallback_started() + fallback_items_name_lst = [item.name for item in fallback_items_lst][::-1] # from bottom to up + op_dtypes = OrderedDict(zip(fallback_items_name_lst, [target_dtype] * len(fallback_items_name_lst))) + initial_op_tuning_cfg = deepcopy(best_op_tuning_cfg_stage1) + fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], + initial_op_tuning_cfg=initial_op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=False) + op_fallback_acc_impact = OrderedDict() + for op_index, op_tuning_cfg in enumerate(fallback_sampler): + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg + acc, _ = self.last_tune_result + op_fallback_acc_impact[fallback_items_name_lst[op_index]] = acc + + + # Fallback OPs accumulated according to the order in the previous stage + if len(op_fallback_acc_impact) > 0: + ordered_ops = sorted(op_fallback_acc_impact.keys(), + key=lambda key: op_fallback_acc_impact[key], + reverse=self.higher_is_better) + op_dtypes = OrderedDict(zip(ordered_ops, [target_dtype] * len(fallback_items_name_lst))) + logger.info(f"Start to accumulate fallback to {target_dtype}.") + initial_op_tuning_cfg = deepcopy(best_op_tuning_cfg_stage1) + fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], + initial_op_tuning_cfg=initial_op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=True) + for op_tuning_cfg in fallback_sampler: + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg + + def _initial_dynamic_cfg_based_on_static_cfg(self, op_static_cfg:OpTuningConfig): + op_state = op_static_cfg.get_state() + op_name = op_static_cfg.op_name + op_type = op_static_cfg.op_type + op_name_type = (op_name, op_type) + op_quant_mode = 'dynamic' + tuning_space = self.tuning_space + dynamic_state = {} + for att in ['weight', 'activation']: + if att not in op_state: continue + # Add dtype + full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, op_quant_mode) + dynamic_state[att + '_dtype'] = self.tuning_space.ops_data_type[op_name_type][full_path[att]] + for method_name, method_val in op_state[att].items(): + att_and_method_name = (att, method_name) + if att_and_method_name not in TUNING_ITEMS_LST: continue + if tuning_space.query_item_option(op_name_type, full_path[att], att_and_method_name, method_val): + dynamic_state[att_and_method_name] = method_val + else: + quant_mode_item = tuning_space.get_item_by_path((op_name_type, *full_path[att])) + if quant_mode_item and quant_mode_item.get_option_by_name(att_and_method_name): + tuning_item = quant_mode_item.get_option_by_name(att_and_method_name) + dynamic_state[att_and_method_name] = tuning_item.options[0] if tuning_item else None + return OpTuningConfig(op_name, op_type, op_quant_mode, tuning_space, kwargs=dynamic_state) + + \ No newline at end of file diff --git a/neural_compressor/experimental/strategy/bayesian.py b/neural_compressor/experimental/strategy/bayesian.py new file mode 100644 index 00000000000..58edcdee024 --- /dev/null +++ b/neural_compressor/experimental/strategy/bayesian.py @@ -0,0 +1,444 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The Bayesian tuning strategy.""" + +import copy +import warnings +import numpy as np +from scipy.optimize import minimize +from sklearn.gaussian_process.kernels import Matern +from sklearn.gaussian_process import GaussianProcessRegressor + +from collections import OrderedDict +from copy import deepcopy + +from ...utils import logger +from .strategy import strategy_registry, TuneStrategy +from .utils.tuning_sampler import OpWiseTuningSampler +from .utils.tuning_structs import OpTuningConfig + + +@strategy_registry +class BayesianTuneStrategy(TuneStrategy): + """The Bayesian tuning strategy.""" + + def __init__(self, model, conf, q_dataloader, q_func=None, eval_dataloader=None, + eval_func=None, dicts=None, q_hooks=None): + """Init the BaySian tuning strategy.""" + super().__init__(model, conf, q_dataloader, q_func, eval_dataloader, + eval_func, dicts, q_hooks) + self.bayes_opt = None + + def __getstate__(self): + """Magic method for pickle saving. + + Returns: + dict: Saved dict for resuming + """ + for history in self.tuning_history: + if self._same_yaml(history['cfg'], self.cfg): + history['bayes_opt'] = self.bayes_opt + save_dict = super().__getstate__() + return save_dict + + def _params_to_tune_configs(self, params): + op_tuning_cfg = {} + calib_sampling_size_lst = self.tuning_space.root_item.get_option_by_name('calib_sampling_size').options + for op_name_type, configs in self.op_configs.items(): + if len(configs) == 1: + op_tuning_cfg[op_name_type] = configs[0] + else: + op_tuning_cfg[op_name_type] = configs[min(len(configs) - 1, int(params[op_name_type[0]]))] + if len(calib_sampling_size_lst) > 1: + calib_sampling_size = calib_sampling_size_lst[min(len(configs) - 1, int(params['calib_sampling_size']))] + else: + calib_sampling_size = calib_sampling_size_lst[0] + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + return op_tuning_cfg + + def next_tune_cfg(self): + """Generate the next tuning config according to bayesian search algorithm. + + This strategy comes from the Bayesian optimization package and changed it to a discrete version. + It uses Gaussian processes to define the prior/posterior distribution over the black-box + function with the tuning history and then finds the tuning configuration that maximizes + the expected improvement. + + Returns: + tune_config (dict): A dict containing the tuning configuration for quantization. + """ + params = None + pbounds = {} + tuning_space = self.tuning_space + calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options + op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg = self.initial_tuning_cfg() + op_wise_pool = OpWiseTuningSampler(tuning_space, [], [], + op_item_dtype_dict, initial_op_tuning_cfg) + self.op_configs = op_wise_pool.get_opwise_candidate() + + for op_name_type, configs in self.op_configs.items(): + if len(configs) > 1: + pbounds[op_name_type[0]] = (0, len(configs)) + if len(calib_sampling_size_lst) > 1: + pbounds['calib_sampling_size'] = (0, len(calib_sampling_size_lst)) + if len(pbounds) == 0: + yield self._params_to_tune_configs(params) + return + if self.bayes_opt is None: + self.bayes_opt = BayesianOptimization( + pbounds=pbounds, random_seed=self.cfg.tuning.random_seed) + while True: + params = self.bayes_opt.gen_next_params() + logger.debug("Dump current bayesian params:") + logger.debug(params) + yield self._params_to_tune_configs(params) + try: + self.bayes_opt._space.register(params, self.last_tune_result[0]) + except KeyError: + logger.debug("Find registered params, skip it.") + pass + +# Util part +# Bayesian opt acq function + + +def acq_max(ac, gp, y_max, bounds, random_seed, n_warmup=10000, n_iter=10): + """Find the maximum of the acquisition function parameters. + + Args: + ac: The acquisition function object that return its point-wise value. + gp: A gaussian process fitted to the relevant data. + y_max: The current maximum known value of the target function. + bounds: The variables bounds to limit the search of the acq max. + random_seed: instance of np.RandomState random number generator + n_warmup: number of times to randomly sample the acquisition function + n_iter: number of times to run scipy.minimize + + Returns: + x_max: The arg max of the acquisition function. + """ + # Warm up with random points + x_tries = np.random.uniform(bounds[:, 0], bounds[:, 1], + size=(n_warmup, bounds.shape[0])) + ys = ac(x_tries, gp=gp, y_max=y_max) + x_max = x_tries[ys.argmax()] + max_acq = ys.max() + + # Explore the parameter space more thoroughly + x_seeds = np.random.uniform(bounds[:, 0], bounds[:, 1], + size=(n_iter, bounds.shape[0])) + for x_try in x_seeds: + # Find the minimum of minus the acquisition function + res = minimize(lambda x: -ac(x.reshape(1, -1), gp=gp, y_max=y_max), + x_try.reshape(1, -1), + bounds=bounds, + method="L-BFGS-B") + + # See if success + if not res.success: + continue + + if isinstance(res.fun, float): + res.fun = np.array([res.fun]) + # Store it if better than previous minimum(maximum). + if max_acq is None or -res.fun[0] >= max_acq: + x_max = res.x + max_acq = -res.fun[0] + + # Clip output to make sure it lies within the bounds. Due to floating + # point technicalities this is not always the case. + return np.clip(x_max, bounds[:, 0], bounds[:, 1]) + + +def _hashable(x): + """Ensure that an point is hashable by a python dict.""" + return tuple(map(float, x)) + +# Target space part +class TargetSpace(object): + """Holds the param-space coordinates (X) and target values (Y). + + Allows for constant-time appends while ensuring no duplicates are added. + """ + + def __init__(self, pbounds, random_seed=9527): + """Construct a TargetSpace. + + Args: + target_func (function): Function to be maximized. + pbounds (dict): Dictionary with parameters names as keys and a tuple with minimum and maximum values. + random_seed (int): Optionally specify a seed for a random number generator + """ + self.random_seed = random_seed + # Get the name of the parameters + names = list(pbounds.keys()) + self._keys = deepcopy(names) + # Create an array with parameters bounds + self._bounds = np.array( + [pbounds[name] for name in names], + dtype=np.float32 + ) + + # preallocated memory for X and Y points + self._params = np.empty(shape=(0, self.dim)) + self._target = np.empty(shape=(0)) + + # keep track of unique points we have seen so far + self._cache = {} + + def __contains__(self, x): + """Check if param x is cached in this space.""" + return _hashable(x) in self._cache + + def __len__(self): + """Get the total count of stored items.""" + assert len(self._params) == len(self._target) + return len(self._target) + + @property + def empty(self): + """Check if the space is empty.""" + return len(self) == 0 + + @property + def params(self): + """Get all params stored in this space.""" + return self._params + + @property + def target(self): + """Get all target values in this space.""" + return self._target + + @property + def dim(self): + """Get the dimension of this space.""" + return len(self._keys) + + @property + def keys(self): + """Get all keys of this space.""" + return self._keys + + @property + def bounds(self): + """Get the bounds of this space.""" + return self._bounds + + def params_to_array(self, params): + """Generate an array from params. + + Args: + params (Dict): The dict contains keys in `self.keys`, and + corresponding param. + + Returns: + np.array: An array contains all params. + """ + try: + assert set(params) == set(self.keys) + except AssertionError: + raise ValueError( + "Parameters' keys ({}) do ".format(list(params.keys())) + + "not match the expected set of keys ({}).".format(self.keys) + ) + return np.asarray([params[key] for key in self.keys]) + + def array_to_params(self, x): + """Generate an params' dict from array. + + Args: + x (np.array): The array contains all params. + + Returns: + dict: the dict contains keys and the params corresponding to it. + """ + try: + assert len(x) == len(self.keys) + except AssertionError: + raise ValueError( + "Size of array ({}) is different than the ".format(len(x)) + + "expected number of parameters ({}).".format(len(self.keys)) + ) + return dict(zip(self.keys, x)) + + def _as_array(self, x): + try: + x = np.asarray(x, dtype=float) + except TypeError: + x = self.params_to_array(x) + + x = x.ravel() + try: + assert x.size == self.dim + except AssertionError: + raise ValueError( + "Size of array ({}) is different than the ".format(len(x)) + + "expected number of parameters ({}).".format(len(self.keys)) + ) + return x + + def register(self, params, target): + """Append a point and its target value to the known data. + + Runs in amortized constant time. + + Args: + params (ndarray): a single point, with len(params) == self.dim + target (float): target function value + + Raises: + KeyError: if the point is not unique + """ + x = self._as_array(params) + if x in self: + raise KeyError('Params point {} is not unique'.format(x)) + + # Insert data into unique dictionary + self._cache[_hashable(x.ravel())] = target + + self._params = np.concatenate([self._params, x.reshape(1, -1)]) + self._target = np.concatenate([self._target, [target]]) + + def get_target(self, params): + """Get the target value of params. + + Args: + params (ndarray): a single point, with len(params) == self.dim + + Returns: + target (float): target function value. + """ + x = self._as_array(params) + target = self._cache[_hashable(x)] + return target + + def random_sample(self): + """Create random points within the bounds of the space. + + Returns: + data (ndarray): [num x dim] array points with dimensions corresponding to `self._keys` + """ + # TODO: support integer, category, and basic scipy.optimize constraints + data = np.empty((1, self.dim)) + for col, (lower, upper) in enumerate(self._bounds): + data.T[col] = np.random.uniform( # pylint: disable=unsupported-assignment-operation + lower, upper, size=1) + return data.ravel() + + def max(self): + """Get maximum target value found and corresponding parametes.""" + try: + res = { + 'target': self.target.max(), + 'params': dict( + zip(self.keys, self.params[self.target.argmax()]) + ) + } + except ValueError: + res = {} + return res + + def res(self): + """Get all target values found and corresponding parametes.""" + params = [dict(zip(self.keys, p)) for p in self.params] + + return [ + {"target": target, "params": param} + for target, param in zip(self.target, params) + ] + +# Tuning part +class BayesianOptimization(): + """The class for bayesian optimization. + + This class takes the parameters bounds in order to find which values for + the parameters yield the maximum value using bayesian optimization. + """ + + def __init__(self, pbounds, random_seed=9527, verbose=2): + """Init bayesian optimization. + + Args: + pbounds (dict): Dictionary with parameters names as keys and a tuple with + minimum and maximum values. + random_seed (int, optional): The seed for random searching. Default to 9527. + verbose (int, optional): The level of verbosity. Default to 2. + """ + self._random_seed = random_seed + # Data structure containing the bounds of its domain, + # and a record of the points we have evaluated. + self._space = TargetSpace(pbounds, random_seed) + + # Internal GP regressor + self._gp = GaussianProcessRegressor( + kernel=Matern(nu=2.5), + alpha=1e-6, + normalize_y=True, + n_restarts_optimizer=5, + random_state=self._random_seed, + ) + self._verbose = verbose + + @property + def space(self): + """Get the target space.""" + return self._space + + @property + def max(self): + """Get the maximum value of target space.""" + return self._space.max() + + @property + def res(self): + """Get the minimum value of target space.""" + return self._space.res() + + @staticmethod + def _ucb(x, gp, y_max, kappa=2.576): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + mean, std = gp.predict(x, return_std=True) + return mean + kappa * std + + def suggest(self): + """Suggest the most promising points.""" + if len(set(self._space.target)) < 2: + return self._space.array_to_params(self._space.random_sample()) + + # Sklearn's GP throws a large number of warnings at times, but + # we don't really need to see them here. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self._gp.fit(self._space.params, self._space.target) + + # Finding argmax of the acquisition function. + suggestion = acq_max( + ac=self._ucb, + gp=self._gp, + y_max=self._space.target.max(), + bounds=self._space.bounds, + random_seed=self._random_seed + ) + return self._space.array_to_params(suggestion) + + def gen_next_params(self): + """Get the next parameter.""" + next_params = self.suggest() + return next_params diff --git a/neural_compressor/experimental/strategy/conservative.py b/neural_compressor/experimental/strategy/conservative.py new file mode 100644 index 00000000000..7608ca1a894 --- /dev/null +++ b/neural_compressor/experimental/strategy/conservative.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The conservative tuning strategy for quantization level 0.""" +import copy +import os +import numpy as np + +from collections import deque +from collections import OrderedDict as COrderedDict +from copy import deepcopy +from typing import Dict, List, Tuple, OrderedDict + +from .strategy import strategy_registry, TuneStrategy +from .utils.tuning_space import TuningItem +from ...utils import logger +from ...utils.utility import Statistics +from ...algorithm import AlgorithmScheduler + +@strategy_registry +class ConservativeTuneStrategy(TuneStrategy): + """Tuning strategy with accuracy first, performance second. + + The quantization level O0 is designed for user who want to keep the accuracy + of the model after quantization. It starts with the original(fp32) model, + and then quantize the OPs to lower precision OP type wisely and OP wisely. + """ + + def __init__(self, model, conf, q_dataloader, q_func=None, eval_dataloader=None, + eval_func=None, dicts=None, q_hooks=None): + """Init conservative tuning strategy.""" + super().__init__(model, conf, q_dataloader, q_func, eval_dataloader, + eval_func, dicts, q_hooks) + self.acc_meet_flag = False + + def next_tune_cfg(self): + """Generate and yield the next tuning config with below order. + + 1. Query all quantifiable ops and save as a list of [(op_name, op_type), ...] + 2. Classify the op by its op type + 3. Add op to quant_queue according to the op type priority + 4. Go through the quant_queue and replace it with the fp32 config in tune_cfg if + accuracy meets the requirements else continue + 5. For bf16 and fp16 operators, do the same as int8 operators. + + Returns: + tune_config (dict): It's a dict containing the tuning configuration to run. + """ + tuning_space = self.tuning_space + calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options + calib_sampling_size = calib_sampling_size_lst[0] + tune_cfg = self._initialize_tune_cfg() + tune_cfg['calib_sampling_size'] = calib_sampling_size + op_type_priority = self._get_op_type_priority() + quant_items_pool = self._quant_items_pool(op_type_priority) + logger.info(f"*** Try to convert op into lower precision to improve performance.") + for dtype, op_items in quant_items_pool.items(): + logger.info(f"*** Start to convert op into {dtype}.") + for op_type, items_lst in op_items.items(): + logger.info(f"*** Try to convert all {op_type} ops into {dtype}.") + tmp_tune_cfg = deepcopy(tune_cfg) + for item, quant_mode in items_lst: + op_info = item.name + op_config = tuning_space.get_default_config(op_info, quant_mode) + tmp_tune_cfg[op_info] = op_config + yield tmp_tune_cfg + if self.acc_meet_flag: + logger.info(f"*** Convert all {op_type} ops to {dtype} and accuracy still meet the requirements") + tune_cfg = deepcopy(tmp_tune_cfg) + else: + tmp_tune_cfg = deepcopy(tune_cfg) + logger.info(f"*** Convert all {op_type} ops to {dtype} but accuracy not meet the requirements") + logger.info(f"*** Try to convert {op_type} op into {dtype} one by one.") + for item, quant_mode in items_lst: + op_info = item.name + op_config = tuning_space.get_default_config(op_info, quant_mode) + tmp_tune_cfg[op_info] = op_config + yield tmp_tune_cfg + if self.acc_meet_flag: + tune_cfg[op_info] = op_config + logger.info((f"*** Convert one {op_type} op({op_info}) " + f"into {dtype} and accuracy still meet the requirements")) + else: + tmp_tune_cfg[op_info] = tune_cfg[op_info] + logger.info(f"*** Skip convert {op_info}.") + logger.info(f"*** Ending tuning process due to no quantifiable op left.") + + def traverse(self): + """Traverse the tuning space.""" + self._eval_baseline() + + # Start tuning + trials_count = 0 + for op_tuning_cfg in self.next_tune_cfg(): + tune_cfg = self._tune_cfg_converter(op_tuning_cfg) + trials_count += 1 + tuning_history = self._find_tuning_history(tune_cfg) + if tuning_history and trials_count < self.cfg.tuning.exit_policy.max_trials: + self.last_tune_result = tuning_history['last_tune_result'] + self.best_tune_result = tuning_history['best_tune_result'] + logger.warn("Find evaluated tuning config, skip.") + continue + logger.debug("Dump current tuning configuration:") + logger.debug(tune_cfg) + self.tuning_times += 1 + # set the parameter for pre quantization algos and run + self.set_param_for_pre_quantization_algos(self.algo_scheduler, tune_cfg, self.model) + self.model = self.algo_scheduler('pre_quantization') + # quantize + q_model = self.adaptor.quantize(copy.deepcopy(tune_cfg), self.model, self.calib_dataloader, self.q_func) + assert self.adaptor.pre_optimized_model + # set the parameter for post quantization algos and run + self.set_param_for_post_quantization_algos(self.algo_scheduler, tune_cfg, self.adaptor.pre_optimized_model, + q_model) + self.last_qmodel = self.algo_scheduler('post_quantization') + self.last_tune_cfg = copy.deepcopy(tune_cfg) + # Remove the reference to model + self.algo_scheduler.reset_exec_algorithms() + assert self.last_qmodel + # Return the last quantized model as a result. if performance only. + if self.cfg.tuning.exit_policy.performance_only: + self.best_qmodel = self.last_qmodel + self._add_tuning_history(copy.deepcopy(tune_cfg), (-1, [0]), q_config=self.last_qmodel.q_config) + return + self.last_tune_cfg = copy.deepcopy(tune_cfg) + self.last_tune_result = self._evaluate(self.last_qmodel) + self.acc_meet_flag = self.objectives.accuracy_meets() + if self.acc_meet_flag: + # For the first tuning + if not self.best_tune_result: + self.best_tune_result = self.last_tune_result + self.best_qmodel = self.last_qmodel + self.best_tune_result = self.last_tune_result + else: + # Update current tuning config and model with best performance + get_better_performance = self._compare_performace(self.last_tune_result, self.best_tune_result) + if get_better_performance: + logger.info(f"*** Update the model with better performance.") + self.best_qmodel = self.last_qmodel + self.best_tune_result = self.last_tune_result + else: + logger.info(f"*** The qmodel was not updated due to not achieving better performance.") + # Dump the current state to log + self._dump_tuning_state(trials_count, self.last_tune_result, self.best_tune_result, self.baseline) + # Judge stop or continue tuning + need_stop = self.stop(trials_count) + # Record the tuning history + saved_tune_cfg = copy.deepcopy(tune_cfg) + saved_last_tune_result = copy.deepcopy(self.last_tune_result) + self._add_tuning_history(saved_tune_cfg, + saved_last_tune_result, + q_config=q_model.q_config) + self.tune_result_record.append(copy.deepcopy(self.last_tune_result)) + self.tune_cfg = tune_cfg + self._dump_tuning_process_statistics() + if need_stop: + if self.cfg.tuning.diagnosis and self.cfg.tuning.diagnosis.diagnosis_after_tuning: + logger.debug(f'*** Start to do diagnosis (inspect tensor).') + self._diagnosis() + self._recover_best_qmodel_from_tuning_cfg() + if self.use_multi_objective and len(self.tune_result_record) > 1 and \ + self.best_tune_result is not None: + best_trail, best_result = self.objectives.best_result(self.tune_result_record, + copy.deepcopy(self.baseline)) + if best_result != self.best_tune_result: + from neural_compressor.utils.utility import recover + self.best_qmodel = recover(self.model.model, + os.path.join(self.cfg.tuning.workspace.path, 'history.snapshot'), + best_trail) + self.best_tune_result = best_result + self._dump_tuning_process_statistics() + break + + def stop(self, trials_count): + """Check whether needed to stop the traverse procedure. + + Args: + trials_count (int): current total count of tuning trails. + + Returns: + bool: whether needed to stop the traverse procedure. + """ + need_stop = False + if trials_count >= self.cfg.tuning.exit_policy.max_trials: + need_stop = True + return need_stop + + def _compare_performace(self, last_tune_result, best_tune_result): # pragma: no cover + """Compare the tuning result with performance only. + + Args: + last_tune_result (list): The list of last tuning result. + best_tune_result (list): The list of best tuning result. + + Returns: + bool: whether the best tuning result is better than last tuning result + in performance. + """ + _, last_perf = last_tune_result + _, best_perf = best_tune_result + return last_perf[0] < best_perf[0] + + def _dump_tuning_state(self, trials_count, last_tune_result, best_tune_result, baseline): + if last_tune_result: + last_tune = last_tune_result[0] if \ + isinstance(last_tune_result[0], list) else [last_tune_result[0]] + for name, data in zip(self.metric_name, last_tune): + if len(self.tune_data[name]) == 1: + self.tune_data[name].append(data) + else: + self.tune_data[name][1] = data + + if self.metric_weight and len(last_tune) > 1: + weighted_acc = np.mean(np.array(last_tune) * self.metric_weight) + if len(self.tune_data['Weighted accuracy']) == 1: + self.tune_data['Weighted accuracy'].append(weighted_acc) + else: + self.tune_data['Weighted accuracy'][1] = weighted_acc + last_tune = [weighted_acc] + + last_tune_msg = '[Accuracy (int8|fp32):' + \ + ''.join([' {:.4f}|{:.4f}'.format(last, base) for last, base in \ + zip(last_tune, self.tune_data['baseline'])]) + \ + ''.join([', {} (int8|fp32): {:.4f}|{:.4f}'.format( \ + x, y, z) for x, y, z in zip( \ + self.objectives.representation, last_tune_result[1], baseline[1]) \ + if x != 'Accuracy']) + ']' + else: # pragma: no cover + last_tune_msg = 'n/a' + for name in self.tune_data.keys() - {'baseline'}: + if len(self.tune_data[name]) == 1: + self.tune_data[name].append('n/a') + else: + self.tune_data[name][1] = 'n/a' + + if best_tune_result: + best_tune = best_tune_result[0] if isinstance(best_tune_result[0], list) \ + else [best_tune_result[0]] + + for name, data in zip(self.metric_name, best_tune): + if len(self.tune_data[name]) == 2: + self.tune_data[name].append(data) + else: + self.tune_data[name][2] = data + + if self.metric_weight and len(best_tune) > 1: + weighted_acc = np.mean(np.array(best_tune) * self.metric_weight) + + if len(self.tune_data['Weighted accuracy']) == 2: + self.tune_data['Weighted accuracy'].append(weighted_acc) + else: # pragma: no cover + self.tune_data['Weighted accuracy'][2] = weighted_acc + + best_tune = [weighted_acc] + + best_tune_msg = '[Accuracy:' + ''.join([' {:.4f}'.format(best) \ + for best in best_tune]) + ''.join([', {}: {:.4f}'.format(x,y) \ + for x,y in zip(self.objectives.representation, \ + best_tune_result[1]) if x != 'Accuracy']) + ']' + + else: + best_tune_msg = 'n/a' + for name in self.tune_data.keys() - {'baseline'}: + if len(self.tune_data[name]) == 2: + self.tune_data[name].append('n/a') + else: + self.tune_data[name][2] = 'n/a' + + logger.info("Tune {} result is: {}, Best tune result is: {}".format(trials_count, + last_tune_msg, + best_tune_msg)) + output_data = [[info_type, + '{:.4f} '.format(self.tune_data[info_type][0]) if \ + not isinstance(self.tune_data[info_type][0], str) else self.tune_data[info_type][0], + '{:.4f} '.format(self.tune_data[info_type][1]) if \ + not isinstance(self.tune_data[info_type][1], str) else self.tune_data[info_type][1], + '{:.4f} '.format(self.tune_data[info_type][2]) if \ + not isinstance(self.tune_data[info_type][2], str) else self.tune_data[info_type][2]] \ + for info_type in self.tune_data.keys() if info_type != 'baseline'] + + output_data.extend([[obj, + '{:.4f} '.format(baseline[1][i]) if baseline else 'n/a', + '{:.4f} '.format(last_tune_result[1][i]) if last_tune_result else 'n/a', + '{:.4f} '.format(best_tune_result[1][i]) if best_tune_result else 'n/a'] \ + for i, obj in enumerate(self.objectives.representation)]) + + Statistics(output_data, + header='Tune Result Statistics', + field_names=['Info Type', 'Baseline', 'Tune {} result'.format(trials_count), \ + 'Best tune result']).print_stat() + + def _get_op_type_priority(self): + optypewise_cap = self.capability['optypewise'] + op_type_priority = list(optypewise_cap.keys()) + return op_type_priority + + def _sorted_item_by_op_type(self, + items_lst: List[Tuple[TuningItem, str]], + op_type_priority: List[str]) -> OrderedDict[str, List]: + """Socring the tuning items according to its op type. + + Args: + items_lst: The tuning item list. # [(op_item, quant_mode), ... ] + op_type_priority: The op type list with the order. # [optype_1, optype_2] + + Returns: + The tuning items list that sorted according to its op type. + OrderDict: + # op_type: [(TuningItem, quant_mode), ...] + conv2d: [(TuningItem, static), (TuningItem, static)] + linear: [(TuningItem, static), (TuningItem, static)] + """ + op_type_lst_from_items_lst = list(set([item[0].name[1] for item in items_lst])) + # For items whose op type does not exist in the priority list, assign it with lowest priority. + sorted_op_type_lst = [op_type for op_type in op_type_priority if op_type in op_type_lst_from_items_lst] + sorted_op_type_lst += list(set(op_type_lst_from_items_lst) - set(op_type_priority)) + sorted_items = COrderedDict() + for op_type in sorted_op_type_lst: + sorted_items[op_type] = [] + for op_item, quant_mode in items_lst: + op_type = op_item.name[1] + sorted_items[op_type].append((op_item, quant_mode)) + return sorted_items + + def _initialize_tune_cfg(self): + """Initialize the tuning config with fp32 AMAP. + + Returns: + The intialized tuning config. + """ + tuning_space = self.tuning_space + quant_mode_wise_items = tuning_space.quant_mode_wise_items + # Initialize the tuning config + initial_tuning_cfg = {} + all_ops = set() + fp32_ops = [] + for quant_mode, items_lst in quant_mode_wise_items.items(): + items_name_lst = [item.name for item in items_lst] + all_ops = all_ops.union(set(items_name_lst)) + if quant_mode == "fp32": + fp32_ops += [item.name for item in items_lst] + non_fp32_ops_dtype = {} + fp32_ops_set = set(fp32_ops) + for quant_mode, items_lst in quant_mode_wise_items.items(): + items_name_set = set([item.name for item in items_lst]) + tmp_non_fp32_ops = items_name_set.difference(fp32_ops_set) + if tmp_non_fp32_ops: + for op_info in tmp_non_fp32_ops: + non_fp32_ops_dtype[op_info] = quant_mode + for op_info in fp32_ops: + initial_tuning_cfg[op_info] = tuning_space.get_default_config(op_info, "fp32") + for op_info, quant_mode in non_fp32_ops_dtype.items(): + initial_tuning_cfg[op_info] = tuning_space.get_default_config(op_info, quant_mode) + return initial_tuning_cfg + + def _quant_items_pool(self, op_type_priority: List[str]) -> OrderedDict[ + str, OrderedDict[str, List[Tuple[TuningItem, str]]]]: + """Create the op queue to be quantized. + + -------------------------------------------------------------------------- + | Level 1 | bf16 | fp16 | static/dynamic | + | Level 2 | conv2d, linear, ...| conv2d, linear, ...| conv2d, linear, ...| + + Args: + op_type_priority: The optype list with priority. + + Returns: + The op item pool to convert into lower precision. + quant_items_pool(OrderDict): + bf16: + OrderDict: + conv2d: [(TuningItem, bf16), (TuningItem, bf16)] + linear: [(TuningItem, bf16), (TuningItem, bf16)] + int8: + OrderDict: + # (TuningItem, quant_mode) + conv2d: [(TuningItem, static), (TuningItem, static)] + linear: [(TuningItem, static), (TuningItem, static)] + """ + quant_mode_wise_items = self.tuning_space.quant_mode_wise_items + # Add all quantized pair into queue + quant_items_pool = COrderedDict() + # collect and sorted all ops that support bf16 and fp16 + for quant_mode in ['bf16', 'fp16']: + if quant_mode in quant_mode_wise_items: + op_item_pairs = [(op_item, quant_mode) for op_item in quant_mode_wise_items[quant_mode]] + op_item_pairs = self._sorted_item_by_op_type(op_item_pairs, op_type_priority) + quant_items_pool[quant_mode] = op_item_pairs + op_item_pairs = [] + quant_ops_name_set = set() + # collect and sorted all ops that support int8 + for quant_mode, items_lst in quant_mode_wise_items.items(): + if "static" in quant_mode or 'dynamic' in quant_mode: + _quant_mode = "static" if "static" in quant_mode else "dynamic" + op_item_pairs += [(item, _quant_mode) for item in items_lst if item.name not in quant_ops_name_set] + quant_ops_name_set = quant_ops_name_set.union([item.name for item in items_lst]) + op_item_pairs = self._sorted_item_by_op_type(op_item_pairs, op_type_priority) + quant_items_pool['int8'] = op_item_pairs + return quant_items_pool diff --git a/neural_compressor/experimental/strategy/exhaustive.py b/neural_compressor/experimental/strategy/exhaustive.py new file mode 100644 index 00000000000..b40d5b70397 --- /dev/null +++ b/neural_compressor/experimental/strategy/exhaustive.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The exhaustive tuning strategy.""" +from collections import OrderedDict +from .strategy import strategy_registry, TuneStrategy + +from .utils.tuning_sampler import OpWiseTuningSampler, FallbackTuningSampler, ModelWiseTuningSampler +from .utils.tuning_structs import OpTuningConfig +from ...utils import logger + +@strategy_registry +class ExhaustiveTuneStrategy(TuneStrategy): + """The exhaustive tuning strategy.""" + + def next_tune_cfg(self): + """Generate and yield the next tuning config using exhaustive search in tuning space. + + It sequentially traverse all possible quantization tuning configurations + in a tuning space. From the perspective of the impact on performance, + we currently only traverse all possible quantization tuning configs. + Same reason as Bayesian, fallback datatypes are not included for now. + + Returns: + tune_config (dict): A dict containing the tuning configuration for quantization. + """ + tuning_space = self.tuning_space + calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options + for calib_sampling_size in calib_sampling_size_lst: + op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg = self.initial_tuning_cfg() + op_wise_tuning_sampler = OpWiseTuningSampler(tuning_space, [], [], + op_item_dtype_dict, initial_op_tuning_cfg) + for op_tuning_cfg in op_wise_tuning_sampler: + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg + return diff --git a/neural_compressor/experimental/strategy/hawq_v2.py b/neural_compressor/experimental/strategy/hawq_v2.py new file mode 100644 index 00000000000..1fd76b9b7dd --- /dev/null +++ b/neural_compressor/experimental/strategy/hawq_v2.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The HAWQ_V2 tuning strategy.""" +from collections import OrderedDict +from copy import deepcopy + +from .strategy import strategy_registry, TuneStrategy + +from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler, ModelWiseTuningSampler +from .utils.tuning_structs import OpTuningConfig +from .utils.constant import TUNING_ITEMS_LST +from ...utils import logger + +@strategy_registry +class HAWQ_V2TuneStrategy(TuneStrategy): + """The HAWQ V2 tuning strategy. + + HAWQ_V2 implements the "Hawq-v2: Hessian aware trace-weighted quantization of neural networks". + We made a small change to it by using the hessian trace to score the op impact and then + fallback the OPs according to the scoring result. + + """ + + def next_tune_cfg(self): + """Generate and yield the next tuning config using HAWQ v2 search in tuning space. + + Returns: + tune_config (dict): A dict containing the tuning configuration for quantization. + """ + tuning_space = self.tuning_space + calib_size = tuning_space.root_item.get_option_by_name('calib_sampling_size').options[0] + + # Initialize the tuning config for each op according to the quantization approach + op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg = self.initial_tuning_cfg() + # Optype-wise tuning tuning items: the algorithm/scheme/granularity of activation(weight) + early_stop_tuning = True + stage1_cnt = 0 + quant_ops = quant_mode_wise_items.get('static', []) + quant_ops += quant_mode_wise_items.get('dynamic', []) + stage1_max = 1 # TODO set a more appropriate value + op_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [], + op_item_dtype_dict, initial_op_tuning_cfg) + for op_tuning_cfg in op_wise_tuning_sampler: + stage1_cnt += 1 + if early_stop_tuning and stage1_cnt > stage1_max: + logger.info("Early stopping the stage 1.") + break + op_tuning_cfg['calib_sampling_size'] = calib_size + yield op_tuning_cfg + # Start compute the hessian trace + logger.info(f"************** Start compute the hessian trace *****************") + target_dtype = "fp32" + hawq_v2_criterion =self.cfg.tuning.strategy.hawq_v2_loss + # assert hawq_v2_criterion is not None, "HAWQ-V2 strategy needs model loss function to compute the gradient, \ + # Please assign it by strategy_kwargs({'hawq_v2_loss': hawq_v2_loss})." + op_to_traces = self.adaptor.calculate_hessian_trace(fp32_model = self._fp32_model, + dataloader = self.calib_dataloader, + q_model = self.last_qmodel, + criterion =hawq_v2_criterion, + enable_act = False) + sorted_op_to_traces = dict(sorted(op_to_traces.items(), key=lambda item: item[1], reverse=True)) + logger.info(f"************** Hessian Trace *****************") + for op_name, trace in sorted_op_to_traces.items(): + logger.info(f"*** op: {op_name}, hessian trace : {trace}") + logger.info(f"************************************************") + # WA for op mapping + ordered_ops_tmp = {} + for op_info in list(initial_op_tuning_cfg.keys()): + op_name, op_type = op_info + for op_trace_name in op_to_traces.keys(): + if isinstance(op_trace_name, str) and op_trace_name.startswith(op_name): + if op_name in ordered_ops_tmp: + logger.info((f"*** Already assigned the hessian trace to {op_name}", + f"update it with the value of {op_trace_name}")) + ordered_ops_tmp[op_name] = op_to_traces[op_trace_name] + + ordered_ops_tmp = sorted(ordered_ops_tmp.keys(), + key=lambda key: ordered_ops_tmp[key], + reverse=self.higher_is_better) + # WA for add op type + op_info_map = {} + for op_info in list(initial_op_tuning_cfg.keys()): + op_info_map[op_info[0]] = op_info # op_name: (op_name, op_type) + tmp_ordered_ops = [op_info_map[op_name] for op_name in ordered_ops_tmp] + op_dtypes = OrderedDict(zip(tmp_ordered_ops, [target_dtype] * len(ordered_ops_tmp))) + + logger.info(f"Start to accumulate fallback to {target_dtype}.") + initial_op_tuning_cfg = deepcopy(op_tuning_cfg) + fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], + initial_op_tuning_cfg=op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=True, + skip_first=False) + for op_tuning_cfg in fallback_sampler: + op_tuning_cfg['calib_sampling_size'] = calib_size + yield op_tuning_cfg + diff --git a/neural_compressor/experimental/strategy/mse.py b/neural_compressor/experimental/strategy/mse.py new file mode 100644 index 00000000000..55955774e74 --- /dev/null +++ b/neural_compressor/experimental/strategy/mse.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MSE tuning strategy.""" +from copy import deepcopy +import numpy as np +from collections import OrderedDict +from typing import Dict, Any, List +from .strategy import strategy_registry, TuneStrategy +from ...utils import logger +from time import time + +from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler +from .utils.tuning_structs import OpTuningConfig + +@strategy_registry +class MSETuneStrategy(TuneStrategy): + """The tuning strategy using MSE policy in tuning space. + + The MSE strategy needs to get the tensors for each OP of raw FP32 models and the quantized model based on + the best model-wise tuning configuration. It then calculates the MSE (Mean Squared Error) for each OP, sorts + those OPs according to the MSE value, and performs the op-wise fallback in this order. + """ + + def __init__(self, model, conf, q_dataloader, q_func=None, eval_dataloader=None, + eval_func=None, dicts=None, q_hooks=None): + """Init an mse tuning strategy.""" + super().__init__(model, conf, q_dataloader, q_func, eval_dataloader, + eval_func, dicts, q_hooks) + self.ordered_ops = None + + + def __getstate__(self): + """Magic method for pickle saving. + + Returns: + save_dict: Saved dict for resuming + """ + for history in self.tuning_history: + if self._same_yaml(history['cfg'], self.cfg): + history['ordered_ops'] = self.ordered_ops + save_dict = super().__getstate__() + return save_dict + + def _mse_metric_gap(self, fp32_tensor, dequantize_tensor): + """Calculate the euclidean distance between fp32 tensor and int8 dequantize tensor. + + Args: + fp32_tensor (tensor): The FP32 tensor. + dequantize_tensor (tensor): The INT8 dequantize tensor. + """ + fp32_max = np.max(fp32_tensor) + fp32_min = np.min(fp32_tensor) + dequantize_max = np.max(dequantize_tensor) + dequantize_min = np.min(dequantize_tensor) + fp32_tensor = (fp32_tensor - fp32_min) / (fp32_max - fp32_min) + dequantize_tensor = (dequantize_tensor - dequantize_min) / \ + (dequantize_max - dequantize_min) + diff_tensor = fp32_tensor - dequantize_tensor + euclidean_dist = np.sum(diff_tensor ** 2) + return euclidean_dist / fp32_tensor.size + + def mse_impact_lst(self, op_list: List, fp32_model, best_qmodel): + """Calculate and generate the MSE impact list. + + Args: + op_list (List[Tuple(str, str)]): List of ops in format of [(op_name, op_type), ...]. + fp32_model (Model): The original FP32 model before quantization. + current_best_model (Model): The currently best quantized model. + + Returns: + ordered_op_name_types (List[Tuple(str, str)]): The sorted list of ops by its MSE + impaction, in the same format of 'op_list'. + """ + op_name_lst = [element[0] for element in op_list ] + op_mapping = {} + for (op_name, op_type) in list(op_list): + op_mapping[op_name] = (op_name, op_type) + current_best_tune_cfg = self._tune_cfg_converter(self.cur_best_tuning_cfg) + fp32_dump_content = self.adaptor.inspect_tensor(fp32_model, + self.calib_dataloader, op_name_lst, [1], inspect_type='activation', + save_to_disk=True, save_path="./nc_workspace/", + quantization_cfg=current_best_tune_cfg) + fp32_tensor_dict = fp32_dump_content['activation'][0] + best_qmodel = self.adaptor.quantize(current_best_tune_cfg, self.model, self.calib_dataloader, self.q_func) + quant_dump_content = self.adaptor.inspect_tensor(best_qmodel, + self.calib_dataloader, op_name_lst, [1], inspect_type='activation', + save_to_disk=True, save_path="./nc_workspace/", + quantization_cfg=current_best_tune_cfg) + dequantize_tensor_dict = quant_dump_content['activation'][0] + ops_mse = { + op: self._mse_metric_gap( + list(fp32_tensor_dict[op].values())[0], + list(dequantize_tensor_dict[op].values())[0]) for op in fp32_tensor_dict} + ordered_op_names = sorted(ops_mse.keys(), key=lambda key: ops_mse[key], reverse=self.higher_is_better) + + ordered_op_name_types = [op_mapping[name] for name in ordered_op_names] + return ordered_op_name_types + + + def next_tune_cfg(self): + """Generate and yield the next tuning config. + + Returns: + tune_config (dict): A dict containing the tuning configuration for quantization. + """ + tuning_space = self.tuning_space + calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options + for calib_sampling_size in calib_sampling_size_lst: + op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg = self.initial_tuning_cfg() + # Optype-wise tuning + early_stop_tuning = True + stage1_cnt = 0 + int8_ops = quant_mode_wise_items['static'] if 'static' in quant_mode_wise_items else [] + int8_ops += quant_mode_wise_items['dynamic'] if 'dynamic' in quant_mode_wise_items else [] + stage1_max = min(5, len(int8_ops)) # TODO set a more appropriate value + op_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [], + op_item_dtype_dict, initial_op_tuning_cfg) + for op_tuning_cfg in op_wise_tuning_sampler: + stage1_cnt += 1 + if early_stop_tuning and stage1_cnt > stage1_max: + logger.info("Early stopping the stage 1.") + break + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg + + # Fallback the ops supported both static and dynamic from static to dynamic + static_dynamic_items = [item for item in tuning_space.query_items_by_quant_mode('static') if + item in tuning_space.query_items_by_quant_mode('dynamic')] + if static_dynamic_items: + logger.info("Fallback all ops that support both dynamic and static to dynamic.") + else: + logger.info("No op support both dynamic and static") + + def dynamic_op_tuning_cfg_from_static(op_tuning_cfg: OpTuningConfig): + new_op_tuning_cfg = deepcopy(op_tuning_cfg) + new_op_tuning_cfg.op_quant_mode = 'dynamic' + return new_op_tuning_cfg + + new_op_tuning_cfg = deepcopy(self.cur_best_tuning_cfg) + for item in static_dynamic_items: + new_op_tuning_cfg[item.name] = dynamic_op_tuning_cfg_from_static(new_op_tuning_cfg[item.name]) + new_op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield new_op_tuning_cfg + + best_op_tuning_cfg_stage1 = deepcopy(self.cur_best_tuning_cfg) + + # Fallback to float point datatypes ('bf16' or 'fp32') + for target_dtype in ['bf16', 'fp32']: + fallback_items_lst = [item for item in int8_ops if + item in tuning_space.query_items_by_quant_mode(target_dtype)] + if fallback_items_lst: + logger.info(f"Start to fallback op to {target_dtype} one by one.") + # Replace it with sorted items list + fallback_items_name_lst = [item.name for item in fallback_items_lst] + # TODO check the best_qmodel + ordered_op_name_types = self.mse_impact_lst(fallback_items_name_lst, self.model, self.best_qmodel) + self.ordered_ops = [op_name for (op_name, op_type) in ordered_op_name_types] + op_dtypes = OrderedDict(zip(ordered_op_name_types, [target_dtype] * len(fallback_items_name_lst))) + initial_op_tuning_cfg = deepcopy(best_op_tuning_cfg_stage1) + fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], + initial_op_tuning_cfg=initial_op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=False) + op_fallback_acc_impact = OrderedDict() + for op_index, op_tuning_cfg in enumerate(fallback_sampler): + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg + acc, _ = self.last_tune_result + op_fallback_acc_impact[fallback_items_name_lst[op_index]] = acc + + # Do accumulated fallback according to the order in the previous stage + if len(op_fallback_acc_impact) > 0: + ordered_ops = sorted(op_fallback_acc_impact.keys(), + key=lambda key: op_fallback_acc_impact[key], + reverse=self.higher_is_better) + op_dtypes = OrderedDict(zip(ordered_ops, [target_dtype] * len(fallback_items_name_lst))) + logger.info(f"Start to accumulate fallback to {target_dtype}.") + initial_op_tuning_cfg = deepcopy(best_op_tuning_cfg_stage1) + fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], + initial_op_tuning_cfg=initial_op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=True) + for op_tuning_cfg in fallback_sampler: + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg diff --git a/neural_compressor/experimental/strategy/mse_v2.py b/neural_compressor/experimental/strategy/mse_v2.py new file mode 100644 index 00000000000..6492ae26dca --- /dev/null +++ b/neural_compressor/experimental/strategy/mse_v2.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The MSE_V2 tuning strategy.""" +import copy +from copy import deepcopy +import numpy as np +from collections import OrderedDict +from typing import Dict, Any, List +from .strategy import strategy_registry, TuneStrategy +from ...utils import logger +from time import time + +from .utils.tuning_sampler import OpTypeWiseTuningSampler +from .utils.tuning_structs import OpTuningConfig + +@strategy_registry +class MSE_V2TuneStrategy(TuneStrategy): + """The `mse_v2` tuning strategy. + + MSE_v2 is a strategy with a two stages fallback and revert fallback. + Note that, only tensorflow framework and pytorch FX backend is currently supported for mse_v2 + tuning strategy. + """ + + def _tuning_record_msg(self, records): + records_str_lst = [[str(e) for e in record] for record in records] + record_msg = '\n'.join(','.join(record) for record in records_str_lst) + return record_msg + + def next_tune_cfg(self): + """Generate and yield the next tuning config with below order. + + 1. In the fallback stage, it uses multi-batch data to score the op impact + and then fallback the op with the highest score util found the quantized model + that meets accuracy criteria. + 2. In the revert fallback stage, it also scores + the impact of fallback OPs in the previous stage and selects the op + with the lowest score to revert the fallback until the quantized model + that does not meets accuracy criteria. + + Returns: + tune_config (dict): A dict containing the tuning configuration for quantization. + """ + best_op_tuning_cfg = None + if len(self.metric_name) == 1 or self.metric_weight is not None: + best_acc = float('-inf') if self.higher_is_better else float('inf') + else: + best_acc = [float('-inf') if higher_is_better else float('inf') for \ + higher_is_better in self.metric_criterion] + + from copy import deepcopy + tuning_space = self.tuning_space + initial_op_tuning_cfg = {} + for item in tuning_space.root_item.options: + if item.item_type == 'op': + op_name, op_type = item.name + initial_op_tuning_cfg[item.name] = OpTuningConfig(op_name, op_type, 'fp32', tuning_space) + calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options + for calib_sampling_size in calib_sampling_size_lst: + # Collect the ops that support static and dynamic + quant_mode_wise_items = OrderedDict() + query_order = ['static', 'dynamic', 'bf16', 'fp16', 'fp32'] + pre_items = set() + for quant_mode in query_order: + items = tuning_space.query_items_by_quant_mode(quant_mode) + filtered_items = [item for item in items if item not in pre_items] + pre_items = pre_items.union(set(items)) + quant_mode_wise_items[quant_mode] = filtered_items + + def initial_op_quant_mode(items_lst, target_quant_mode, op_item_dtype_dict): + for item in items_lst: + op_item_dtype_dict[item.name] = target_quant_mode + + op_item_dtype_dict = OrderedDict() + for quant_mode, quant_mode_items in quant_mode_wise_items.items(): + initial_op_quant_mode(quant_mode_items, quant_mode, op_item_dtype_dict) + + # Optype-wise tuning + early_stop_tuning = True + stage1_cnt = 0 + int8_ops = quant_mode_wise_items['dynamic'] + quant_mode_wise_items['static'] + stage1_max = 2 # TODO set a more appropriate value + op_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [], + op_item_dtype_dict, initial_op_tuning_cfg) + for op_tuning_cfg in op_wise_tuning_sampler: + stage1_cnt += 1 + if early_stop_tuning and stage1_cnt > stage1_max: + logger.info("Early stopping the stage 1.") + break + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg + + # Fallback the ops supported both static and dynamic from static to dynamic + static_dynamic_items = [item for item in tuning_space.query_items_by_quant_mode('static') if + item in tuning_space.query_items_by_quant_mode('dynamic')] + if static_dynamic_items: + logger.info("Fallback all ops that support both dynamic and static to dynamic.") + else: + logger.info("No op support both dynamic and static") + + def dynamic_op_tuning_cfg_from_static(op_tuning_cfg: OpTuningConfig): + new_op_tuning_cfg = deepcopy(op_tuning_cfg) + new_op_tuning_cfg.op_quant_mode = 'dynamic' + return new_op_tuning_cfg + + new_op_tuning_cfg = deepcopy(self.cur_best_tuning_cfg) + for item in static_dynamic_items: + new_op_tuning_cfg[item.name] = dynamic_op_tuning_cfg_from_static(new_op_tuning_cfg[item.name]) + new_op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield new_op_tuning_cfg + + # Fallback one by one by op sensitivity(mse) + # 1. while the accuracy requirements not met: # to improve the accuracy + # 1) calculate the sensitivity of int8 ops in current state. + # 2) fallback the op with higher sensitivity accumulatively + # 2. after the accuracy requirements met: # to improve the performance + # 1) calculate the sensitivity of fp32 ops in the current state + # 2) re-quantize the op with lower sensitivity accumulatively + tune_cfg = deepcopy(self.cur_best_tuning_cfg) + requantize_cfg = deepcopy(self._tune_cfg_converter(self.cur_best_tuning_cfg)) + self.output_op_names = self.adaptor.get_output_op_names(self.last_qmodel) + self.confidence_batches = (self.cfg.tuning.strategy.confidence_batches + if self.cfg.tuning.strategy.confidence_batches != None else 2) + tune_cfg_backup = deepcopy(tune_cfg) + quant_ops_in_tune_cfg = self._collect_ops_by_quant_mode(tune_cfg, 'dynamic') + \ + self._collect_ops_by_quant_mode(tune_cfg, 'static') + op_quant_cfgs = {op_info: tune_cfg_backup[op_info] for op_info in quant_ops_in_tune_cfg} + fallback_records = [] + self.re_quant = True + while not self.objectives.compare(self.last_tune_result, self.baseline): + # Record the time of calcutating the sensitivity + start = time() + ops_lst = self.adaptor.calculate_op_sensitivity(self.model, + self.calib_dataloader, + deepcopy(self._tune_cfg_converter(tune_cfg)), + self.output_op_names, + self.confidence_batches, + fallback=True) + logger.debug(f"*** The op sensitivity analysis took {time() - start:.2f}s.") + select_op_info = ops_lst[0] + logger.info(f"*** The op {select_op_info} have the highest sensitivity in the current state, \ + fallback it to fp32.") + tune_cfg[select_op_info] = OpTuningConfig(select_op_info[0], + select_op_info[1], + 'fp32', + self.tuning_space) + # Record the fallback history + if not fallback_records: + fallback_records = [[select_op_info]] + else: + fallback_records.append(fallback_records[-1] + [select_op_info]) + logger.debug(f"*** The fallback ops record: \n{self._tuning_record_msg(fallback_records)}") + yield tune_cfg + + logger.info(f"*** The accuracy meeting the accuracy requirements, stop fallback ops.") + while self.objectives.compare(self.last_tune_result, self.baseline): + if len(fallback_records) == 0 or len(fallback_records[-1]) <= 1: + logger.info(f"*** Stop re-quant due to no int8 op or only 1 int8 op left.") + break + logger.info(f"*** Start to re-quant the fallback op in the previous stage.") + # Track the current fallback ops + tmp_fallback_ops = fallback_records[-1] if fallback_records else [] + start = time() + ops_lst = self.adaptor.calculate_op_sensitivity(self.model, + self.calib_dataloader, + deepcopy(self._tune_cfg_converter(tune_cfg)), + self.output_op_names, + self.confidence_batches, + fallback=False, + requantize_cfgs=requantize_cfg['op']) + logger.debug(f"*** The op sensitivity analysis took {time() - start:.2f}s.") + if not ops_lst: + logger.warning("No op to be requantized") + break + for select_op_info in ops_lst: + #assert select_op_info in tmp_fallback_ops, f"{select_op_info} not in fallback list." + if select_op_info not in tmp_fallback_ops: + logger.debug(f"{select_op_info} not in fallback list.") + continue + + new_fallback_ops = deepcopy(tmp_fallback_ops) + new_fallback_ops.remove(select_op_info) + if new_fallback_ops not in fallback_records: + logger.info(f"*** The op {select_op_info} have the lowest sensitivity in the current state, \ + re-quantize it.") + tune_cfg[select_op_info] = op_quant_cfgs[select_op_info] + fallback_records.append(new_fallback_ops) + logger.debug(f"*** The fallback ops record: \n{self._tuning_record_msg(fallback_records)}") + yield tune_cfg + break + else: + logger.debug(f"*** Skip re-qaunt {select_op_info}, due the config has been evallated.") + continue + self.re_quant = False + logger.info(f"*** The accuracy not meeting the accuracy requirements, stop re-quantize ops.") \ No newline at end of file diff --git a/neural_compressor/experimental/strategy/random.py b/neural_compressor/experimental/strategy/random.py new file mode 100644 index 00000000000..7148100a76a --- /dev/null +++ b/neural_compressor/experimental/strategy/random.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The random tuning strategy.""" +import numpy as np +from .strategy import strategy_registry, TuneStrategy +from collections import OrderedDict + +from .utils.tuning_sampler import OpWiseTuningSampler, FallbackTuningSampler +from .utils.tuning_structs import OpTuningConfig +from ...utils import logger + +@strategy_registry +class RandomTuneStrategy(TuneStrategy): + """The random tuning strategy.""" + + def next_tune_cfg(self): + """Generate and yield the next tuning config by random searching in tuning space. + + Random strategy is used to randomly choose quantization tuning configurations + from the tuning space. As with the Exhaustive strategy, it also only considers + quantization tuning configs to generate a better-performance quantized model. + + Returns: + tune_config (dict): A dict containing the tuning configuration for quantization. + """ + tuning_space = self.tuning_space + op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg = self.initial_tuning_cfg() + op_wise_tuning_sampler = OpWiseTuningSampler(tuning_space, [], [], + op_item_dtype_dict, initial_op_tuning_cfg) + op_tuning_cfg_lst = list(op_wise_tuning_sampler) + op_tuning_cfg_cnt = len(op_tuning_cfg_lst) + calib_sampling_size_lst = tuning_space.root_item.get_option_by_name('calib_sampling_size').options + calib_sampling_size_cnt = len(calib_sampling_size_lst) + while True: + calib_index = np.random.choice(calib_sampling_size_cnt) + calib_sampling_size = calib_sampling_size_lst[calib_index] + op_tuning_cfg_index = np.random.choice(op_tuning_cfg_cnt) + op_tuning_cfg = op_tuning_cfg_lst[op_tuning_cfg_index] + op_tuning_cfg['calib_sampling_size'] = calib_sampling_size + yield op_tuning_cfg + return diff --git a/neural_compressor/experimental/strategy/strategy.py b/neural_compressor/experimental/strategy/strategy.py new file mode 100644 index 00000000000..b7abb59ed5f --- /dev/null +++ b/neural_compressor/experimental/strategy/strategy.py @@ -0,0 +1,1556 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The base class for tuning strategy.""" + +from abc import abstractmethod +from enum import EnumMeta +import os +import math +import copy +from copy import deepcopy +import pickle +from collections import OrderedDict, defaultdict +from pathlib import Path +import yaml +import numpy as np +from typing import OrderedDict as T_OrderedDict + +from neural_compressor.adaptor.tensorflow import TensorFlowAdaptor +from ...objective import MultiObjective +from ...adaptor import FRAMEWORKS +from ...utils.utility import Statistics, dump_data_to_local +from ...utils.utility import fault_tolerant_file, equal_dicts, GLOBAL_STATE, MODE +from ...utils.create_obj_from_config import create_eval_func, create_train_func +from ...utils.utility import LazyImport +from ...utils import logger +from ...version import __version__ +from ...conf.dotdict import DotDict, deep_get, deep_set +from ...algorithm import AlgorithmScheduler, ALGORITHMS + +import copy +import numpy as np +from collections import OrderedDict +from time import time +from ...utils import logger +import sys + + +from .utils.tuning_space import TuningItem, TuningSpace +from .utils.tuning_structs import OpTuningConfig +from .utils.constant import FALLBACK_RECIPES_SET + + +EXP_STRATEGIES = {} + + +def strategy_registry(cls): + """Class decorator used to register all TuneStrategy subclasses. + + Args: + cls (class): The class of register. + + Returns: + cls: The class of register. + """ + assert cls.__name__.endswith( + 'TuneStrategy' + ), "The name of subclass of TuneStrategy should end with \'TuneStrategy\' substring." + if cls.__name__[:-len('TuneStrategy')].lower() in EXP_STRATEGIES: + raise ValueError('Cannot have two strategies with the same name') + EXP_STRATEGIES[cls.__name__[:-len('TuneStrategy')].lower()] = cls + return cls + +@strategy_registry +class TuneStrategy(object): + """Basic class for tuning strategy.""" + + def __init__(self, model, conf, q_dataloader=None, q_func=None, eval_dataloader=None, + eval_func=None, resume=None, q_hooks=None): + """Init the TuneStrategy. + + Args: + model: The FP32 model specified for low precision tuning. + conf: The Conf class instance includes all user configurations. + q_dataloader: Data loader for calibration, mandatory for post-training quantization. Defaults to None. + q_func: Training function for quantization aware training. Defaults to None. Defaults to None. + eval_dataloader: Data loader for evaluation. Defaults to None. + eval_func: The evaluation function provided by user. This function takes model as parameter, and + evaluation dataset and metrics should be encapsulated in this function implementation and + outputs a higher-is-better accuracy scalar value. + resume: The dict containing resume information. Defaults to None. + q_hooks: The dict of training hooks, supported keys are: on_epoch_begin, on_epoch_end, on_step_begin, + on_step_end. Their values are functions to be executed in adaptor layer.. Defaults to None. + last_qmodel: The quantized model that generated from the last tuning. + best_qmodel: The best quantized model that generated during the tuning process. + """ + self.model = model + self.cfg = conf.usr_cfg + self.cfg_bk = copy.deepcopy(self.cfg) + self.history_path = self._create_path(self.cfg.tuning.workspace.path, './history.snapshot') + self.deploy_path = self._create_path(self.cfg.tuning.workspace.path, 'deploy.yaml') + self.eval_dataloader = eval_dataloader + self.calib_dataloader = q_dataloader + self.q_func = q_func + self.q_hooks = q_hooks + self.eval_func = eval_func + GLOBAL_STATE.STATE = MODE.QUANTIZATION + framework, framework_specific_info = self._set_framework_info(q_dataloader, q_func) + self.adaptor = FRAMEWORKS[framework](framework_specific_info) + self.framework = framework + + self.set_q_func() + self._set_objectives() + self.tune_data = {} + self.tune_result_record = [] + self.tuning_history = [] + self.tuning_result_data = [] + # The tuning history ever made, structured like below: + # [ + # { + # 'version': __version__, + # 'cfg': cfg1, + # 'framework': tensorflow + # 'baseline': baseline1, + # 'last_tune_result': last_tune_result1, + # 'best_tune_result': best_tune_result1, + # 'history': [ + # # tuning history under same yaml config + # {'tune_cfg': tune_cfg1, 'tune_result': \ + # tune_result1, 'q_config': q_config1, ...}, + + # ..., + # ], + # # new fields added by subclass for resuming + # ..., + # }, + # # tuning history under different yaml configs + # ..., + # ] + + self.baseline = None + self.last_tune_result = None + self.last_qmodel = None + self.last_tune_cfg = None + self.best_qmodel = None + self.best_tune_result = None + self.best_tuning_cfg = None # track the best tuning config correspondence to the best quantized model + self.cur_best_acc = self.initial_best_acc() # track the current best accuracy + self.cur_best_tuning_cfg = {} # track tuning cfg with the current best accuracy + self.re_quant = False + + self.capability = self.adaptor.query_fw_capability(model) + logger.debug(self.capability) + self.set_tuning_space(conf) + + #For algo scheduler + self.algo_scheduler = AlgorithmScheduler(self.cfg.quantization.recipes) + self.algo_scheduler.dataloader = self.calib_dataloader # reuse the calibration iteration + self.algo_scheduler.origin_model = self.model + self.algo_scheduler.adaptor = self.adaptor + + self._optype_statistics = None + self.fallback_stats_baseline = None + self.fallback_stats = None + self.tuning_times = 0 + self.fallback_start_point = 0 + self.metric_met_point = 0 + + # for recipes + # {recipe name: the list of supported value} + self._tuning_recipes = OrderedDict() + # {recipe name: the default value when not tuning} + self._tuning_recipes_default_values = {} + # {recipe name: the value specified by user} + self._not_tuning_recipes_values = {} + self._initialize_recipe() + self.applied_all_recipes_flag = False + if resume is not None: self.setup_resume(resume) + + + @abstractmethod + def next_tune_cfg(self): + """Interface for generate the next tuning config. + + The generator of yielding next tuning config to traverse by concrete strategies or quantization level + according to last tuning result and traverse logic. + + It should be implemented by the sub-class. + + Yields: + tune_config (dict): It's a dict containing the tuning configuration to traverse. + """ + raise NotImplementedError + + def _initialize_recipe(self): + """Divide the recipe into two categories tuning/not tuning.""" + from .utils.utility import get_adaptor_name + from ...utils.constant import RECIPES as fwk_recipes + from ...utils.constant import RECIPES_PRIORITY as fwk_recipes_priority + # get all recipes supported by adaptor. + adaptor_name = get_adaptor_name(self.adaptor) + adaptor_recipes = fwk_recipes['common'] + # TODO WA due to smooth quant only supported by ort/pt currently. + if not adaptor_name not in ['onnx', 'pytorch']: + adaptor_recipes.pop('smooth_quant', None) + for adaptor_name_key, adaptor_recipes_val in fwk_recipes.items(): + if adaptor_name_key.startswith(adaptor_name): + adaptor_recipes.update(adaptor_recipes_val) + # divide it into two categories: + # tuning lst: the value is equal to the default value + # not tuning list: the value is not equal to the default value + logger.info(f"Adaptor has {len(adaptor_recipes)} recipes.") + logger.debug(adaptor_recipes) + usr_recipes_cfg = self.cfg_bk.quantization.recipes if self.cfg_bk.quantization.recipes else {} + for recipe_name, recipe_val in usr_recipes_cfg.items(): + # for not tuning recipes, use the value specified by user. + if recipe_name in adaptor_recipes and recipe_val != adaptor_recipes[recipe_name][0]: + self._not_tuning_recipes_values[recipe_name] = recipe_val + # sorted the recipes and set the default value to be used before recipe tuning + for recipe_name in fwk_recipes_priority: + if recipe_name in adaptor_recipes and recipe_name not in self._not_tuning_recipes_values: + # TODO skip tuning smooth_quant first + if recipe_name == 'smooth_quant': continue + self._tuning_recipes[recipe_name] = adaptor_recipes[recipe_name] + self._tuning_recipes_default_values[recipe_name] = adaptor_recipes[recipe_name][0] + logger.info(f"{len(self._not_tuning_recipes_values)} recipes specified by user.") + logger.debug(self._not_tuning_recipes_values) + logger.info(f"{len(self._tuning_recipes)} recipes require future tuning.") + logger.debug(self._tuning_recipes) + + + def distributed_next_tune_cfg_lst(self, comm): + """Interface for generate the distributed next tuning config list. + + The generator of yielding next tuning config list to distributed traverse by concrete strategies or + quantization level according to tuning result and traverse logic. + + It should be implemented by the sub-class. Currently, it is only implemented in the BasicTuneStrategy. + """ + pass + + def meet_acc_req(self, eval_res): + """Compare the result of last tuning with baseline to check whether the result meet requirements. + + Args: + eval_res: The evaluation result of tuning. + + Returns: + Return True if the accuracy meets requirements else False. + """ + self.last_tune_result = eval_res + return self.objectives.accuracy_meet_req(deepcopy(self.last_tune_result)) + + def master_worker_handle(self, comm): + """Matster worker handles the task assignment and result management. + + Master node send all task ids to all free nodes, and wait until any result. + When receiving any result, directly send a new task id to the sender (it's free). + + Args: + comm (MPI.COMM): The instance of comunication for MPI. + """ + MPI = LazyImport("mpi4py.MPI") + size = comm.Get_size() + for process_id in range(1, min(len(self.tune_cfg_lst) + 1, size)): + tune_cfg_id = process_id - 1 + logger.info("~~~~~~master sending tune cfg: {} to rank {}".format(tune_cfg_id, process_id)) + comm.send( + obj=tune_cfg_id, # just send the tune cfg id is enough + dest=process_id, # rank 0 send to rank 1, 2, ... + tag=tune_cfg_id # tag, the index of tune cfg 0,1,2,3 + ) + import time as ttime + ttime.sleep(0.5) # WA for UT + + cur_cfg_id = min(len(self.tune_cfg_lst), size - 1) # 4 master should be aware of the next config id to send + self.eval_results = {} # record all results + self.num_acks = 0 # number of all response acks, break when it equals to len() + status = MPI.Status() # used to obtain the source and the tag for each received message + + self.already_ack_id_lst = set() + self.requirements_met_min_cfg_id = sys.maxsize + + # stuck here to receive any result + while True: + eval_res = comm.recv( + source=MPI.ANY_SOURCE, + tag=MPI.ANY_TAG, + status=status # get MPI status object + ) + self.num_acks += 1 + sender_rank = status.Get_source() # sender rank + tag = status.Get_tag() # the task id that is finished + + logger.info("~~~~~~master receiving eval result: {} from rank {}".format(eval_res, sender_rank)) + + self.last_tune_result = eval_res # for context coordination of stage 3 + self.eval_results[tag] = eval_res + + self.overall_trials += 1 + self.best_tune_cfg_id = None + self.already_ack_id_lst.add(tag) + + # if meet accuracy requirement, then update minimum id that met requirement + if(self.meet_acc_req(eval_res)): + logger.info("~~~~~~master has one tuning cfg meet acc: {}".format(tag)) + self.met_flag = True + self.requirements_met_min_cfg_id = min(self.requirements_met_min_cfg_id, tag) + + # must ensure every id lower than current min_id has been acknowledged + # because a tune cfg (not acked yet) with lower id can have better acc + for i in range(self.requirements_met_min_cfg_id): + if i not in self.already_ack_id_lst: + logger.info("~~~~~~master has one tuning cfg meet acc: {} but not collect all acks before"\ + .format(tag)) + self.met_flag = False # not completely collected yet! + break + + if self.met_flag: + # found the best tune cfg! + logger.info("~~~~~~master has one tuning cfg meet acc: {} and also collect all acks before"\ + .format(tag)) + self.best_tune_cfg_id = self.requirements_met_min_cfg_id + else: + # get the current best acc but not meet requirements + logger.info("~~~~~~master gets the current best acc: {} but not meet requirements".format(tag)) + self.cur_best_acc, self.cur_best_tuning_cfg = self.update_best_op_tuning_cfg(self.tune_cfg_lst[tag]) + + if self.best_tune_cfg_id is not None: + #### we find the best tune cfg id that meet requirements!! + logger.info("~~~~~~master finds best tune cfg id~~~~~~~") + logger.info(self.best_tune_cfg_id) + logger.info(self.tune_cfg_lst[self.best_tune_cfg_id]) + break + + # send the next cfg if not exceed max trials + if self.overall_trials > self.cfg.tuning.exit_policy.max_trials: + self.max_trial_flag = True + # elif time.time() - self.overall_time_start > self.cfg.tuning.exit_policy.timeout: + # self.max_time_flag = True + elif cur_cfg_id < len(self.tune_cfg_lst): + logger.info("~~~~~~master sends new tuning cfg {} to rank: {}".format(cur_cfg_id, sender_rank)) + comm.send(obj=cur_cfg_id, dest=sender_rank, tag=cur_cfg_id) + cur_cfg_id += 1 + else: + logger.info("All tune configs are sent, no more sending, just collecting...") + + if len(self.tune_cfg_lst) == self.num_acks: # all collected (ack should collected == acks) + # all processes ended + # return self.requirements_met_min_cfg_id if it has been updated + if self.requirements_met_min_cfg_id == sys.maxsize: + logger.info("~~~~~~Not found any tune cfg that meet requirements~~~~~~") + self.cur_best_tuning_cfg = self.tune_cfg_lst[0] # TODO select cur_best_tuning_cfg + else: + logger.info("~~~~~~Find best tune cfg id~~~~~~") + logger.info(self.requirements_met_min_cfg_id) + self.met_flag = True + self.best_tune_cfg_id = self.requirements_met_min_cfg_id + logger.info(self.tune_cfg_lst[self.best_tune_cfg_id]) + break + + # send END signal to all other slaves + logger.info("~~~~~~master sends END signal to all other slaves~~~~") + for process_id in range(1, size): + logger.info("~~~~~~master sends END signal to rank: {}".format(process_id)) + comm.send( + obj="MET" if self.met_flag else "NOT MET", # send whether met criterion in the current stage + dest=process_id, # rank 0 send to rank 1, 2, ... + tag=len(self.tune_cfg_lst) + ) + + if self.best_tune_cfg_id is not None: + self.best_qmodel = self.adaptor.quantize( + copy.deepcopy(self.tune_cfg_lst[self.best_tune_cfg_id]), self.model, self.calib_dataloader, \ + self.q_func) + + + def slave_worker_handle(self, comm): + """Slave worker handles the task processing. + + When receiving any task id, slave node finds it in self.tune_cfg_lst and run it. + Then slave node sends back the tune result to master node. + + Args: + comm (MPI.COMM): The instance of comunication for MPI. + """ + MPI = LazyImport("mpi4py.MPI") + status = MPI.Status() + while True: + task = comm.recv( + source=MPI.ANY_SOURCE, + tag=MPI.ANY_TAG, + status=status # sender (master) + ) + cfg_idx = status.Get_tag() + if status.Get_tag() >= len(self.tune_cfg_lst): + logger.info("~~~~~~slave {} receiving END signal in the current stage".format(comm.Get_rank())) + if task == "MET": + logger.info("~~~~~~met criterion in this stage!") + self.met_flag = True + break + tune_cfg = self.tune_cfg_lst[cfg_idx] + + # set the parameter for pre quantization algos and run + self.set_param_for_pre_quantization_algos(self.algo_scheduler, tune_cfg, self.model) + self.model = self.algo_scheduler('pre_quantization') + # quantize + q_model = self.adaptor.quantize(copy.deepcopy(tune_cfg), self.model, self.calib_dataloader, self.q_func) + assert self.adaptor.pre_optimized_model + # set the parameter for post quantization algos and run + self.set_param_for_post_quantization_algos(self.algo_scheduler, tune_cfg, self.adaptor.pre_optimized_model, + q_model) + self.last_qmodel = self.algo_scheduler('post_quantization') + self.last_tune_cfg = copy.deepcopy(tune_cfg) + # Remove the reference to model + self.algo_scheduler.reset_exec_algorithms() + assert self.last_qmodel + self.last_tune_result = self._evaluate(self.last_qmodel) + + ##### send back the tuning statistics ######### + logger.debug("##### Slave sends back the tuning statistics #########") + logger.debug(self.last_tune_result) + comm.send( + obj=self.last_tune_result, + dest=0, # rank 0 send to rank 1, 2, ... + tag=cfg_idx + ) + + def distributed_traverse(self): + """Disributed traverse the tuning space. + + The main traverse logic which could be override by some concrete strategy which needs more hooks. + """ + MPI = LazyImport("mpi4py.MPI") + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + + self.met_flag = False + self.max_trial_flag = False # whether exceed max trials + self.max_time_flag = False # whether exceed max time + self.overall_trials = 0 + self.overall_time_start = time() + + # for all the stages, handle the tune cfg lst + # the tune cfg lst is generated/yielded each time by distributed_next_self.tune_cfg_lst + # we must pass the comm to the specific strategy because slaves may not know + # contexts such as the best_tune_cfg + # master should make sure slaves have all the contexts needed before going to the next computation stage + for op_tuning_cfg_lst in self.distributed_next_tune_cfg_lst(comm): + self.tune_cfg_lst = [self._tune_cfg_converter(op_tuning_cfg) for op_tuning_cfg in op_tuning_cfg_lst] + if self.tune_cfg_lst == []: + # skip empty list at some stages + continue + if rank == 0: + self.master_worker_handle(comm) + else: + self.slave_worker_handle(comm) + logger.debug("# if self.met_flag or self.max_trial_flag or self.max_time_flag:" \ + .format(self.met_flag or self.max_trial_flag or self.max_time_flag)) + if self.met_flag or self.max_trial_flag or self.max_time_flag: + break + + def _open_all_recipes(self): + """Open all tunable recipes.""" + opened_recipes = {} + for recipe_name, recipe_val_lst in self._tuning_recipes.items(): + opened_recipes[recipe_name] = recipe_val_lst[-1] + logger.info("Opened all recipes.") + logger.info(opened_recipes) + + def _fallback_ops(self, tune_cfg, recipe_op_lst, tuning_space): + """Fallback ops in recipe op list.""" + for op_name_type in recipe_op_lst: + tune_cfg.update({op_name_type: OpTuningConfig(op_name_type[0], \ + op_name_type[1],'fp32', tuning_space)}) + return tune_cfg + + def apply_all_tuning_recipes(self, tune_cfg): + """Apply all tunable recipes with their value.""" + tune_cfg['recipe_cfgs'] = tune_cfg.get('recipe_cfgs', {}) + for recipe_name, recipe_val_lst in self._tuning_recipes.items(): + tune_cfg['recipe_cfgs'][recipe_name] = recipe_val_lst[-1] + if recipe_name in FALLBACK_RECIPES_SET and 'recipes_ops' in self.capability and \ + len(self.capability['recipes_ops'].get(recipe_name, [])) > 0: + logger.info(f"Applied recipe {recipe_name}.") + tune_cfg = self._fallback_ops(tune_cfg, self.capability['recipes_ops'][recipe_name],\ + self.tuning_space) + return tune_cfg + + def apply_recipe_one_by_one(self, tune_cfg): + """Apply the tunable recipes one by one. + + For recipes only have two options, apply the last one. + For recipes with multiple values. such as alpha of smooth quant, apply it one by one. + """ + from .utils.tuning_sampler import TuningSamplerRegistry + all_registered_samplers = TuningSamplerRegistry.sampler_dict + for recipe_name, recipe_vals in self._tuning_recipes.items(): + if recipe_name in FALLBACK_RECIPES_SET and 'recipes_ops' in self.capability and \ + len(self.capability['recipes_ops'].get(recipe_name, [])) > 0: + logger.info(f"Applied recipe {recipe_name} with value {recipe_vals[-1]}") + new_tune_cfg = self._fallback_ops(copy.deepcopy(tune_cfg), \ + self.capability['recipes_ops'][recipe_name], self.tuning_space) + yield new_tune_cfg + if recipe_name in all_registered_samplers: + recipe_sampler = all_registered_samplers[recipe_name](tuning_space=None, + tuning_order_lst=[], + initial_op_tuning_cfg=copy.deepcopy(tune_cfg), + kwargs={recipe_name: recipe_vals}) + for new_tune_cfg in recipe_sampler: + yield new_tune_cfg + + def set_param_for_pre_quantization_algos(self, algo_scheduler, tune_cfg, fp32_model) -> None: + """Set the parameter for pre-quantization algos, such as smooth quantization. + + Args: + algo_scheduler: algo scheduler + tune_cfg: the tuning config + fp32_model: the fp32 model + """ + algo_scheduler.origin_model = fp32_model + algo_scheduler.calib_iter = tune_cfg['calib_iteration'] + algo_scheduler.q_model = fp32_model + + recipe_cfgs = tune_cfg.get('recipe_cfgs', None) + algo_scheduler.reset_exec_algorithms() + if recipe_cfgs and recipe_cfgs.get('smooth_quant', False): + # skip assign alpha to sq first. + # set the alpha to 0.5 by default + # smooth_quant_args = recipe_cfgs.get('smooth_quant_args', {'alpha': 0.5}) + sq_algo = ALGORITHMS()['smooth_quant'] + #sq_algo.alpha = smooth_quant_args['alpha'] + #logger.debug(f"Set smooth quant with alpha {smooth_quant_args['alpha']} as the pre-quantization algo.") + algo_scheduler.append_algorithm('pre_quantization', sq_algo) + + + def set_param_for_post_quantization_algos(self, algo_scheduler, tune_cfg, pre_optimized_model, q_model) -> None: + """Set the parameter for post-quantization algos, such as bias correction, weight correction. + + Args: + algo_scheduler: algo scheduler + tune_cfg: the tuning config. + pre_optimized_model: the pre-optimized model + q_model: the quantized model + """ + algo_scheduler.origin_model = pre_optimized_model + # if no pre-process algos, return the fp32 model directly. + algo_scheduler.q_model = q_model + + algo_scheduler.reset_exec_algorithms() + recipe_cfgs = tune_cfg.get('recipe_cfgs', None) + # for fast_bias_correction + if recipe_cfgs and recipe_cfgs.get('fast_bias_correction', False): + fbc_algo = ALGORITHMS()['fast_bias_correction'] + fbc_algo.quantization_cfg = deepcopy(tune_cfg) + algo_scheduler.append_algorithm('post_quantization', fbc_algo) + logger.debug(f"Add fast bias correction as the post quantization algo.") + # for weight correction + if recipe_cfgs and recipe_cfgs.get('weight_correction', False): + w_algo = ALGORITHMS()['weight_correction'] + w_algo.quantization_cfg = deepcopy(tune_cfg) + algo_scheduler.append_algorithm('post_quantization', w_algo) + logger.debug(f"Add weight correction as the post quantization algo.") + + def traverse(self): + """Traverse the tuning space. + + The main traverse logic which could be override by some concrete strategy which needs more hooks. + """ + self._eval_baseline() + logger.info("use distributed traverse: {}".format(self.cfg.tuning.use_distributed_tuning)) + if self.cfg.tuning.use_distributed_tuning: + return self.distributed_traverse() + trials_count = 0 + traverse_start_time = time() + for op_tuning_cfg in self.next_tune_cfg(): + tuning_start_time = time() + tune_cfg = self._tune_cfg_converter(op_tuning_cfg) + trials_count += 1 + tuning_history = self._find_tuning_history(tune_cfg) + if tuning_history and trials_count < self.cfg.tuning.exit_policy.max_trials: + self.last_tune_result = tuning_history['last_tune_result'] + self.best_tune_result = tuning_history['best_tune_result'] + logger.warn("Find evaluated tuning config, skip.") + continue + self._remove_redundant_qmodel() + logger.debug("Dump current tuning configuration:") + logger.debug(tune_cfg) + self.tuning_times += 1 + # set the parameter for pre quantization algos and run + self.set_param_for_pre_quantization_algos(self.algo_scheduler, tune_cfg, self.model) + self.model = self.algo_scheduler('pre_quantization') + # quantize + q_model = self.adaptor.quantize(copy.deepcopy(tune_cfg), self.model, self.calib_dataloader, self.q_func) + assert self.adaptor.pre_optimized_model + # set the parameter for post quantization algos and run + self.set_param_for_post_quantization_algos(self.algo_scheduler, tune_cfg, self.adaptor.pre_optimized_model, + q_model) + self.last_qmodel = self.algo_scheduler('post_quantization') + self.last_tune_cfg = copy.deepcopy(tune_cfg) + # Remove the reference to model + self.algo_scheduler.reset_exec_algorithms() + assert self.last_qmodel + # Return the last quantized model as a result. if performance only. + if self.cfg.tuning.exit_policy.performance_only: + self.best_qmodel = self.last_qmodel + self._add_tuning_history(copy.deepcopy(tune_cfg), (-1, [0]), q_config=self.last_qmodel.q_config) + return + self.last_tune_result = self._evaluate(self.last_qmodel) + self.cur_best_acc, self.cur_best_tuning_cfg = self.update_best_op_tuning_cfg(op_tuning_cfg) + need_stop = self.stop(self.cfg.tuning.exit_policy.timeout, trials_count) + + # record the tuning history + saved_tune_cfg = copy.deepcopy(tune_cfg) + saved_last_tune_result = copy.deepcopy(self.last_tune_result) + self._add_tuning_history(saved_tune_cfg, + saved_last_tune_result, + q_config=q_model.q_config) + self.tune_result_record.append(copy.deepcopy(self.last_tune_result)) + self.tune_cfg = tune_cfg + now_time = time() + acc_res_msg = "" + performace_res_msg = "" + if self.tuning_result_data: + acc_res_msg = "[ " + "| ".join(self.tuning_result_data[0]) + " ]" + performace_res_msg = "[ " + "| ".join(self.tuning_result_data[1]) + " ]" + logger.debug(f"*** The accuracy of last tuning is: {acc_res_msg}") + logger.debug(f"*** The perfomance of last tuning is: {performace_res_msg}") + logger.debug(f"*** The last tuning time: {(now_time - tuning_start_time):.2f} s") + logger.debug(f"*** The tuning process lasted time: {(now_time - traverse_start_time):.2f} s") + + self._dump_tuning_process_statistics() + if need_stop: + if self.re_quant: + logger.info("*** Do not stop the tuning process, re-quantize the ops.") + continue + # recover the best quantized model from tuning config + self._recover_best_qmodel_from_tuning_cfg() + if self.cfg.tuning.diagnosis and self.cfg.tuning.diagnosis.diagnosis_after_tuning: + logger.debug(f'*** Start to do diagnosis (inspect tensor).') + self._diagnosis() + if self.use_multi_objective and len(self.tune_result_record) > 1 and \ + self.best_tune_result is not None: + best_trail, best_result = self.objectives.best_result(self.tune_result_record, + copy.deepcopy(self.baseline)) + if best_result != self.best_tune_result: + from neural_compressor.utils.utility import recover + self.best_qmodel = recover(self.model.model, + os.path.join(self.cfg.tuning.workspace.path, 'history.snapshot'), + best_trail) + logger.debug(f"*** Update the best qmodel by recovering from history.") + self.best_tune_result = best_result + self._dump_tuning_process_statistics() + break + self._recover_best_qmodel_from_tuning_cfg() + + def _remove_redundant_qmodel(self): + """Remove the redundant quantized model to reduce memory use. + + During the tuning process, the strategy only keeps the best tuning config + instead of the best quantized model to reduce memory use. + """ + self.last_qmodel = None + self.best_qmodel = None + + def _can_create_eval_func_from_cfg(self): + """Determine whether an eval function can be created from cfg. + + Returns: + Returns True if the eval func can be created from config, False otherwise. + """ + if self.cfg.evaluation and self.cfg.evaluation.accuracy and \ + (self.cfg.evaluation.accuracy.metric or self.cfg.evaluation.accuracy.multi_metrics)\ + and self.eval_dataloader: + return True + return False + + def _eval_baseline(self): + """Evaluate the fp32 model if needed.""" + if not self._can_create_eval_func_from_cfg() and not self.eval_func: + logger.info("Neither evaluation function nor metric is defined." \ + " Generate a quantized model with default quantization configuration.") + self.cfg.tuning.exit_policy.performance_only = True + logger.info("Force setting 'tuning.exit_policy.performance_only = True'.") + + if not self.cfg.tuning.exit_policy.performance_only: + # get fp32 model baseline + if self.baseline is None: + logger.info("Get FP32 model baseline.") + self._fp32_model = self.model + self.baseline = self._evaluate(self.model) + self.objectives.baseline = self.baseline + # record the FP32 baseline + self._add_tuning_history() + self.show_baseline_info() + + def _recover_best_qmodel_from_tuning_cfg(self): + """Recover the best quantized model from tuning config.""" + if self.best_tuning_cfg and not self.best_qmodel: + self.best_qmodel = self.adaptor.quantize(copy.deepcopy(self.best_tuning_cfg), self.model, + self.calib_dataloader, self.q_func) + + def _fallback_started(self): + self.fallback_start_point = self.tuning_times + + def _update_optype_statistics(self): + self._optype_statistics = defaultdict(lambda:defaultdict(int)) + + for op_name_type, op_tune_cfg in self.tune_cfg['op'].items(): + optype = op_name_type[1] + quant_mode = op_tune_cfg['activation']['quant_mode'] + if isinstance(quant_mode, tuple) or isinstance(quant_mode, list): + quant_mode = quant_mode[0] + dtype = 'INT8' if quant_mode in ('static', 'dynamic') \ + else quant_mode.upper() + self._optype_statistics[optype]['Total'] += 1 + self._optype_statistics[optype][dtype] += 1 + return + + def _dump_tuning_process_statistics(self): + self._update_optype_statistics() + + logger.debug("Current tuning process statistics:") + logger.debug(f"Total Tuning Times: {self.tuning_times}") + logger.debug("Fallback started at Tune {}".format(self.fallback_start_point)) + logger.debug("Objective(s) met at Tune {}".format(self.metric_met_point)) + + fallback_stats = self._calculate_fallback_op_count() + if self.fallback_stats_baseline == None: + self.fallback_stats_baseline = fallback_stats + logger.debug(f"Fallbacked ops count: {self.fallback_stats_baseline - fallback_stats}") + + if isinstance(self.adaptor, TensorFlowAdaptor): + self._compare_optype_statistics() + + return + + def _calculate_fallback_op_count(self, target_dtype='INT8'): + fallback_stats = defaultdict(int) + + for optype in self._optype_statistics: + for dtype, count in self._optype_statistics[optype].items(): + fallback_stats[dtype] += count + + return fallback_stats[target_dtype] + + + def _compare_optype_statistics(self, fields=None, optypes=None, + skip_fields=None, skip_optypes=None): + assert(fields == None or skip_fields == None) + assert(optypes == None or skip_optypes == None) + if not isinstance(self.adaptor, TensorFlowAdaptor): + logger.debug("OpType statistics comparation is only available for TensorFlow adaptor.") + return + + adaptor_statistics = self.adaptor.optype_statistics + + def _field_skipped(field): + if fields != None: + return field not in fields + elif skip_fields != None: + return field in skip_fields + + def _optype_skipped(optype): + if optypes != None: + return optype not in optypes + elif skip_optypes != None: + return optype in skip_optypes + + + field_names = adaptor_statistics[0][1:] + adaptor_data = { + line[0].lower() : {dtype : count for dtype, count in zip(field_names, line[1:])} + for line in adaptor_statistics[1]} + strategy_data = self._optype_statistics + + # compare adaptor statistics to strategy statistics + logger.debug("Statistics difference between adaptor and tuning config:") + has_difference = False + difference_count = 0 + for optype in adaptor_data: + if optype not in strategy_data or _optype_skipped(optype): continue + for field in field_names: + if _field_skipped(field): continue + adaptor_count = adaptor_data[optype][field] + strategy_count = strategy_data[optype][field] + if adaptor_count != strategy_count: + has_difference = True + if field == 'INT8': + difference_count += abs(strategy_count - adaptor_count) + logger.debug("\t{}: [adaptor: {} | tune_cfg: {}]".format( + (optype, field), adaptor_count, strategy_count)) + if not has_difference: + logger.debug("\tNone") + logger.debug(f"\tDifference(s) in total: {difference_count}") + return + + def initial_tuning_cfg(self): + """Init the tuning config. + + Initialize the tuning config according to the quantization approach. + + Returns: + op_item_dtype_dict (OrderedDict): key is (op_name, op_type); value is quantization mode. + quant_mode_wise_items (OrderedDict): key is quant_mode/precision; value is item list. + initial_op_tuning_cfg (OrderedDict): key is (op_name, op_type); value is the initialized tuning config. + """ + from .utils.constant import auto_query_order, static_query_order, dynamic_query_order + from .utils.tuning_space import initial_tuning_cfg_with_quant_mode + if self.cfg.quantization.approach == 'post_training_auto_quant': + query_order = auto_query_order + elif self.cfg.quantization.approach == 'post_training_dynamic_quant': + query_order = dynamic_query_order + elif self.cfg.quantization.approach == 'post_training_static_quant': + query_order = static_query_order + elif self.cfg.quantization.approach == 'quant_aware_training': + logger.info("!!! Currently, the qat tuning is not supported by strategy.") + query_order = auto_query_order + + quant_mode_wise_items = OrderedDict() # mode, op_item_lst + pre_items = set() + # Collect op items supported the specified mode. + for quant_mode in query_order: + items = self.tuning_space.query_items_by_quant_mode(quant_mode) + filtered_items = list(filter(lambda item: item not in pre_items, items)) + pre_items = pre_items.union(set(items)) + quant_mode_wise_items[quant_mode] = filtered_items + + def initial_op_quant_mode(items_lst, target_quant_mode, op_item_dtype_dict): + for item in items_lst: + op_item_dtype_dict[item.name] = target_quant_mode + + op_item_dtype_dict = OrderedDict() + for quant_mode, quant_mode_items in quant_mode_wise_items.items(): + initial_op_quant_mode(quant_mode_items, quant_mode, op_item_dtype_dict) + + initial_op_tuning_cfg = {} + for op_name_type, quant_mode in op_item_dtype_dict.items(): + initial_op_tuning_cfg[op_name_type] = initial_tuning_cfg_with_quant_mode(op_name_type, + quant_mode, + self.tuning_space) + return op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg + + def show_baseline_info(self): + """Display the accuracy and duration of the the baseline model.""" + if self.baseline: + self.tune_data['baseline'] = self.baseline[0] if \ + isinstance(self.baseline[0], list) else [self.baseline[0]] + for name, data in zip(self.metric_name, self.tune_data['baseline']): + self.tune_data[name] = [data] + if self.metric_weight: + # baseline is weighted accuracy + self.tune_data['Weighted accuracy'] = \ + [np.mean(np.array(self.tune_data['baseline']) * self.metric_weight)] + self.tune_data['baseline'] = self.tune_data['Weighted accuracy'] + baseline_msg = '[Accuracy:' + \ + ''.join([' {:.4f}'.format(i) for i in self.tune_data['baseline']]) + \ + ''.join([', {}: {:.4f}'.format(x,y) for x,y in zip( \ + self.objectives.representation, self.baseline[1]) if x != 'Accuracy']) + ']' + else: # pragma: no cover + if self.metric_weight: + self.tune_data['Weighted accuracy'] = ['n/a'] + self.tune_data['baseline'] = ['n/a'] + + for name, data in zip(self.metric_name, self.tune_data['baseline']): + self.tune_data[name] = ['n/a'] + baseline_msg = 'n/a' + logger.info("FP32 baseline is: {}".format(baseline_msg)) + + def initial_best_acc(self): + """Init the best accuracy. + + Returns: + The initial value of best accuracy. + """ + if len(self.metric_name) == 1 or self.metric_weight is not None: + best_acc = float('-inf') if self.higher_is_better else float('inf') + else: + best_acc = [float('-inf') if higher_is_better else float('inf') for \ + higher_is_better in self.metric_criterion] + return best_acc + + def _tune_cfg_converter(self, op_tuning_cfg): + """Convert op_tuning_cfg for adaptor. + + Args: + op_tuning_cfg (Dict): the op tuning config. + """ + tune_cfg = {'op': OrderedDict()} + for op_name_type, op_config in op_tuning_cfg.items(): + if isinstance(op_config, OpTuningConfig): + tune_cfg['op'][op_name_type] = op_config.get_state() + op_cap_lst = self.capability['opwise'][op_name_type] + # Add pattern for diagnosis + for op_cap in op_cap_lst: + if 'pattern' in op_cap: + op_pattern = {} + op_pattern['sequence'] = op_cap['pattern']['sequence'][0] if\ + 'sequence' in op_cap['pattern'] else None + op_pattern['precision'] = op_cap['pattern']['precision'][0] if\ + 'precision' in op_cap['pattern'] else None + tune_cfg['op'][op_name_type]['pattern'] = op_pattern + else: + tune_cfg[op_name_type] = op_config + tune_cfg['calib_sampling_size'] = op_tuning_cfg['calib_sampling_size'] + if self.calib_dataloader is not None: + tune_cfg['calib_iteration'] = math.ceil(int(tune_cfg['calib_sampling_size']) / \ + self.calib_dataloader.batch_size) + else: + tune_cfg['calib_iteration'] = 1 + tune_cfg['advance'] = self.cfg.quantization.advance + tune_cfg['approach'] = self.cfg.quantization.approach + # Add the recipe config + tune_cfg['recipe_cfgs'] = tune_cfg.get('recipe_cfgs', {}) + # For not tuning recipe, tune cfg use it directly + tune_cfg['recipe_cfgs'].update(self._not_tuning_recipes_values) + # WA for get the smooth quant args + if 'smooth_quant_args' in self.cfg_bk.quantization.recipes: + tune_cfg['recipe_cfgs']['smooth_quant_args'] = self.cfg_bk.quantization.recipes['smooth_quant_args'] + # For tuning recipe, use the default value if it not specified by recipe tuning sampler. + for recipe_name, recipe_val in self._tuning_recipes_default_values.items(): + if recipe_name not in tune_cfg['recipe_cfgs']: + tune_cfg['recipe_cfgs'][recipe_name] = recipe_val + return tune_cfg + + def set_tuning_space(self, conf): + """Create the tuning space. + + Create the tuning space based on the framework capability and user configuration. + + Args: + conf: The Conf class instance includes all user configurations. + """ + calib_sampling_size_lst = self.cfg.quantization.calibration.sampling_size + calib_sampling_size_lst = [int(calib_sampling_size) for calib_sampling_size in calib_sampling_size_lst] + if self.calib_dataloader: + self.calib_iter = [math.ceil(int(x) / self.calib_dataloader.batch_size) \ + for x in calib_sampling_size_lst] + else: + self.calib_iter = 1 + # create tuning space + adaptor_cap = { + 'calib': {'calib_sampling_size': calib_sampling_size_lst}, + 'op': self.capability['opwise'] + } + self.tuning_space = TuningSpace(adaptor_cap, conf=conf, framework=self.framework) + + def setup_resume(self, resume): + """Resume the best quantized model from tuning history. + + Args: + resume: The dict containing resume information. + """ + self.__dict__.update(resume) + for history in self.tuning_history: + if self._same_yaml(history['cfg'], self.cfg): + self.__dict__.update({k: v for k, v in history.items() \ + if k not in ['version', 'history']}) + logger.info("Start to resume tuning process.") + # resume the best tuning model if needed + try: + index = history['id'] - 1 + resume_tuning_cfg = history['history'][index]['tune_cfg'] + self.best_qmodel = self.adaptor.quantize(resume_tuning_cfg, + self.model, + self.calib_dataloader, + self.q_func) + except: + logger.debug("Can not resume the best quantize model from history.") + + break + + def set_q_func(self): + """Set the training function for quantization aware training.""" + if self.q_func == None and self.cfg.quantization.approach == 'quant_aware_training': + train_cfg = self.cfg.quantization.train + assert train_cfg, "train field of quantization section in yaml file must " \ + "be configured for quantization aware training if q_func is NOT set." + assert self.calib_dataloader, "dataloader field of train field of quantization " \ + "section in yaml file must be configured." + self.q_func = create_train_func(self.framework, self.calib_dataloader, \ + self.adaptor, train_cfg, hooks=self.q_hooks) + + def _create_path(self, custom_path, filename): + new_path = os.path.join(os.path.abspath(os.path.expanduser(custom_path)),filename) + path = Path(os.path.dirname(new_path)) + path.mkdir(exist_ok=True, parents=True) + return new_path + + def _set_framework_info(self, q_dataloader, q_func=None): + framework_specific_info = {'device': self.cfg.device, + 'approach': self.cfg.quantization.approach, + 'random_seed': self.cfg.tuning.random_seed, + 'performance_only': self.cfg.tuning.exit_policy.performance_only,} + framework = self.cfg.model.framework.lower() + framework_specific_info.update({'backend': self.cfg.model.get('backend', 'default')}) + framework_specific_info.update({'format': self.cfg.model.get('quant_format', 'default')}) + framework_specific_info.update({'domain': self.cfg.model.get('domain', 'auto')}) + + self.mixed_precision_mode = bool('mixed_precision' in self.cfg) or \ + bool('graph_optimization' in self.cfg) + + if 'tensorflow' in framework: + framework_specific_info.update( + {"inputs": self.cfg.model.inputs, + "outputs": self.cfg.model.outputs, + 'workspace_path': self.cfg.tuning.workspace.path, + 'recipes': self.cfg.quantization.recipes, + 'use_bf16': self.cfg.use_bf16 if self.cfg.use_bf16 is not None else False}) + for item in ['scale_propagation_max_pooling', 'scale_propagation_concat']: + if item not in framework_specific_info['recipes']: + framework_specific_info['recipes'].update({item: True}) + if self.cfg.model.backend == 'itex': + self.cfg.model.framework = 'tensorflow_itex' + framework = 'tensorflow_itex' + if 'keras' in framework: + framework_specific_info.update({ + 'workspace_path': self.cfg.tuning.workspace.path, }) + if framework == 'mxnet': + framework_specific_info.update({"q_dataloader": q_dataloader}) + if 'onnx' in framework.lower(): + if self.mixed_precision_mode: + framework_specific_info.update({"approach": "post_training_dynamic_quant"}) + framework_specific_info.update({"deploy_path": os.path.dirname(self.deploy_path)}) + framework_specific_info.update({'workspace_path': self.cfg.tuning.workspace.path}) + framework_specific_info.update({'recipes': self.cfg.quantization.recipes}) + framework_specific_info.update({'reduce_range': self.cfg.reduce_range}) + framework_specific_info.update({'recipes': self.cfg.quantization.get('recipes', {})}) + if framework.lower() == 'onnxrt_qdq' or \ + framework_specific_info['backend'] == 'onnxrt_trt_ep': + framework_specific_info.update({'format': 'QDQ'}) + framework = 'onnxrt_qdq' + if framework == 'pytorch_ipex' or framework == 'pytorch' or framework == 'pytorch_fx': + if self.cfg.model.backend == 'ipex': + self.cfg.model.framework = 'pytorch_ipex' + framework = 'pytorch_ipex' + elif self.cfg.model.backend == 'default': + self.cfg.model.framework = 'pytorch_fx' + framework = 'pytorch_fx' + if self.mixed_precision_mode: + framework_specific_info.update({"approach": "post_training_dynamic_quant"}) + framework_specific_info.update({"q_dataloader": q_dataloader}) + framework_specific_info.update({"use_bf16": self.cfg.use_bf16 \ + if self.cfg.use_bf16 is not None else True}) + framework_specific_info.update( + {"workspace_path": os.path.dirname(self.deploy_path)}) + if self.cfg['quantization']['op_wise'] is not None \ + and 'default_qconfig' in self.cfg['quantization']['op_wise']: + framework_specific_info.update( + {"default_qconfig": self.cfg['quantization']['op_wise']['default_qconfig']}) + framework_specific_info.update({"q_func": q_func}) + framework_specific_info.update({"example_inputs": self.cfg.quantization.example_inputs}) + return framework, framework_specific_info + + def _set_objectives(self): + self.higher_is_better = bool(self.cfg.tuning.accuracy_criterion.higher_is_better) + self.use_multi_objective = deep_get(self.cfg, 'tuning.multi_objectives') and \ + len(self.cfg.tuning.multi_objectives.objective) > 1 + objectives = [i.lower() for i in self.cfg.tuning.multi_objectives.objective] if \ + self.use_multi_objective else [self.cfg.tuning.objective.lower()] + self.metric_weight = deep_get(self.cfg, 'evaluation.accuracy.multi_metrics.weight') + self.metric_name = ['Accuracy'] if \ + not deep_get(self.cfg, 'evaluation.accuracy.multi_metrics') else \ + self.cfg.evaluation.accuracy.multi_metrics.keys()-{'weight','higher_is_better'} + if len(self.metric_name) == 1: + self.metric_criterion = [self.higher_is_better] + elif not deep_get(self.cfg, 'evaluation.accuracy.multi_metrics.higher_is_better'): + # default is True + self.metric_criterion = [True] * len(self.metric_name) + else: + self.metric_criterion = \ + deep_get(self.cfg, 'evaluation.accuracy.multi_metrics.higher_is_better') + + self.objectives = MultiObjective(objectives, + self.cfg.tuning.accuracy_criterion, + self.metric_criterion, + self.metric_weight, + deep_get(self.cfg, 'tuning.multi_objectives.higher_is_better'), + deep_get(self.cfg, 'tuning.multi_objectives.weight')) + + def _same_yaml(self, src_yaml, dst_yaml): + """Check if the two yamls are the same. + + The check will exclude those keys which do not really impact the tuning result, such as + tensorboard, workspace, resume options under the tuning section of YAML. + """ + if equal_dicts(src_yaml, dst_yaml, ignore_keys=['tuning']) and \ + equal_dicts(src_yaml.tuning, src_yaml.tuning, compare_keys=['objective', + 'accuracy_criterion', + 'random_seed', + 'exit_policy']): + return True + + return False + + def update_best_op_tuning_cfg(self, op_tuning_cfg): + """Track and update the best tuning config with correspondence accuracy result. + + Args: + op_tuning_cfg: The tuning config. + + Returns: + The current best tuning results and corresponding configurations. + """ + acc, _ = self.last_tune_result + if self.cur_best_tuning_cfg is None: + self.cur_best_tuning_cfg = copy.deepcopy(op_tuning_cfg) + if not isinstance(acc, list) and ((self.higher_is_better and acc >= self.cur_best_acc) \ + or (not self.higher_is_better and acc <= self.cur_best_acc)): + self.cur_best_acc = acc + self.cur_best_tuning_cfg = copy.deepcopy(op_tuning_cfg) + elif len(self.metric_name) > 1 and self.metric_weight is not None: + acc = np.mean(np.array(acc) * self.metric_weight) + if (self.higher_is_better and acc >= self.cur_best_acc) or \ + (not self.higher_is_better and acc <= self.cur_best_acc): + self.cur_best_acc = acc + self.cur_best_tuning_cfg = copy.deepcopy(op_tuning_cfg) + elif len(self.metric_name) > 1 and self.metric_weight is None: + if all([acc_i >= best_i if higher_is_better else acc_i <= best_i for \ + acc_i, best_i, higher_is_better in \ + zip(acc, self.cur_best_acc, self.metric_criterion)]): + self.cur_best_acc = acc + self.cur_best_tuning_cfg = copy.deepcopy(op_tuning_cfg) + logger.debug(f"Best acc is {self.cur_best_acc}.") + return self.cur_best_acc, self.cur_best_tuning_cfg + + def deploy_config(self): + """Save the configuration locally for deployment.""" + acc_dataloader_cfg = deep_get(self.cfg, 'evaluation.accuracy.dataloader') + perf_dataloader_cfg = deep_get(self.cfg, 'evaluation.performance.dataloader') + # use acc dataloader if perf dataloader is not configured + if perf_dataloader_cfg is None: + perf_dataloader_cfg = acc_dataloader_cfg + + self.deploy_cfg = OrderedDict() + # int8 dataloader graph transform + if deep_get(perf_dataloader_cfg, 'transform.QuantizedInput') is not None \ + or deep_get(acc_dataloader_cfg, 'transform.QuantizedInput') is not None: + self.best_qmodel, scale = self.adaptor.quantize_input(self.best_qmodel) + deep_set(perf_dataloader_cfg, 'transform.QuantizedInput.dtype', 'int8') + deep_set(perf_dataloader_cfg, 'transform.QuantizedInput.scale', scale) + deep_set(acc_dataloader_cfg, 'transform.QuantizedInput.dtype', 'int8') + deep_set(acc_dataloader_cfg, 'transform.QuantizedInput.scale', scale) + + self.deploy_cfg['model'] = self.cfg.model + self.deploy_cfg['device'] = self.cfg.device + if self.cfg.evaluation is not None: + deep_set(self.cfg, 'evaluation.performance.dataloader',\ + perf_dataloader_cfg) + deep_set(self.cfg, 'evaluation.accuracy.dataloader', \ + acc_dataloader_cfg) + self.deploy_cfg['evaluation'] = self.cfg.evaluation + + def setup_yaml(): + represent_dict_order = lambda self, \ + data: self.represent_mapping('tag:yaml.org,2002:map', data.items()) + yaml.add_representer(OrderedDict, represent_dict_order) + yaml.add_representer(DotDict, represent_dict_order) + setup_yaml() + with open(self.deploy_path, 'w+') as f: + yaml.dump(self.deploy_cfg, f) + logger.info("Save deploy yaml to {}".format(self.deploy_path)) + + def _get_common_cfg(self, model_wise_cfg, op_wise_cfgs): + """Get the common parts from the model_wise_cfg. + + This function is focused on composing the configuration that consists of + model-wise field and op-wise unique field data. + + Args: + model_wise_cfg ([DotDict]): The model-wise configuration. + op_wise_cfgs ([List]): The list of each op's config in DotDict type. + + Returns: + [DotDict]: The combined configration with the op-wise unique field. + """ + model_wise_keys = model_wise_cfg.keys() + + result = op_wise_cfgs[0] + for each_op_wise_cfg in op_wise_cfgs: + tmp_cfg = {} + for k in model_wise_keys: + tmp_cfg[k] = each_op_wise_cfg[k] + + if model_wise_cfg == tmp_cfg: + result = each_op_wise_cfg + break + + return result + + @property + def evaluation_result(self): + """Evaluate the given model. + + Returns: + The objective value evaluated. + """ + return self._evaluate(self.model) + + def _evaluate(self, model): + """Interface of evaluating model. + + Args: + model (object): The model to be evaluated. + + Returns: + Objective: The objective value evaluated. + """ + if self.eval_func: + if self.cfg.tuning.tensorboard: + # Pytorch can insert observer to model in this hook. + # Tensorflow don't support this mode for now + model = self.adaptor._pre_eval_hook(model) + val = self.objectives.evaluate( + self.eval_func, model if self.framework == "pytorch_ipex" else model.model + ) + if self.cfg.tuning.tensorboard: + # post_eval_hook to deal the tensor + self.adaptor._post_eval_hook(model, accuracy=val[0]) + else: + assert self.cfg.evaluation and self.cfg.evaluation.accuracy and \ + (self.cfg.evaluation.accuracy.metric or \ + self.cfg.evaluation.accuracy.multi_metrics), \ + "metric or multi_metrics field of accuracy field of evaluation" \ + " section should not be empty" + + postprocess_cfg = self.cfg.evaluation.accuracy.postprocess + metric_cfg = self.cfg.evaluation.accuracy.metric if \ + self.cfg.evaluation.accuracy.metric else \ + self.cfg.evaluation.accuracy.multi_metrics + iteration = -1 if self.cfg.evaluation.accuracy.iteration is None \ + else self.cfg.evaluation.accuracy.iteration + eval_func = create_eval_func(self.framework, + self.eval_dataloader, + self.adaptor, + metric_cfg, + postprocess_cfg, + iteration, + tensorboard = self.cfg.tuning.tensorboard, + fp32_baseline = self.baseline == None) + + if getattr(self.eval_dataloader, 'distributed', False): + if 'tensorflow' in self.framework: + import horovod.tensorflow as hvd + elif self.framework in ['pytorch_ipex','pytorch','pytorch_fx']: + import horovod.torch as hvd + else: + raise NotImplementedError("Currently only TensorFlow and PyTorch " + "support distributed inference in PTQ.") + hvd.init() + try: + len_dataloader = len(self.eval_dataloader) + except: + logger.info("The length of the distributed dataloader is unknown." + "When the iteration of evaluation dataloader in each " + "process is inconsistent, an error may occur.") + else: + list_len_dataloader = hvd.allgather_object(len_dataloader) + if hvd.rank() == 0: + for i in range(len(list_len_dataloader)-1): + if list_len_dataloader[i] != list_len_dataloader[i+1]: + raise AttributeError("The evaluation dataloader's iteration is" + "different between processes, please reset " + "dataloader's batch_size.") + val = self.objectives.evaluate(eval_func, model) + if isinstance(val[0], list): + assert all([np.isscalar(i) for i in val[0]]), \ + "The eval_func should return a scalar or list of scalar, " \ + "but not {}!".format(str([type(i) for i in val[0]])) + else: + assert np.isscalar(val[0]), \ + "The eval_func should return a scalar or list of scalar, " \ + "but not {}!".format(str(type(val[0]))) + + return val + + def __getstate__(self): + """Magic method for pickle saving. + + Returns: + dict: Saved dict for resuming + """ + return {'tuning_history': self.tuning_history} + + def __setstate__(self, d): + """Magic method for pickle loading. + + Args: + d (dict): The dict to load. + """ + self.__dict__.update(d) + + def stop(self, timeout, trials_count): + """Check if need to stop traverse. + + Check if need to stop traversing the tuning space, either accuracy goal is met or timeout is reach. + + Returns: + bool: True if need stop, otherwise False + """ + need_stop = False + if self.cfg.tuning.exit_policy.performance_only or \ + self.objectives.compare(self.best_tune_result, self.baseline): + self.best_tune_result = self.last_tune_result + self.best_qmodel = self.last_qmodel + self.best_tuning_cfg = copy.deepcopy(self.last_tune_cfg) + logger.debug(f"*** Update the best qmodel with the result {self.best_tune_result}") + if self.metric_met_point == 0: + self.metric_met_point = self.tuning_times + + # track the model with highest acc + if self.best_tune_result and self.last_tune_result: # (acc, [perf]) + if self.re_quant and self.objectives.accuracy_meets(): + self.best_tune_result = self.last_tune_result + self.best_qmodel = self.last_qmodel + self.best_tuning_cfg = copy.deepcopy(self.last_tune_cfg) + logger.debug(f"*** Update the best qmodel with the result {self.best_tune_result}.") + else: + logger.debug(f"*** Accuracy not meets the requirements, do not update the best qmodel.") + + if self.last_tune_result: + last_tune = self.last_tune_result[0] if \ + isinstance(self.last_tune_result[0], list) else [self.last_tune_result[0]] + + for name, data in zip(self.metric_name, last_tune): + if len(self.tune_data[name]) == 1: + self.tune_data[name].append(data) + else: + self.tune_data[name][1] = data + + if self.metric_weight and len(last_tune) > 1: + weighted_acc = np.mean(np.array(last_tune) * self.metric_weight) + + if len(self.tune_data['Weighted accuracy']) == 1: + self.tune_data['Weighted accuracy'].append(weighted_acc) + else: + self.tune_data['Weighted accuracy'][1] = weighted_acc + + last_tune = [weighted_acc] + + last_tune_msg = '[Accuracy (int8|fp32):' + \ + ''.join([' {:.4f}|{:.4f}'.format(last, base) for last, base in \ + zip(last_tune, self.tune_data['baseline'])]) + \ + ''.join([', {} (int8|fp32): {:.4f}|{:.4f}'.format( \ + x, y, z) for x, y, z in zip( \ + self.objectives.representation, self.last_tune_result[1], self.baseline[1]) \ + if x != 'Accuracy']) + ']' + else: # pragma: no cover + last_tune_msg = 'n/a' + for name in self.tune_data.keys() - {'baseline'}: + if len(self.tune_data[name]) == 1: + self.tune_data[name].append('n/a') + else: + self.tune_data[name][1] = 'n/a' + + if self.best_tune_result: + best_tune = self.best_tune_result[0] if isinstance(self.best_tune_result[0], list) \ + else [self.best_tune_result[0]] + + for name, data in zip(self.metric_name, best_tune): + if len(self.tune_data[name]) == 2: + self.tune_data[name].append(data) + else: + self.tune_data[name][2] = data + + if self.metric_weight and len(best_tune) > 1: + weighted_acc = np.mean(np.array(best_tune) * self.metric_weight) + + if len(self.tune_data['Weighted accuracy']) == 2: + self.tune_data['Weighted accuracy'].append(weighted_acc) + else: # pragma: no cover + self.tune_data['Weighted accuracy'][2] = weighted_acc + + best_tune = [weighted_acc] + + best_tune_msg = '[Accuracy:' + ''.join([' {:.4f}'.format(best) \ + for best in best_tune]) + ''.join([', {}: {:.4f}'.format(x,y) \ + for x,y in zip(self.objectives.representation, \ + self.best_tune_result[1]) if x != 'Accuracy']) + ']' + + else: + best_tune_msg = 'n/a' + for name in self.tune_data.keys() - {'baseline'}: + if len(self.tune_data[name]) == 2: + self.tune_data[name].append('n/a') + else: + self.tune_data[name][2] = 'n/a' + + logger.info("Tune {} result is: {}, Best tune result is: {}".format(trials_count, + last_tune_msg, + best_tune_msg)) + output_data = [[info_type, + '{:.4f} '.format(self.tune_data[info_type][0]) if \ + not isinstance(self.tune_data[info_type][0], str) else self.tune_data[info_type][0], + '{:.4f} '.format(self.tune_data[info_type][1]) if \ + not isinstance(self.tune_data[info_type][1], str) else self.tune_data[info_type][1], + '{:.4f} '.format(self.tune_data[info_type][2]) if \ + not isinstance(self.tune_data[info_type][2], str) else self.tune_data[info_type][2]] \ + for info_type in self.tune_data.keys() if info_type != 'baseline'] + + output_data.extend([[obj, + '{:.4f} '.format(self.baseline[1][i]) if self.baseline else 'n/a', + '{:.4f} '.format(self.last_tune_result[1][i]) if self.last_tune_result else 'n/a', + '{:.4f} '.format(self.best_tune_result[1][i]) if self.best_tune_result else 'n/a'] \ + for i, obj in enumerate(self.objectives.representation)]) + self.tuning_result_data = output_data + Statistics(output_data, + header='Tune Result Statistics', + field_names=['Info Type', 'Baseline', 'Tune {} result'.format(trials_count), \ + 'Best tune result']).print_stat() + + + if self.cfg.tuning.exit_policy.performance_only: + need_stop = True + elif timeout == 0 and self.best_tune_result: + need_stop = True + elif trials_count >= self.cfg.tuning.exit_policy.max_trials: + need_stop = True + else: + need_stop = False + + return need_stop + + def _save(self): + """Save current tuning state to snapshot for resuming.""" + logger.info("Save tuning history to {}.".format(self.history_path)) + with fault_tolerant_file(self.history_path) as f: + pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) + + def _find_tuning_history(self, tune_cfg): + """Check if the specified tune_cfg is evaluated or not on same yaml config. + + Args: + tune_cfg (dict): The tune_cfg to check if evaluated before. + + Returns: + tuning_history or None: The tuning history containing evaluated tune_cfg. + """ + for tuning_history in self.tuning_history: + # only check if a tune_cfg is evaluated under same yam config, excluding + # some fields in tuning section of yaml, such as tensorboard, snapshot, resume. + if self._same_yaml(tuning_history['cfg'], self.cfg): + for history in tuning_history['history']: + if history and history['tune_cfg'] == tune_cfg: + return tuning_history + + return None + + def _find_history(self, tune_cfg): + """Check if the specified tune_cfg is evaluated or not on same yaml config. + + Returns: + history or None: The history containing evaluated tune_cfg. + """ + for tuning_history in self.tuning_history: + # only check if a tune_cfg is evaluated under same yam config, excluding + # some fields in tuning section of yaml, such as tensorboard, snapshot, resume. + if self._same_yaml(tuning_history['cfg'], self.cfg): + for history in tuning_history['history']: + if history and history['tune_cfg'] == tune_cfg: + return history + return None + + def _find_self_tuning_history(self): + """Find self history dict. + + Returns: + history or None: The history for self. + """ + for tuning_history in self.tuning_history: + # only check if a tune_cfg is evaluated under same yam config, excluding + # some fields in tuning section of yaml, such as tensorboard, snapshot, resume. + if self._same_yaml(tuning_history['cfg'], self.cfg): + return tuning_history + + return None + + def _add_tuning_history(self, tune_cfg=None, tune_result=None, **kwargs): + """Add tuning config to tuining history. + + Note this record is added under same yaml config. + """ + found = False + d = {'tune_cfg': tune_cfg, 'tune_result': tune_result} + for tuning_history in self.tuning_history: + if self._same_yaml(tuning_history['cfg'], self.cfg): + d.update(kwargs) + tuning_history['history'].append(d) + tuning_history['last_tune_result'] = self.last_tune_result + tuning_history['best_tune_result'] = self.best_tune_result + tuning_history['cfg'] = self.cfg + found = True + break + + if not found: + tuning_history = {} + tuning_history['version'] = __version__ + tuning_history['cfg'] = self.cfg + tuning_history['baseline'] = self.baseline + tuning_history['last_tune_result'] = self.last_tune_result + tuning_history['best_tune_result'] = self.best_tune_result + tuning_history['history'] = [] + if tune_cfg and tune_result: + d.update(kwargs) + tuning_history['history'].append(d) + self.tuning_history.append(tuning_history) + + self._save() + + def _collect_ops_by_quant_mode(self, tune_cfg, quant_mode): + ops_lst = [] + for op_info, op_config in tune_cfg.items(): + if isinstance(op_config, OpTuningConfig) and quant_mode in op_config.op_quant_mode: + ops_lst.append(op_info) + return ops_lst + + def _diagnosis(self): + import logging + logger = logging.getLogger("neural_compressor") + iteration_list = self.cfg.tuning.diagnosis.iteration_list + inspect_type = self.cfg.tuning.diagnosis.inspect_type + save_to_disk = self.cfg.tuning.diagnosis.save_to_disk + save_path = self.cfg.tuning.diagnosis.save_path + inspect_node_lst, updated_cfg = self.adaptor.diagnosis_helper(self._fp32_model, + self.last_qmodel, + self.tune_cfg, + save_path = save_path) + op_list = self.cfg.tuning.diagnosis.op_list + if not op_list: + op_list = list(inspect_node_lst) + else: + op_list = list(set(op_list).intersection(inspect_node_lst)) + + logger.debug(f'*** Start to inspect tensor :{op_list} in fp32 model.') + self.adaptor.inspect_tensor(self._fp32_model, + dataloader=self.calib_dataloader, + op_list=op_list, + iteration_list=iteration_list, + inspect_type=inspect_type, + save_to_disk=save_to_disk, + save_path= save_path + '/fp32/', + quantization_cfg=updated_cfg) + + logger.debug(f'*** Start to inspect tensor :{op_list} in quantized model.') + self.adaptor.inspect_tensor(self.last_qmodel, + dataloader=self.calib_dataloader, + op_list=op_list, + iteration_list=iteration_list, + inspect_type=inspect_type, + save_to_disk=save_to_disk, + save_path= save_path + '/quan/', + quantization_cfg=updated_cfg) diff --git a/neural_compressor/experimental/strategy/utils/__init__.py b/neural_compressor/experimental/strategy/utils/__init__.py new file mode 100644 index 00000000000..1b730c7ded2 --- /dev/null +++ b/neural_compressor/experimental/strategy/utils/__init__.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Intel Neural Compressor Strategy Utils.""" + +from .tuning_sampler import TuningSampler, OpWiseTuningSampler, OpTypeWiseTuningSampler, FallbackTuningSampler +from .tuning_structs import OpTuningConfig +from .tuning_space import TuningItem, TuningSpace diff --git a/neural_compressor/experimental/strategy/utils/constant.py b/neural_compressor/experimental/strategy/utils/constant.py new file mode 100644 index 00000000000..9cbeaa00859 --- /dev/null +++ b/neural_compressor/experimental/strategy/utils/constant.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Strategy constant.""" + +PRECISION_SET = {'bf16', 'fp16' , 'fp32',} +QUANT_MODE_SET = {'static', 'dynamic'} +QUNAT_BIT_SET = {'int8', 'uint8', 'int4', 'uint4'} + +TUNING_ITEMS_LST = [('activation','scheme'), ('activation','algorithm'), ('activation','granularity'), + ('weight','scheme'), ('weight','algorithm'), ('weight','granularity'), 'sampling_size'] + +PRECISION_SET_V2_0 = {'fp32', 'bf16'} + +auto_query_order = ['static', 'dynamic', 'bf16', 'fp16', 'fp32'] +static_query_order = ['static', 'bf16', 'fp16', 'fp32'] +dynamic_query_order = ['dynamic', 'bf16', 'fp16', 'fp32'] + + +FALLBACK_RECIPES_SET = {'first_conv_or_matmul_quantization', 'last_conv_or_matmul_quantization' \ + 'pre_post_process_quantization'} \ No newline at end of file diff --git a/neural_compressor/experimental/strategy/utils/tuning_sampler.py b/neural_compressor/experimental/strategy/utils/tuning_sampler.py new file mode 100644 index 00000000000..63984f600dd --- /dev/null +++ b/neural_compressor/experimental/strategy/utils/tuning_sampler.py @@ -0,0 +1,463 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tuning sampler.""" + +from itertools import product +import copy +from collections import deque, OrderedDict, defaultdict +from typing import List, Dict, Any +from .tuning_space import TuningSpace, pattern_to_internal, pattern_to_path, quant_mode_from_pattern +from .tuning_structs import OpTuningConfig +from ....utils import logger + +TUNING_ITEM_PRIORITY = [('activation','scheme'), ('activation','algorithm'),('activation','granularity'), + ('activation','compute_dtype'), ('weight','scheme'), ('weight','algorithm'), \ + ('weight','granularity')] + + + +class TuningSamplerRegistry: + """Class decorator used to register all TuningSampler subclasses.""" + + sampler_dict = {} + + @classmethod + def register(cls, name): + """Register new tuning sampler. + + Args: + name: the name of new tuning sampler. + """ + def decorator(sampler): + assert name not in cls.sampler_dict, "Cannot have two sampler with the same name." + cls.sampler_dict[name] = sampler + return decorator + +class TuningOrder: + """Not displayed in API Docs.""" + + def __init__(self): + """For future use.""" + pass + + +class TuningSampler: + """Not displayed in API Docs. + + Basic class of tuning sampler. + """ + + def __init__(self, + tuning_space: TuningSpace, + tuning_order_lst: List[TuningOrder], + initial_op_tuning_cfg: Dict, + kwargs: Dict = {}): + """Init tuning sampler. + + Args: + tuning_space: The tuning space. + tuning_order_lst: The traverse orders. + initial_op_tuning_cfg: The initialized tuning config. + kwargs: other args. + """ + self.tuning_space = tuning_space + self.tuning_order_lst = tuning_order_lst + self.initial_op_tuning_cfg = initial_op_tuning_cfg + self.queue = deque() + # (op_name, op_type): [full_path1, full_path2,...] + self.op_complete_path = {} + + def __iter__(self, tune_cfg=None): + """Interface for generate the next tuning config.""" + pass + + def _set_dtype(self, op_name_type, config_args): + has_weight = op_name_type in self.tuning_space.ops_attr['weight'] + path = self.op_complete_path[op_name_type].get('activation', None) + config_args['activation_dtype'] = self.tuning_space.ops_data_type[op_name_type][path] + if has_weight: + path = self.op_complete_path[op_name_type].get('weight', None) + config_args['weight_dtype'] = self.tuning_space.ops_data_type[op_name_type][path] + + +class ModelWiseTuningSampler(TuningSampler): + """Not displayed in API Docs.""" + + def __init__(self, + tuning_space: TuningSpace, + tuning_items_priority: List[str], + tuning_order_lst: List[TuningOrder], + op_dtype_dict: Dict[tuple, str], + initial_op_tuning_cfg: Dict[tuple, OpTuningConfig]): + """Model type wise tuning sampler. + + step1. create a default tuning config for each op + step2. collect all tuning items and options, and build the model-wise traverse order + step3. yield the tuning item with option one by one, query the existence of tuning item + and specific option for one op if exist, use the default tuning config if not exist + + Args: + tuning_space: Tuning space. + tuning_items_priority: The priority to traverse the tuning items. + tuning_order_lst: The tuning orders. + op_dtype_dict: The (op name, op type) and its target data type. + initial_op_tuning_cfg: The initial tuning config. + + """ + super().__init__(tuning_space, tuning_order_lst, initial_op_tuning_cfg) + + self.op_dtype_dict = op_dtype_dict + self.tuning_space = tuning_space + self.default_op_config = {} + tuning_items = defaultdict(set) # item name: options + for op_name_type, quant_mode in op_dtype_dict.items(): + full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, quant_mode) + self.op_complete_path[op_name_type] = copy.deepcopy(full_path) + # step1, set the default config for each op + self.default_op_config[op_name_type] = tuning_space.get_default_config(op_name_type, quant_mode) + if quant_mode[0] == 'precision': continue + mode_items = copy.deepcopy(full_path) # TODO refactor the initialization method + # step2, collect all tuning items and their options + for att in mode_items: + if att not in full_path: continue + quant_mode_item = self.tuning_space.query_quant_mode_item_by_full_path(op_name_type ,full_path[att]) + for tuning_item in quant_mode_item.options: + tuning_items[tuning_item.name] = tuning_items[tuning_item.name].union(tuning_item.options) + self.tuning_items = tuning_items + + def __iter__(self): + """Yield the next tuning config. + + Yields: + The next tuning config. + """ + keys = self.tuning_items.keys() + for vals in product(*self.tuning_items.values()): + # traverse all possible combinations by model-wise level + tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg) + for op_name_type, quant_mode in self.op_dtype_dict.items(): + if quant_mode[0] == 'precision': continue + all_exist_flag = True + for method_name, method_val in zip(keys, vals): + full_path = self.op_complete_path[op_name_type] + if method_name[0] not in full_path: continue + if not self.tuning_space.query_item_option(op_name_type, + full_path[method_name[0]], + method_name, method_val): + all_exist_flag = False + tune_cfg[op_name_type] = self.default_op_config[op_name_type] + break + if all_exist_flag: + config_args = dict(zip(keys, vals)) + self._set_dtype( op_name_type, config_args) + internal_pattern = pattern_to_internal(quant_mode) + quant_mode = quant_mode_from_pattern(internal_pattern) + tune_cfg[op_name_type] = OpTuningConfig(op_name_type[0], + op_name_type[1], + quant_mode, + self.tuning_space, + kwargs=config_args) + yield tune_cfg + + +class OpTypeWiseTuningSampler(TuningSampler): + """Not displayed in API Docs.""" + + def __init__(self, + tuning_space: TuningSpace, + tuning_items_priority: List[str], + tuning_order_lst: List[TuningOrder], + op_dtype_dict: Dict[tuple, str], + initial_op_tuning_cfg: Dict[tuple, OpTuningConfig]): + """Op type wise tuning sampler. + + Args: + tuning_space: Tuning space. + tuning_items_priority: The priority to traverse the tuning items. + tuning_order_lst: The tuning orders. + op_dtype_dict: The (op name, op type) and its target data type. + initial_op_tuning_cfg: The initial tuning config. + """ + super().__init__(tuning_space, tuning_order_lst, initial_op_tuning_cfg) + tuning_items_priority = TUNING_ITEM_PRIORITY + # (op_type, quant_mode) : {tuning_item_name : [option1, option2]} + # {('activation', 'scheme'): ['sym', 'sym'], ('activation', 'algorithm'): ['minmax', 'kl', 'minmax', 'kl']} + + self.optype_quant_mode_option = {} + self.optype_quant_mode_items_name = defaultdict(list) + self.op_type_quant_mode_wise_combination = {} + self.op_dtype_dict = op_dtype_dict + self.default_op_config = {} + + for op_name_type, quant_mode in op_dtype_dict.items(): + full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, quant_mode) + self.op_complete_path[op_name_type] = copy.deepcopy(full_path) + self.default_op_config[op_name_type] = self.tuning_space.get_default_config(op_name_type, quant_mode) + op_name, op_type = op_name_type + if quant_mode[0] == 'precision': continue + mode_items = copy.deepcopy(full_path) # TODO refactor the initialization method + op_type_quant_mode = (op_type, quant_mode) + filtered_tuning_items = [] + for item_name in tuning_items_priority: + att, method_name = item_name + if att not in mode_items: + continue + quant_mode_item = self.tuning_space.query_quant_mode_item_by_full_path(op_name_type ,full_path[att]) + item = quant_mode_item.get_option_by_name(item_name) + if item: + if op_type_quant_mode not in self.optype_quant_mode_option: + self.optype_quant_mode_option[op_type_quant_mode] = defaultdict(list) + self.optype_quant_mode_option[op_type_quant_mode][item_name] += item.options + filtered_tuning_items.append(item) + self.optype_quant_mode_items_name[op_type_quant_mode] = filtered_tuning_items + + for op_type_quant_mode, val in self.optype_quant_mode_option.items(): + options_lst = [] + # remove the duplicate options + for _, item_options in val.items(): + seen = set() + filter_options = [option for option in item_options if not (option in seen or seen.add(option))] + options_lst.append(filter_options) + op_type_quant_mode_vals = product(*options_lst) + self.op_type_quant_mode_wise_combination[op_type_quant_mode] = op_type_quant_mode_vals + + def __iter__(self): + """Yield the next tuning config. + + Yields: + The next tuning config. + """ + new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg) + for options_lst in product(*self.op_type_quant_mode_wise_combination.values()): + for index, op_type_quant_mode in enumerate(self.op_type_quant_mode_wise_combination.keys()): + for op_name_type, quant_mode in self.op_dtype_dict.items(): + if op_name_type[1] == op_type_quant_mode[0] and quant_mode == op_type_quant_mode[1]: + op_tuning_items = [item.name for item in \ + self.optype_quant_mode_items_name[op_type_quant_mode]] + op_tuning_item_vals = options_lst[index] + all_exist_flag = True + for method_name, method_val in zip(op_tuning_items, op_tuning_item_vals): + full_path = self.op_complete_path[op_name_type] + if not self.tuning_space.query_item_option(op_name_type, + full_path[method_name[0]], + method_name, + method_val): + all_exist_flag = False + op_tuning_config = self.default_op_config[op_name_type] + break + if all_exist_flag: + config_args = dict(zip(op_tuning_items, op_tuning_item_vals)) + self._set_dtype( op_name_type, config_args) + internal_pattern = pattern_to_internal(quant_mode) + quant_mode = quant_mode_from_pattern(internal_pattern) + op_tuning_config = OpTuningConfig(op_name_type[0], + op_name_type[1], + quant_mode, + self.tuning_space, + kwargs=config_args) + new_tune_cfg.update({op_name_type: op_tuning_config}) + yield new_tune_cfg + +class OpWiseTuningSampler(TuningSampler): + """Not displayed in API Docs.""" + + def __init__(self, + tuning_space: TuningSpace, + tuning_items_priority: List[str], + tuning_order_lst: List[TuningOrder], + op_dtype_dict: Dict[tuple, str], + initial_op_tuning_cfg: Dict): + """Op wise tuning config sampler. + + Args: + tuning_space: Tuning space. + tuning_items_priority: The priority to traverse the tuning items. + tuning_order_lst: The tuning orders. + op_dtype_dict: The (op name, op type) and its target data type. + initial_op_tuning_cfg: The initial tuning config. + """ + super().__init__(tuning_space, tuning_order_lst, initial_op_tuning_cfg) + tuning_items_priority = TUNING_ITEM_PRIORITY + # query the combination of tuning items with according to the tuning items priority + self.op_dtype_dict = op_dtype_dict + self.op_options_combination = OrderedDict() + self.op_tuning_items = {} + for op_name_type, op_quant_mode in op_dtype_dict.items(): + full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, op_quant_mode) + self.op_complete_path[op_name_type] = copy.deepcopy(full_path) + mode_items = copy.deepcopy(full_path) + internal_pattern = pattern_to_internal(op_quant_mode) + op_quant_mode = quant_mode_from_pattern(internal_pattern) + if internal_pattern[0] == 'precision': continue + filtered_tuning_items = [] + for item_name in tuning_items_priority: + att, method_name = item_name + if att not in mode_items: + continue + quant_mode_item = self.tuning_space.query_quant_mode_item_by_full_path(op_name_type ,full_path[att]) + item = quant_mode_item.get_option_by_name(item_name) + if item: + filtered_tuning_items.append(item) + self.op_tuning_items[op_name_type] = filtered_tuning_items + op_options_lst = product(*[item.options for item in filtered_tuning_items]) + self.op_options_combination[op_name_type] = op_options_lst + + def __iter__(self): + """Yield the next tuning config. + + Yields: + The next tuning config. + """ + new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg) + for op_options_lst in product(*self.op_options_combination.values()): + for index, op_name_type in enumerate(self.op_options_combination.keys()): + op_quant_mode = self.op_dtype_dict[op_name_type] + op_tuning_items = [item.name for item in self.op_tuning_items[op_name_type]] + op_tuning_item_vals = op_options_lst[index] + config_args = dict(zip(op_tuning_items, op_tuning_item_vals)) + self._set_dtype(op_name_type, config_args) + internal_pattern = pattern_to_internal(op_quant_mode) + quant_mode = quant_mode_from_pattern(internal_pattern) + op_tuning_config = OpTuningConfig(op_name_type[0], op_name_type[1], + quant_mode, self.tuning_space, + kwargs=config_args) + new_tune_cfg.update({op_name_type: op_tuning_config}) + yield new_tune_cfg + + def get_opwise_candidate(self): + """Collect all op-wise setting. + + Returns: + op_wise_configs: all op-wise setting. + """ + op_wise_configs = OrderedDict() + for op_name_type, op_quant_mode in self.op_dtype_dict.items(): + # For static/dynamic/fp32/bf16 + internal_pattern = pattern_to_internal(op_quant_mode) + quant_mode = quant_mode_from_pattern(internal_pattern) + full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, op_quant_mode) + self.op_complete_path[op_name_type] = copy.deepcopy(full_path) + op_wise_configs[op_name_type] = [] + # For precision + if internal_pattern[0] == 'precision': + config_args = {} + self._set_dtype(op_name_type, config_args) + op_tuning_config = OpTuningConfig(op_name_type[0], op_name_type[1], + quant_mode, self.tuning_space, + kwargs=config_args) + op_wise_configs[op_name_type].append(op_tuning_config) + continue + # For quantization + op_tuning_items = [item.name for item in self.op_tuning_items.get(op_name_type, [])] + op_options = self.op_options_combination[op_name_type] + + for op_tuning_item_vals in op_options: + config_args = dict(zip(op_tuning_items, op_tuning_item_vals)) + self._set_dtype( op_name_type, config_args) + op_tuning_config = OpTuningConfig(op_name_type[0], op_name_type[1], + quant_mode, self.tuning_space, + kwargs=config_args) + op_wise_configs[op_name_type].append(op_tuning_config) + return op_wise_configs + +class FallbackTuningSampler(TuningSampler): + """Not displayed in API Docs.""" + + def __init__(self, + tuning_space: TuningSpace, + tuning_order_lst: List[TuningOrder], + initial_op_tuning_cfg: Dict[tuple, Any], + op_dtypes: Dict[str, str], + accumulate: bool, + skip_first: bool = True + ): + """Sampler for generate the tuning config of fallback stage. + + Args: + tuning_space: Tuning space. + tuning_order_lst: The tuning orders. + initial_op_tuning_cfg: The initial tuning config. + op_dtypes: The (op name, op type) and its target data type. + accumulate: Fallback accumulated or not. + skip_first: Skip fallback the first op or not. Defaults to True. + """ + super().__init__(tuning_space, tuning_order_lst, initial_op_tuning_cfg) + self.op_dtypes = op_dtypes + self.accumulate = accumulate + self.skip_first = skip_first + + def __iter__(self): + """Yield the next tuning config. + + Yields: + The next tuning config. + """ + new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg) + skip_first = self.skip_first + for op_name_type, target_dtype in self.op_dtypes.items(): + # Only support fallback to lower precision. + if not self.accumulate: + new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg) + full_path = self.tuning_space.get_op_default_path_by_pattern(op_name_type, target_dtype) + self.op_complete_path[op_name_type] = copy.deepcopy(full_path) + config_args = {} + self._set_dtype(op_name_type, config_args) + internal_pattern = pattern_to_internal(target_dtype) + quant_mode = quant_mode_from_pattern(internal_pattern) + new_op_config = OpTuningConfig(op_name_type[0], op_name_type[1], + quant_mode, self.tuning_space, + kwargs=config_args) + + new_tune_cfg.update({op_name_type: new_op_config}) + if self.accumulate and skip_first: # skip the first one + skip_first = False + continue + logger.debug(f"fallback {op_name_type} to {target_dtype}") + yield new_tune_cfg # need to skip the first one + +@TuningSamplerRegistry.register("smooth_quant") +class SmoothQuantSampler(TuningSampler): + """Sampler for the hyperparameter tuning of smooth quantization.""" + + def __init__(self, + tuning_space: TuningSpace, + tuning_order_lst: List[TuningOrder], + initial_op_tuning_cfg: Dict, + kwargs: Dict ={}): + """Initialize the sampler.""" + super().__init__(tuning_space, tuning_order_lst, initial_op_tuning_cfg, kwargs) + # TODO use the alpha list specified by user + self._kwargs = kwargs + self._alpha_lst = [0.5] + if kwargs.get('smooth_quant_agrs', {}): + self._alpha_lst = kwargs['smooth_quant_agrs'].get('alpha_lst', [0.5]) + + def __iter__(self, tune_cfg=None) -> OpTuningConfig: + """Yield the next tuning config with update alpha. + + Args: + tune_cfg: tuning config. Defaults to None. + """ + for alpha in self._alpha_lst: + new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg) if not tune_cfg else copy.deepcopy(tune_cfg) + sq_args = {'smooth_quant': True, 'smooth_quant_args': {'alpha': alpha}} + if 'recipe_cfgs' not in new_tune_cfg: + new_tune_cfg['recipe_cfgs'] = sq_args + else: + new_tune_cfg['recipe_cfgs'].update(sq_args) + yield new_tune_cfg \ No newline at end of file diff --git a/neural_compressor/experimental/strategy/utils/tuning_space.py b/neural_compressor/experimental/strategy/utils/tuning_space.py new file mode 100644 index 00000000000..6ea1998dbb8 --- /dev/null +++ b/neural_compressor/experimental/strategy/utils/tuning_space.py @@ -0,0 +1,728 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tuning space.""" + +from collections import defaultdict, OrderedDict +import os +import re +from typing import Dict, Tuple +from copy import deepcopy +from ....utils import logger +from .utility import OrderedDefaultDict +from .tuning_structs import OpTuningConfig + +from .constant import TUNING_ITEMS_LST + +class TuningItem: + """Not displayed in API Docs.""" + + def __init__(self, name, options=[], item_type=None): + """Init the tuning item. + + Args: + name: tuning item name. + options: The options. Defaults to []. + item_type: The item type. Defaults to None. + """ + self.name = name + self._options = options + self.item_type = item_type + + @property + def options(self): + """Return all options. + + Returns: + All options. + """ + return self._options + + def get_options_name(self): + """Return the name list of the options.""" + return [o.name for o in self.options] + + def append(self, option): + """Append option. + + Args: + option: The option to add. + """ + self._options.append(option) + + def remove(self, option): + """Remove option. + + Args: + option: The option to remove. + """ + if option in self._options: + self._options.remove(option) + + def get_option_by_name(self, option_name): + """Get the option item by name. + + Args: + option_name: option name. + + Returns: + option: the queried option. + """ + for option in self.options: + if isinstance(option, TuningItem) and option.name == option_name: + return option + return None + + def get_details(self, depth=0): + """Get the tuning item and its options recursively. + + Args: + depth: recursion depth. Defaults to 0. + + Returns: + The tuning item and its options as a string. + """ + details = ['\t' * depth + f"{self.name}, {self.item_type}"] + for option in self.options: + if isinstance(option, int) or isinstance(option, str): + details.append("\t" * depth + str(option)) + else: + details.append(option.get_details(depth + 1)) + return "\n".join(details) + + +class TuningSpace: + """Not displayed in API Docs. + + 1) capability -> internal format -> merge -> tuning space (tree) + + """ + + def __init__(self, capability, conf, framework=None): + """Init the tuning space. + + Args: + capability: framework capability. + conf: user configuration + framework: framework name. Defaults to None. + """ + self.capability = capability + self.conf = conf + self.root_item = TuningItem(name='root', options=[], item_type='root') + self.quant_mode_wise_items = defaultdict(list) # quant_mode/precision_name: {(op_name, op_type),...} + self.op_type_wise_items = defaultdict(list) # op_type: {(op_name, op_type), ...} + self.framework = framework + self.ops_dtype = defaultdict(OrderedDict) + usr_cfg = conf.usr_cfg if conf else None + self.op_items = {} + # {(op_name, op_type): {(path): data type}} + self.ops_data_type = OrderedDefaultDict() + self.ops_attr = {'activation': set(), 'weight': set()} + # {(op_name, op_type): {path1, path2, ...} + self.ops_path_set = defaultdict(set) + + self._create_tuning_space(capability, usr_cfg) + + def _parse_capability(self, capability: Dict) -> None: + """Parse the capability and construct the tuning space(a tree). + + Args: + capability: merged framework capability. + """ + calib = TuningItem(name='calib_sampling_size', + options=capability['calib']['calib_sampling_size'], + item_type='calib_sampling_size') + self.root_item.append(calib) + def _parse(cap, root, path, op_name_type): + if isinstance(cap, dict): + for key, val in cap.items(): + if isinstance(val, dict): + if len(path) > 1 and path[-2] == 'precision': + self.ops_path_set[op_name_type].add(tuple(path + [key])) + tuning_item = TuningItem(name=key, options=[], item_type=key) + root.append(tuning_item) + _parse(val, tuning_item, path + [key], op_name_type) + elif isinstance(val, list): + new_key = ('activation', key) if 'activation' in path else ('weight', key) + tuning_item = TuningItem(name=new_key, options=val, item_type='method') + self.ops_path_set[op_name_type].add(tuple(path)) + root.append(tuning_item) + else: + return + + for op_name_type, op_cap in capability['op'].items(): + op_name, op_type = op_name_type + op_item = TuningItem(name=op_name_type, options=[], item_type='op') + self.op_type_wise_items[op_type].append(op_item) + self.root_item.append(op_item) + self.op_items[op_name_type] = op_item + _parse(op_cap, op_item, [], op_name_type) + for q_option in op_item.options: + if q_option and q_option.name == 'precision': + acc_item = q_option.get_option_by_name('activation') + if acc_item and acc_item.options: + for dtype_item in acc_item.options: + self.quant_mode_wise_items[dtype_item.name].append(op_item) + else: + self.quant_mode_wise_items[q_option.name].append(op_item) + + def _create_tuning_item(self, tuning_items: Dict, attr_name: str, quant_mode_item: TuningItem): + for tuning_item_name, options in tuning_items.items(): + if tuning_item_name not in ['dtype', 'quant_mode']: + name = (attr_name, tuning_item_name) + tuning_item = TuningItem(name=name, options=options, item_type=name) + quant_mode_item.append(tuning_item) + + def _merge_op_cfg(self, cur_op_cap, op_user_cfg, fw_op_cap): + """Merge the op cfg with user cfg. + + op_user_cfg:{ + 'activation':{ + 'dtype': ['fp32'] + }, + 'weight':{ + 'dtype': ['fp32'] + } + } + + Step1. merge dtype, get the intersection between fw_op_cap and op_user_cfg. + Step2. merge method options. + + # if dtype and type intersection with precision set -> only keep the intersection precision + # and remove the quantization. + # else(no dtype, or no intersection) -> merge the method + + Args: + cur_op_cap: current capability. + op_user_cfg: The user capability. + fw_op_cap: The fwk capability(baseline). + + Returns: + Return the merged capability. + """ + from .utility import extract_data_type, reverted_data_type + fw_op_cap = deepcopy(fw_op_cap) + new_op_cap = deepcopy(cur_op_cap) + for att in ['activation', 'weight']: + if op_user_cfg.get(att, None) is not None: + user_dtype_lst = op_user_cfg[att]['dtype'] if op_user_cfg[att]['dtype'] is not None else [] + # Merge the precision part. + fwk_att_precision_cap = fw_op_cap['precision'].get(att, {}) + fwk_precision_set = set(fwk_att_precision_cap.keys()) + # The intersection of user cfg and fwk capability. + valid_precision_set = set(fwk_precision_set).intersection(set(user_dtype_lst)) + if len(valid_precision_set) != 0: + new_op_cap = dict(filter(lambda item: item[0] == 'precision', new_op_cap.items())) + new_op_cap['precision'][att] = dict(filter(lambda item: item[0] in valid_precision_set,\ + fw_op_cap['precision'][att].items())) + else: + # Filter the valid options for tuning item + for quant_mode in fw_op_cap: + if quant_mode not in new_op_cap: + new_op_cap[quant_mode] = deepcopy(fw_op_cap[quant_mode]) + if quant_mode == 'precision': continue + for data_type in new_op_cap[quant_mode][att]: + for signed_flag in new_op_cap[quant_mode][att][data_type]: + cur_items = new_op_cap[quant_mode][att][data_type][signed_flag] + fwk_items = fw_op_cap[quant_mode][att][data_type][signed_flag] + for method_name, method_options in op_user_cfg[att].items(): + if method_name not in ['dtype', 'quant_mode'] and method_options: + # filter the method options + options_intersection = set(fwk_items[method_name]\ + ).intersection(set(method_options)) + # merge with fwk, if intersection -> use intersection + if len(options_intersection) > 0: + cur_items[method_name] = [option for option in fwk_items[method_name] if\ + option in options_intersection] + return new_op_cap + + def _merge_optype_wise_cfg(self, cap: Dict, optype_wise_usr_cfg: Dict, fw_cap: Dict): + for op_type, op_user_cfg in optype_wise_usr_cfg.items(): + op_lst = [op_name_type for op_name_type in cap['op'] if op_name_type[1] == op_type] + for op_name_type in op_lst: + cap['op'][op_name_type] = self._merge_op_cfg(cap['op'][op_name_type], + op_user_cfg, + fw_cap['op'][op_name_type]) + + def _merge_model_wise_cfg(self, cap: Dict, model_wise_usr_cfg: Dict, fw_cap: Dict): + for op_name_type in cap['op'].keys(): + cap['op'][op_name_type] = self._merge_op_cfg(cap['op'][op_name_type], + model_wise_usr_cfg, + fw_cap['op'][op_name_type]) + + def _merge_op_wise_cfg(self, cap: Dict, op_wise_usr_cfg: Dict, fw_cap: Dict): + op_name_types = {key[0]: key for key in cap['op'].keys()} + for op_name_pattern, op_user_cfg in op_wise_usr_cfg.items(): + op_name_pattern = re.compile(op_name_pattern) + for op_name in op_name_types: + if op_name_pattern.fullmatch(op_name): + op_name_type = op_name_types[op_name] + cap['op'][op_name_type] = self._merge_op_cfg(cap['op'][op_name_type], + op_user_cfg, + fw_cap['op'][op_name_type]) + + def _merge_with_user_cfg(self, capability: Dict, user_cfg: Dict): + """Merge the capability with user config. + + Merge the capability queried from the adaptor with user config in the order of + model-wise, optype-wise, and op-wise if needed. + The optype-wise user config will override the model-wise user config for their + intersection parts, the same as the op-wise and optype-wise. + + Here is an example: + capability:{ + ('op1','type1'): { + 'item1': [item1_option1, item1_option2, item1_option3], + 'item2': [item2_option1, item2_option2, item2_option3], + } + ('op2','type1'): { + 'item1': [item1_option1, item1_option2, item1_option3], + 'item2': [item2_option1, item2_option2, item2_option3], + } + ('op3','type2'): { + 'item1': [item1_option1, item1_option2], + 'item2': [item2_option1, item2_option2], + } + ('op4','type2'): { + 'item1': [item1_option1, item1_option2], + 'item2': [item2_option1, item2_option2], + } + } + + user_config{ + model-wise:{ + 'item1': [item1_option1] + } + optype-wise: { + 'type1': { + 'item1': [item1_option1, item1_option2] + }} + op-wise: { + ('op3','type2'): { + 'item2': [item2_option1] + }} + } + + # step1. merged with model-wise + capability:{ + ('op1','type1'): { + 'item1': [item1_option1], + 'item2': [item2_option1, item2_option2, item2_option3], + } + ('op2','type1'): { + 'item1': [item1_option1], + 'item2': [item2_option1, item2_option2, item2_option3], + } + ('op3','type2'): { + 'item1': [item1_option1], + 'item2': [item2_option1, item2_option2], + } + ('op4','type2'): { + 'item1': [item1_option1], + 'item2': [item2_option1, item2_option2], + } + } + + # step2. merged with optype-wise + capability:{ + ('op1','type1'): { + 'item1': [item1_option1, item1_option2], + 'item2': [item2_option1, item2_option2, item2_option3], + } + ('op2','type1'): { + 'item1': [item1_option1, item1_option2], + 'item2': [item2_option1, item2_option2, item2_option3], + } + ('op3','type2'): { + 'item1': [item1_option1], + 'item2': [item2_option1, item2_option2], + } + ('op4','type2'): { + 'item1': [item1_option1], + 'item2': [item2_option1, item2_option2], + } + } + + # step3. merged with op-wise + capability:{ + ('op1','type1'): { + 'item1': [item1_option1, item1_option2], + 'item2': [item2_option1, item2_option2, item2_option3], + } + ('op2','type1'): { + 'item1': [item1_option1, item1_option2], + 'item2': [item2_option1, item2_option2, item2_option3], + } + ('op3','type2'): { + 'item1': [item1_option1], + 'item2': [item2_option1], + } + ('op4','type2'): { + 'item1': [item1_option1], + 'item2': [item2_option1, item2_option2], + } + } + :param capability: + :param user_cfg: + :return: + """ + fw_capability = deepcopy(capability) + if user_cfg['model_wise'] is not None: + self._merge_model_wise_cfg(capability, user_cfg['model_wise'], fw_capability) + if user_cfg['optype_wise'] is not None: + self._merge_optype_wise_cfg(capability, user_cfg['optype_wise'], fw_capability) + if user_cfg['op_wise'] is not None: + self._merge_op_wise_cfg(capability, user_cfg['op_wise'], fw_capability) + + def _parse_cap_helper(self, cap): + """Convert the cpa to internal format. + + Parsed result: + (op_name, op_type): + { + 'static':{ + 'act':{ + 'int8':{ + 'signed':{ # (op_name, op_type): ('static', (('int8', 'signed'),(...))) + 'dtype': 'int8', + 'scheme': ['sym'], + 'algorithm': ['minmax', 'kl'], + 'granularity': ['per_channel','per_tensor'], + } + } + 'int4':{ + ... + } + }, + 'weight':{ + 'int8':{ + ... + } + 'int4':{ + 'signed':{ + 'dtype': 'int4' + 'scheme': ['asym'], + ... + } + } + } + }, + 'dynamic':{ + ... + } + 'precision':{ + 'act':{ + 'fp32':{} + 'bf16':{} + }, + 'weight':{ + 'fp32':{ + 'dtype': 'fp32, + }, + 'bf16':{ + 'dtype': 'fp32', + }, + } + + } + } + """ + from .utility import OrderedDefaultDict, extract_data_type + cap = deepcopy(cap) + parsed_cap = OrderedDict() # {(op_name, op_type): parsed_op_cap} + for op_name_type, op_cap_lst in cap.items(): + parsed_op_cap = OrderedDefaultDict() # {ptq_type/precision, {}} + parsed_op_cap['precision'] = OrderedDefaultDict() + # WA for some op have extra weight dtype. + has_weight = all(['weight' in op_cap for op_cap in op_cap_lst]) + if has_weight: self.ops_attr['weight'].add(op_name_type) + for op_cap in op_cap_lst: + if 'activation' in op_cap: + self.ops_attr['activation'].add(op_name_type) + attrs_lst = ['activation', 'weight'] if has_weight else ['activation'] + for att in attrs_lst: + # Parse the data info for item that has options. + if 'activation' in op_cap and 'quant_mode' in op_cap['activation']: + quant_mode = op_cap['activation']['quant_mode'] + att_dtype = op_cap[att]['dtype'][0] + signed_flag, _data_type = extract_data_type(att_dtype) + for item_name, item_options in op_cap[att].items(): + if item_name == 'dtype': + # The dtype should be a string, need to align with fwk.yaml. + self.ops_data_type[op_name_type][(quant_mode, att, _data_type, signed_flag)] = \ + item_options[0] if isinstance(item_options, list) else item_options + if item_name not in ['dtype', 'quant_mode']: + parsed_op_cap[quant_mode][att][_data_type][signed_flag][item_name] = item_options + else: + # Parse the data info for item with unique value. + att_dtype = op_cap[att]['dtype'] + if isinstance(att_dtype, list): + att_dtype = att_dtype[0] + parsed_op_cap['precision'][att][att_dtype] = {'dtype': att_dtype} + self.ops_data_type[op_name_type][('precision', att, att_dtype)] = att_dtype + + parsed_cap[op_name_type] = parsed_op_cap + return parsed_cap + + def _create_tuning_space(self, capability, usr_cfg): + """Create tuning space. + + steo1. convert the capability into internal format. + step2. merge the capability with usr_cfg + step3. create the tuning space + :param capability: + :param usr_cfg: + :return: + """ + capability['op'] = self._parse_cap_helper(deepcopy(capability['op'])) + if usr_cfg: + self._merge_with_user_cfg(capability, usr_cfg['quantization']) + logger.debug(f"*********** After Merged with user cfg ***********") + logger.debug(capability) + self._parse_capability(capability) + + def query_item_option(self, op_name_type, path, method_name, method_val): + """Query the method value, such as scheme, algorithm. + + Args: + op_name_type: (op_name, op_type) + path: full path + method_name: method name + method_val: method value + + Returns: + Return the query result if exist. + """ + mode_item = self.get_item_by_path((op_name_type, *path)) + if not mode_item: return None + method_item = mode_item.get_option_by_name(method_name) + return method_item is not None and method_val in method_item.options + + def get_default_config(self, op_name_type, quant_mode): + """Get the default tuning config. + + Args: + op_name_type: (op_name, op_type) + quant_mode: quantization mode. + + Returns: + op_tuning_config: the default config according to the specified quantization mode. + """ + from .tuning_structs import OpTuningConfig + # For quant_mode static/dynamic/((static, int8), (dynamic, int4)) + # set the first option as the default if the not support the required quant mode + full_path = self.get_op_default_path_by_pattern(op_name_type, quant_mode) + config_args = {} + has_weight = op_name_type in self.ops_attr['weight'] + config_args['activation_dtype'] = self.ops_data_type[op_name_type].get(full_path['activation']) + if has_weight: + config_args['weight_dtype'] = self.ops_data_type[op_name_type].get(full_path['weight']) + for att in full_path: + mode_item = self.query_quant_mode_item_by_full_path(op_name_type ,full_path[att]) + if mode_item: + method_args = {method_item.name: method_item.options[0] for method_item in mode_item.options \ + if method_item.name in TUNING_ITEMS_LST} + config_args.update(method_args) + + quant_mode = quant_mode if isinstance(quant_mode, str) else quant_mode[0] + # set the first option as the default for each tuning item + op_tuning_config = OpTuningConfig(op_name_type[0], + op_name_type[1], + quant_mode, + self, + kwargs=config_args) + return op_tuning_config + + def get_item_by_path(self, path, default=None): + """Get the item according to the path.""" + item = self.root_item + for val in path: + if item is None: + logger.warning(f"Did not found the item according to the path {path}") + return default + item = item.get_option_by_name(val) + if item is None: + logger.warning(f"Did not found the item according to the path {path}") + return item + + def get_default_full_path(self, op_name_type, path): + """Complete the path. + + Args: + op_name_type: (op_name, op_path) + path: incomplete path. + + Returns: + new_path: the complete path. + """ + # For precision + if path[0] == 'precision': + # If the path is ('precision', 'activation', dtype), return it directly. + if len(path) == 3: return path + assert len(path) == 2, f"Got the path: {path}, please provide the path include activation or weight." + att_item = self.get_item_by_path((op_name_type, *path)) + if not att_item or len(att_item.options) == 0: + logger.debug(f"Could not found item for {op_name_type} with path {path}") + return None + dtype = att_item.options[0].name + return (*path, dtype) + else: + # For quantization + assert len(path) >= 2, f"Got the path: {path}, please provide the path include activation or weight." + if path[-1] == None: path = path[:-1] + item = self.get_item_by_path((op_name_type, *path)) + new_path = path + # For path ('static', 'activation', ...) + while item: + item_options = item.options + if len(item_options) > 0 and isinstance(item_options[0], TuningItem) and \ + item_options[0].item_type != 'method': + new_path = new_path + (item_options[0].name,) + item = item_options[0] + else: + break + return new_path + + def query_quant_mode_item_by_full_path(self, op_name_type, path) -> Tuple[TuningItem, Tuple]: + """Query the mode item by full path.""" + new_path = (op_name_type, *path) + item = self.get_item_by_path(new_path) + return item + + def query_items_by_quant_mode(self, quant_mode): + """Collect all op items that support the specified mode. + + Args: + quant_mode: dynamic/static/bf16/fp32/fp16 + + Returns: + The op item set that support quant model. + """ + return self.quant_mode_wise_items.get(quant_mode, []) + + def get_op_default_path_by_pattern(self, op_name_type, pattern): + """Get the default path by quant mode. + + Args: + op_name_type: (op_name, op_type) + pattern: 'static', 'dynamic', ('static', 'int8'), ('precision', 'fp32') + + Returns: + result(Dict): The default full path of activation and weight if have. + """ + internal_pattern = pattern_to_internal(pattern) + full_path = {'activation': None, 'weight': None} + full_path['activation'], full_path['weight'] = pattern_to_path(internal_pattern) + result = {} + has_weight = op_name_type in self.ops_attr['weight'] + att_lst = ['activation', 'weight'] if has_weight else ['activation'] + for att in att_lst: + result[att] = self.get_default_full_path(op_name_type, full_path[att]) + return result + +def get_op_mode_by_query_order(tuning_space: TuningSpace, query_order): + """Get the op mode according to the query order.""" + quant_mode_wise_items = OrderedDict() # mode, op_item_lst + pre_items = set() + # Collect op items supported the specified mode. + for quant_mode in query_order: + items = tuning_space.query_items_by_quant_mode(quant_mode) + filtered_items = list(filter(lambda item: item not in pre_items, items)) + pre_items = pre_items.union(set(items)) + quant_mode_wise_items[quant_mode] = filtered_items + + def initial_op_quant_mode(items_lst, target_quant_mode, op_item_dtype_dict): + for item in items_lst: + op_item_dtype_dict[item.name] = target_quant_mode + op_item_dtype_dict = OrderedDict() + for quant_mode, quant_mode_items in quant_mode_wise_items.items(): + initial_op_quant_mode(quant_mode_items, quant_mode, op_item_dtype_dict) + + return op_item_dtype_dict + +def pattern_to_internal(pattern, default_dtype='int8'): + """Convert pattern to internal format. + + 'static' -> ('static', (('int8'),('int8'))) + 'dynamic' -> ('dynamic', (('int8'),('int8'))) + 'fp32' -> ('precision', (('fp32'), ('fp32'))) + 'bf16' -> ('precision', (('bf16'), ('bf16'))) + ('static', 'int8') -> ('static', (('int8'),('int8'))) + ('dynamic', 'int8') -> ('dynamic', (('int8'),('int8'))) + ('precision', 'fp32') -> ('precision', (('fp32'), ('fp32')))) # (('fp32'), ('fp32')) or ('fp32', 'fp32') + #TODO to add the support for mixed data type of weight and activation + """ + from .constant import PRECISION_SET_V2_0 + pattern_bk = pattern + if isinstance(pattern, str): + pattern = ('precision', pattern) if pattern in PRECISION_SET_V2_0 else (pattern, (None)) + internal_pattern = (pattern[0], ((pattern[1],), (pattern[1],))) + return internal_pattern + +def pattern_to_path(pattern): + """Convert pattern to path.""" + act_path = (pattern[0], 'activation', *pattern[1][0]) + weight_path = (pattern[0], 'weight', *pattern[1][1]) + return act_path, weight_path + +def quant_mode_from_pattern(internal_pattern): + """Get quant mode from internal pattern.""" + if internal_pattern[0] == 'precision': + return internal_pattern[1][0] + else: + return internal_pattern[0] + +def initial_tuning_cfg_with_quant_mode(op_name_type, quant_mode, tuning_space: TuningSpace) -> OpTuningConfig: + """Initialize the tuning cfg. + + Args: + op_name_type: (op name, op type) + quant_mode: dynamic/static/fp32/bf16/fp16 + tuning_space: tuning space. + + step1, convert the quant_mode into internal format. + step2, complete the path based. + step3, get the mode item. + step4, use the first option as value for method. + step5, create the op tuning config. + + Returns: + The initial tuning config. + """ + internal_pattern = pattern_to_internal(quant_mode) + full_path = {'activation': None, 'weight': None} + full_path['activation'], full_path['weight'] = pattern_to_path(internal_pattern) + has_weight = op_name_type in tuning_space.ops_attr['weight'] + + config_args = {} + att_lst = ['activation', 'weight'] if has_weight else ['activation'] + for att in att_lst: + att_full_path = tuning_space.get_default_full_path(op_name_type, full_path[att]) + config_args[att + '_dtype'] = tuning_space.ops_data_type[op_name_type].get(att_full_path, None) + mode_item = tuning_space.get_item_by_path((op_name_type, *att_full_path)) + if mode_item: + method_args = {method_item.name: method_item.options[0] for method_item in mode_item.options \ + if method_item.name in TUNING_ITEMS_LST} + config_args.update(method_args) + quant_mode = internal_pattern[0] + # set the first option as the default for each tuning item + op_tuning_config = OpTuningConfig(op_name_type[0], + op_name_type[1], + quant_mode, + tuning_space, + kwargs=config_args) + return op_tuning_config \ No newline at end of file diff --git a/neural_compressor/experimental/strategy/utils/tuning_structs.py b/neural_compressor/experimental/strategy/utils/tuning_structs.py new file mode 100644 index 00000000000..b13f27cf0cd --- /dev/null +++ b/neural_compressor/experimental/strategy/utils/tuning_structs.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tuning structure.""" + +from typing import Dict +from .constant import QUANT_MODE_SET, TUNING_ITEMS_LST, PRECISION_SET +from ....utils import logger + +class OpTuningConfig: + """Op tuning config.""" + + def __init__(self, op_name, op_type, op_quant_mode, tuning_space, kwargs={}): + """Create the tuning config. + + Args: + op_name: op name. + op_type: op type. + op_quant_mode: quantization mode. + tuning_space: tuning space. + kwargs: other parameters. Defaults to {}. + """ + self.op_name = op_name + self.op_type = op_type + self.op_name_type = (self.op_name, self.op_type) + self.op_quant_mode = op_quant_mode # static/dynamic/fp32/bf16/fp16 + self.kwargs = kwargs + self.act_dtype = None + self.weight_dtype = None + self.has_weight = self.op_name_type in tuning_space.ops_attr['weight'] + self._set_dtype() + + def _set_dtype(self): + """Set the date type.""" + if self.op_quant_mode in PRECISION_SET: + self.act_dtype, self.weight_dtype = self.op_quant_mode, self.op_quant_mode + else: + self.act_dtype = self.kwargs.get('activation_dtype', None) + self.weight_dtype = self.kwargs.get('weight_dtype', None) + assert self.act_dtype and isinstance(self.act_dtype, str),\ + (f"Didn't assign the activation data type for {self.op_name, self.op_type}", \ + f"with quant_mode {self.op_quant_mode}") + # if self.has_weight: + # assert self.weight_dtype, \ + # (f"Didn't assign the weight data type for {self.op_name, self.op_type}", \ + # f"with quant_mode {self.op_quant_mode}") + + + def __str__(self) -> str: + """Display the tuning config as string. + + Returns: + msg: the tuning config as string. + """ + msg = f"op name: {self.op_name}, op type : {self.op_type} \n" + msg += f"\t activation dtype: {self.act_dtype} \n" + msg += f"\t weight dtype: {self.weight_dtype} \n" if self.has_weight else "" + for key, val in self.kwargs.items(): + if key in TUNING_ITEMS_LST: + msg += f"\t {key[0]} {key[1]}: {val}\n" + return msg + + def get_state(self): + """Return the op tuning configuration. + + Returns: + Dict: The op tuning state. + """ + result = {} + if self.has_weight: + result['weight'] = { + 'dtype': self.weight_dtype, + } + result['activation'] = { + 'dtype': self.act_dtype, + 'quant_mode': self.op_quant_mode, + } + for key, val in self.kwargs.items(): + if key in TUNING_ITEMS_LST: + result[key[0]][key[1]] = val + return result + + @classmethod + def from_state(cls, config: Dict): + """Create the tuning config from dict. + + Args: + config: A dict includes the tuning config. + """ + cls(**config) diff --git a/neural_compressor/experimental/strategy/utils/utility.py b/neural_compressor/experimental/strategy/utils/utility.py new file mode 100644 index 00000000000..22b95176e59 --- /dev/null +++ b/neural_compressor/experimental/strategy/utils/utility.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tuning utility.""" + + +from collections import OrderedDict + +class OrderedDefaultDict(OrderedDict): + """Ordered default dict.""" + + def __missing__(self, key): + """Initialize value for the missing key.""" + self[key] = value = OrderedDefaultDict() + return value + +def extract_data_type(data_type: str) -> str: + """Extract data type and signed from data type. + + Args: + data_type: The original data type such as uint8, int8. + + Returns: + (signed or unsigned, data type without signed) + """ + return ('signed', data_type) if data_type[0] != 'u' else ('unsigned', data_type[1:]) + +def reverted_data_type(signed_flag: str, data_type: str) -> str: + """Revert the data type.""" + return data_type if signed_flag == 'signed' else 'u' + data_type + +def get_adaptor_name(adaptor): + """Get adaptor name. + + Args: + adaptor: adaptor instance. + """ + adaptor_name = type(adaptor).__name__.lower() + adaptor_name_lst = ['onnx', 'tensorflow', 'pytorch'] + for name in adaptor_name_lst: + if adaptor_name.startswith(name): + return name + return "" \ No newline at end of file diff --git a/neural_compressor/quantization.py b/neural_compressor/quantization.py index 545b7de90cb..8061442b77d 100644 --- a/neural_compressor/quantization.py +++ b/neural_compressor/quantization.py @@ -116,8 +116,9 @@ def pre_proccess(self): self.conf, self._calib_dataloader, self._train_func, - self._eval_dataloader, self._eval_func, + self._eval_dataloader, + self._eval_metric, _resume, self.callbacks.hooks if self.callbacks is not None else None) @@ -289,6 +290,7 @@ def metric(self, user_metric): Multi-metrics: {topk: 1, MSE: {compare_label: False}, + } For the built-in metrics, please refer to below link: https://github.com/intel/neural-compressor/blob/master/docs/source/metric.md#supported-built-in-metric-matrix. diff --git a/neural_compressor/strategy/auto.py b/neural_compressor/strategy/auto.py index 3e8a1ef3072..26048f9aa30 100644 --- a/neural_compressor/strategy/auto.py +++ b/neural_compressor/strategy/auto.py @@ -34,8 +34,16 @@ class AutoTuneStrategy(TuneStrategy): and the tuning process ends once the condition meets the exit policy. """ - def __init__(self, model, conf, q_dataloader=None, q_func=None, \ - eval_dataloader=None, eval_func=None, resume=None, q_hooks=None): + def __init__(self, + model, + conf, + q_dataloader=None, + q_func=None, + eval_func=None, + eval_dataloader=None, + eval_metric=None, + resume=None, + q_hooks=None): """Init an auto tuning strategy. Args: diff --git a/neural_compressor/strategy/auto_mixed_precision.py b/neural_compressor/strategy/auto_mixed_precision.py index 04f0bc39307..44471670626 100644 --- a/neural_compressor/strategy/auto_mixed_precision.py +++ b/neural_compressor/strategy/auto_mixed_precision.py @@ -117,7 +117,7 @@ def traverse(self): tune_cfg = self._tune_cfg_converter(op_tuning_cfg) self.trials_count += 1 tuning_history = self._find_tuning_history(tune_cfg) - if tuning_history and self.trials_count < self.cfg.tuning.exit_policy.max_trials: + if tuning_history and self.trials_count < self.conf.quantization.tuning_criterion.max_trials: self.last_tune_result = tuning_history['last_tune_result'] self.best_tune_result = tuning_history['best_tune_result'] logger.warn("Find evaluated tuning config, skip.") diff --git a/neural_compressor/strategy/bayesian.py b/neural_compressor/strategy/bayesian.py index d267e8cdbc4..12ce0a23429 100644 --- a/neural_compressor/strategy/bayesian.py +++ b/neural_compressor/strategy/bayesian.py @@ -101,7 +101,7 @@ def next_tune_cfg(self): return if self.bayes_opt is None: self.bayes_opt = BayesianOptimization( - pbounds=pbounds, random_seed=self.cfg.tuning.random_seed) + pbounds=pbounds, random_seed=self.conf.options.random_seed) while True: params = self.bayes_opt.gen_next_params() logger.debug("Dump current bayesian params:") diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py index b9e3187cb41..57cf6ec106f 100644 --- a/neural_compressor/strategy/strategy.py +++ b/neural_compressor/strategy/strategy.py @@ -18,7 +18,6 @@ """The base class for tuning strategy.""" from abc import abstractmethod -from enum import EnumMeta import os import math import copy @@ -31,6 +30,8 @@ from typing import OrderedDict as T_OrderedDict from neural_compressor.adaptor.tensorflow import TensorFlowAdaptor +from neural_compressor.config import PostTrainingQuantConfig +from ..config import MixedPrecisionConfig from ..objective import MultiObjective from ..adaptor import FRAMEWORKS from ..utils.utility import Statistics, dump_data_to_local @@ -78,9 +79,25 @@ def strategy_registry(cls): @strategy_registry class TuneStrategy(object): """Basic class for tuning strategy.""" - - def __init__(self, model, conf, q_dataloader=None, q_func=None, eval_dataloader=None, - eval_func=None, resume=None, q_hooks=None): + + def _check_tuning_status(self): + if self.eval_func: + self._not_tuning = False + return + elif self.eval_dataloader and self.eval_metric: + self._not_tuning = False + return + + def __init__(self, + model, + conf: PostTrainingQuantConfig, + q_dataloader=None, + q_func=None, + eval_func=None, + eval_dataloader=None, + eval_metric=None, + resume=None, + q_hooks=None): """Init the TuneStrategy. Args: @@ -99,15 +116,18 @@ def __init__(self, model, conf, q_dataloader=None, q_func=None, eval_dataloader= best_qmodel: The best quantized model that generated during the tuning process. """ self.model = model - self.cfg = conf.usr_cfg - self.cfg_bk = copy.deepcopy(self.cfg) - self.history_path = self._create_path(self.cfg.tuning.workspace.path, './history.snapshot') - self.deploy_path = self._create_path(self.cfg.tuning.workspace.path, 'deploy.yaml') - self.eval_dataloader = eval_dataloader + self.conf = conf + self.history_path = self._create_path(self.conf.options.workspace, './history.snapshot') + self.deploy_path = self._create_path(self.conf.options.workspace, 'deploy.yaml') self.calib_dataloader = q_dataloader + self.eval_dataloader = eval_dataloader + self.eval_metric = eval_metric + self.eval_func = eval_func + # not tuning equals to performance only + self._not_tuning = True + self._check_tuning_status() self.q_func = q_func self.q_hooks = q_hooks - self.eval_func = eval_func GLOBAL_STATE.STATE = MODE.QUANTIZATION framework, framework_specific_info = self._set_framework_info(q_dataloader, q_func) self.adaptor = FRAMEWORKS[framework](framework_specific_info) @@ -119,29 +139,7 @@ def __init__(self, model, conf, q_dataloader=None, q_func=None, eval_dataloader= self.tune_result_record = [] self.tuning_history = [] self.tuning_result_data = [] - # The tuning history ever made, structured like below: - # [ - # { - # 'version': __version__, - # 'cfg': cfg1, - # 'framework': tensorflow - # 'baseline': baseline1, - # 'last_tune_result': last_tune_result1, - # 'best_tune_result': best_tune_result1, - # 'history': [ - # # tuning history under same yaml config - # {'tune_cfg': tune_cfg1, 'tune_result': \ - # tune_result1, 'q_config': q_config1, ...}, - - # ..., - # ], - # # new fields added by subclass for resuming - # ..., - # }, - # # tuning history under different yaml configs - # ..., - # ] - + self.baseline = None self.last_tune_result = None self.last_qmodel = None @@ -158,7 +156,7 @@ def __init__(self, model, conf, q_dataloader=None, q_func=None, eval_dataloader= self.set_tuning_space(conf) #For algo scheduler - self.algo_scheduler = AlgorithmScheduler(self.cfg.quantization.recipes) + self.algo_scheduler = AlgorithmScheduler(self.conf.quantization.recipes) self.algo_scheduler.dataloader = self.calib_dataloader # reuse the calibration iteration self.algo_scheduler.origin_model = self.model self.algo_scheduler.adaptor = self.adaptor @@ -179,6 +177,8 @@ def __init__(self, model, conf, q_dataloader=None, q_func=None, eval_dataloader= self._not_tuning_recipes_values = {} self._initialize_recipe() self.applied_all_recipes_flag = False + + if resume is not None: self.setup_resume(resume) @@ -215,7 +215,7 @@ def _initialize_recipe(self): # not tuning list: the value is not equal to the default value logger.info(f"Adaptor has {len(adaptor_recipes)} recipes.") logger.debug(adaptor_recipes) - usr_recipes_cfg = self.cfg_bk.quantization.recipes if self.cfg_bk.quantization.recipes else {} + usr_recipes_cfg = self.conf.quantization.recipes if self.conf.quantization.recipes else {} for recipe_name, recipe_val in usr_recipes_cfg.items(): # for not tuning recipes, use the value specified by user. if recipe_name in adaptor_recipes and recipe_val != adaptor_recipes[recipe_name][0]: @@ -350,9 +350,9 @@ def master_worker_handle(self, comm): break # send the next cfg if not exceed max trials - if self.overall_trials > self.cfg.tuning.exit_policy.max_trials: + if self.overall_trials > self.conf.quantization.tuning_criterion.max_trials: self.max_trial_flag = True - # elif time.time() - self.overall_time_start > self.cfg.tuning.exit_policy.timeout: + # elif time.time() - self.overall_time_start > self.conf.quantization.tuning_criterion.timeout: # self.max_time_flag = True elif cur_cfg_id < len(self.tune_cfg_lst): logger.info("[Rank {}]master sends new tuning cfg {} to rank: {}".format(comm.Get_rank(), \ @@ -590,8 +590,8 @@ def traverse(self): The main traverse logic which could be override by some concrete strategy which needs more hooks. """ self._eval_baseline() - if self.cfg.tuning.use_distributed_tuning: - logger.info("use distributed traverse: {}".format(self.cfg.tuning.use_distributed_tuning)) + if self.conf.quantization.use_distributed_tuning: + logger.info("use distributed traverse: {}".format(self.conf.quantization.use_distributed_tuning)) return self.distributed_traverse() traverse_start_time = time() for op_tuning_cfg in self.next_tune_cfg(): @@ -599,7 +599,7 @@ def traverse(self): tune_cfg = self._tune_cfg_converter(op_tuning_cfg) self.trials_count += 1 tuning_history = self._find_tuning_history(tune_cfg) - if tuning_history and self.trials_count < self.cfg.tuning.exit_policy.max_trials: + if tuning_history and self.trials_count < self.conf.quantization.tuning_criterion.max_trials: self.last_tune_result = tuning_history['last_tune_result'] self.best_tune_result = tuning_history['best_tune_result'] logger.warn("Find evaluated tuning config, skip.") @@ -623,13 +623,13 @@ def traverse(self): self.algo_scheduler.reset_exec_algorithms() assert self.last_qmodel # Return the last quantized model as a result. if performance only. - if self.cfg.tuning.exit_policy.performance_only: + if self._not_tuning: self.best_qmodel = self.last_qmodel self._add_tuning_history(copy.deepcopy(tune_cfg), (-1, [0]), q_config=self.last_qmodel.q_config) return self.last_tune_result = self._evaluate(self.last_qmodel) self.cur_best_acc, self.cur_best_tuning_cfg = self.update_best_op_tuning_cfg(op_tuning_cfg) - need_stop = self.stop(self.cfg.tuning.exit_policy.timeout, self.trials_count) + need_stop = self.stop(self.conf.quantization.tuning_criterion.timeout, self.trials_count) # record the tuning history saved_tune_cfg = copy.deepcopy(tune_cfg) @@ -657,7 +657,7 @@ def traverse(self): continue # recover the best quantized model from tuning config self._recover_best_qmodel_from_tuning_cfg() - if self.cfg.tuning.diagnosis and self.cfg.tuning.diagnosis.diagnosis_after_tuning: + if self.conf.options.diagnosis: logger.debug(f'*** Start to do diagnosis (inspect tensor).') self._diagnosis() if self.use_multi_objective and len(self.tune_result_record) > 1 and \ @@ -667,7 +667,7 @@ def traverse(self): if best_result != self.best_tune_result: from neural_compressor.utils.utility import recover self.best_qmodel = recover(self.model.model, - os.path.join(self.cfg.tuning.workspace.path, 'history.snapshot'), + os.path.join(self.conf.options.workspace, 'history.snapshot'), best_trail) logger.debug(f"*** Update the best qmodel by recovering from history.") self.best_tune_result = best_result @@ -683,33 +683,20 @@ def _remove_redundant_qmodel(self): """ self.last_qmodel = None self.best_qmodel = None - - def _can_create_eval_func_from_cfg(self): - """Determine whether an eval function can be created from cfg. - - Returns: - Returns True if the eval func can be created from config, False otherwise. - """ - if self.cfg.evaluation and self.cfg.evaluation.accuracy and \ - (self.cfg.evaluation.accuracy.metric or self.cfg.evaluation.accuracy.multi_metrics)\ - and self.eval_dataloader: - return True - return False def _eval_baseline(self): """Evaluate the fp32 model if needed.""" - if not self._can_create_eval_func_from_cfg() and not self.eval_func: + if self._not_tuning: logger.info("Neither evaluation function nor metric is defined." \ " Generate a quantized model with default quantization configuration.") - self.cfg.tuning.exit_policy.performance_only = True - logger.info("Force setting 'tuning.exit_policy.performance_only = True'.") + self._not_tuning = True - if not self.cfg.tuning.exit_policy.performance_only: + if not self._not_tuning: # get fp32 model baseline if self.baseline is None: logger.info("Get FP32 model baseline.") self._fp32_model = self.model - self.baseline = self._evaluate(self.model) + self.baseline = self._evaluate(self.model) self.objectives.baseline = self.baseline # record the FP32 baseline self._add_tuning_history() @@ -828,14 +815,13 @@ def initial_tuning_cfg(self): """ from .utils.constant import auto_query_order, static_query_order, dynamic_query_order from .utils.tuning_space import initial_tuning_cfg_with_quant_mode - if self.cfg.quantization.approach == 'post_training_auto_quant': + if self.conf.quantization.approach == 'post_training_auto_quant': query_order = auto_query_order - elif self.cfg.quantization.approach == 'post_training_dynamic_quant': + elif self.conf.quantization.approach == 'post_training_dynamic_quant': query_order = dynamic_query_order - elif self.cfg.quantization.approach == 'post_training_static_quant': + elif self.conf.quantization.approach == 'post_training_static_quant': query_order = static_query_order - elif self.cfg.quantization.approach == 'quant_aware_training': - logger.info("!!! Currently, the qat tuning is not supported by strategy.") + elif self.conf.quantization.approach == 'quant_aware_training': query_order = auto_query_order quant_mode_wise_items = OrderedDict() # mode, op_item_lst @@ -929,15 +915,14 @@ def _tune_cfg_converter(self, op_tuning_cfg): self.calib_dataloader.batch_size) else: tune_cfg['calib_iteration'] = 1 - tune_cfg['advance'] = self.cfg.quantization.advance - tune_cfg['approach'] = self.cfg.quantization.approach + tune_cfg['approach'] = self.conf.quantization.approach # Add the recipe config tune_cfg['recipe_cfgs'] = tune_cfg.get('recipe_cfgs', {}) # For not tuning recipe, tune cfg use it directly tune_cfg['recipe_cfgs'].update(self._not_tuning_recipes_values) # WA for get the smooth quant args - if 'smooth_quant_args' in self.cfg_bk.quantization.recipes: - tune_cfg['recipe_cfgs']['smooth_quant_args'] = self.cfg_bk.quantization.recipes['smooth_quant_args'] + if 'smooth_quant_args' in self.conf.quantization.recipes: + tune_cfg['recipe_cfgs']['smooth_quant_args'] = self.conf.quantization.recipes['smooth_quant_args'] # For tuning recipe, use the default value if it not specified by recipe tuning sampler. for recipe_name, recipe_val in self._tuning_recipes_default_values.items(): if recipe_name not in tune_cfg['recipe_cfgs']: @@ -952,7 +937,7 @@ def set_tuning_space(self, conf): Args: conf: The Conf class instance includes all user configurations. """ - calib_sampling_size_lst = self.cfg.quantization.calibration.sampling_size + calib_sampling_size_lst = self.conf.quantization.calibration_sampling_size calib_sampling_size_lst = [int(calib_sampling_size) for calib_sampling_size in calib_sampling_size_lst] if self.calib_dataloader: self.calib_iter = [math.ceil(int(x) / self.calib_dataloader.batch_size) \ @@ -974,7 +959,7 @@ def setup_resume(self, resume): """ self.__dict__.update(resume) for history in self.tuning_history: - if self._same_yaml(history['cfg'], self.cfg): + if self._same_yaml(history['cfg'], self.conf): self.__dict__.update({k: v for k, v in history.items() \ if k not in ['version', 'history']}) logger.info("Start to resume tuning process.") @@ -993,14 +978,8 @@ def setup_resume(self, resume): def set_q_func(self): """Set the training function for quantization aware training.""" - if self.q_func == None and self.cfg.quantization.approach == 'quant_aware_training': - train_cfg = self.cfg.quantization.train - assert train_cfg, "train field of quantization section in yaml file must " \ - "be configured for quantization aware training if q_func is NOT set." - assert self.calib_dataloader, "dataloader field of train field of quantization " \ - "section in yaml file must be configured." - self.q_func = create_train_func(self.framework, self.calib_dataloader, \ - self.adaptor, train_cfg, hooks=self.q_hooks) + if self.conf.quantization.approach == 'quant_aware_training': + assert self.q_func != None, "Please set train func for quantization aware training" def _create_path(self, custom_path, filename): new_path = os.path.join(os.path.abspath(os.path.expanduser(custom_path)),filename) @@ -1009,95 +988,113 @@ def _create_path(self, custom_path, filename): return new_path def _set_framework_info(self, q_dataloader, q_func=None): - framework_specific_info = {'device': self.cfg.device, - 'approach': self.cfg.quantization.approach, - 'random_seed': self.cfg.tuning.random_seed, - 'performance_only': self.cfg.tuning.exit_policy.performance_only,} - framework = self.cfg.model.framework.lower() - framework_specific_info.update({'backend': self.cfg.model.get('backend', 'default')}) - framework_specific_info.update({'format': self.cfg.model.get('quant_format', 'default')}) - framework_specific_info.update({'domain': self.cfg.model.get('domain', 'auto')}) - - self.mixed_precision_mode = bool('mixed_precision' in self.cfg) or \ - bool('graph_optimization' in self.cfg) + framework_specific_info = {'device': self.conf.quantization.device, + 'approach': self.conf.quantization.approach, + 'random_seed': self.conf.options.random_seed, + 'performance_only': self._not_tuning} + framework = self.conf.quantization.framework.lower() + framework_specific_info.update({'backend': self.conf.quantization.backend}) + framework_specific_info.update({'format': self.conf.quantization.quant_format}) + framework_specific_info.update({'domain': self.conf.quantization.quant_format}) + + self.mixed_precision_mode = isinstance(self.conf.quantization, MixedPrecisionConfig) if 'tensorflow' in framework: framework_specific_info.update( - {"inputs": self.cfg.model.inputs, - "outputs": self.cfg.model.outputs, - 'workspace_path': self.cfg.tuning.workspace.path, - 'recipes': self.cfg.quantization.recipes, - 'use_bf16': self.cfg.use_bf16 if self.cfg.use_bf16 is not None else False}) + {"inputs": self.conf.quantization.inputs, + "outputs": self.conf.quantization.outputs, + 'workspace_path': self.conf.options.workspace, + 'recipes': self.conf.quantization.recipes, + 'use_bf16': self.conf.quantization.use_bf16 if self.conf.quantization.use_bf16 is not None else False}) for item in ['scale_propagation_max_pooling', 'scale_propagation_concat']: if item not in framework_specific_info['recipes']: framework_specific_info['recipes'].update({item: True}) - if self.cfg.model.backend == 'itex': - self.cfg.model.framework = 'tensorflow_itex' + if self.conf.quantization.backend == 'itex': + #TODO replace it with when config ready framework = 'tensorflow_itex' if 'keras' in framework: framework_specific_info.update({ - 'workspace_path': self.cfg.tuning.workspace.path, }) + 'workspace_path': self.conf.options.workspace, }) if framework == 'mxnet': framework_specific_info.update({"q_dataloader": q_dataloader}) if 'onnx' in framework.lower(): if self.mixed_precision_mode: framework_specific_info.update({"approach": "post_training_dynamic_quant"}) framework_specific_info.update({"deploy_path": os.path.dirname(self.deploy_path)}) - framework_specific_info.update({'workspace_path': self.cfg.tuning.workspace.path}) - framework_specific_info.update({'recipes': self.cfg.quantization.recipes}) - framework_specific_info.update({'reduce_range': self.cfg.reduce_range}) - framework_specific_info.update({'recipes': self.cfg.quantization.get('recipes', {})}) + framework_specific_info.update({'workspace_path': self.conf.options.workspace}) + framework_specific_info.update({'recipes': self.conf.quantization.recipes}) + framework_specific_info.update({'reduce_range': self.conf.quantization.reduce_range}) + framework_specific_info.update({'recipes': self.conf.quantization.recipes}) if framework.lower() == 'onnxrt_qdq' or \ framework_specific_info['backend'] == 'onnxrt_trt_ep': framework_specific_info.update({'format': 'QDQ'}) framework = 'onnxrt_qdq' if framework == 'pytorch_ipex' or framework == 'pytorch' or framework == 'pytorch_fx': - if self.cfg.model.backend == 'ipex': - self.cfg.model.framework = 'pytorch_ipex' + if self.conf.quantization.backend == 'ipex': framework = 'pytorch_ipex' - elif self.cfg.model.backend == 'default': - self.cfg.model.framework = 'pytorch_fx' + elif self.conf.quantization.backend == 'default': framework = 'pytorch_fx' if self.mixed_precision_mode: framework_specific_info.update({"approach": "post_training_dynamic_quant"}) framework_specific_info.update({"q_dataloader": q_dataloader}) - framework_specific_info.update({"use_bf16": self.cfg.use_bf16 \ - if self.cfg.use_bf16 is not None else True}) + framework_specific_info.update({"use_bf16": self.conf.quantization.use_bf16 \ + if self.conf.quantization.use_bf16 is not None else True}) framework_specific_info.update( {"workspace_path": os.path.dirname(self.deploy_path)}) - if self.cfg['quantization']['op_wise'] is not None \ - and 'default_qconfig' in self.cfg['quantization']['op_wise']: + if self.conf.quantization.op_name_dict is not None \ + and 'default_qconfig' in self.conf.quantization.op_name_dict: framework_specific_info.update( - {"default_qconfig": self.cfg['quantization']['op_wise']['default_qconfig']}) + {"default_qconfig": self.conf.quantization.op_name_dict['default_qconfig']}) framework_specific_info.update({"q_func": q_func}) - framework_specific_info.update({"example_inputs": self.cfg.quantization.example_inputs}) + framework_specific_info.update({"example_inputs": self.conf.quantization.example_inputs}) return framework, framework_specific_info def _set_objectives(self): - self.higher_is_better = bool(self.cfg.tuning.accuracy_criterion.higher_is_better) - self.use_multi_objective = deep_get(self.cfg, 'tuning.multi_objectives') and \ - len(self.cfg.tuning.multi_objectives.objective) > 1 - objectives = [i.lower() for i in self.cfg.tuning.multi_objectives.objective] if \ - self.use_multi_objective else [self.cfg.tuning.objective.lower()] - self.metric_weight = deep_get(self.cfg, 'evaluation.accuracy.multi_metrics.weight') - self.metric_name = ['Accuracy'] if \ - not deep_get(self.cfg, 'evaluation.accuracy.multi_metrics') else \ - self.cfg.evaluation.accuracy.multi_metrics.keys()-{'weight','higher_is_better'} - if len(self.metric_name) == 1: - self.metric_criterion = [self.higher_is_better] - elif not deep_get(self.cfg, 'evaluation.accuracy.multi_metrics.higher_is_better'): - # default is True - self.metric_criterion = [True] * len(self.metric_name) + # set objectives + self.higher_is_better = bool(self.conf.quantization.accuracy_criterion.higher_is_better) + obj_higher_is_better = None + obj_weight = None + if self.conf.quantization.tuning_criterion.multi_objectives: + obj_higher_is_better = self.conf.quantization.tuning_criterion.multi_objectives.get('higher_is_better', None) + obj_weight = self.conf.quantization.tuning_criterion.multi_objectives.get('weight', None) + obj_lst = self.conf.quantization.tuning_criterion.multi_objectives.get('objective', []) + self.use_multi_objective = len(obj_lst) > 0 + if self.use_multi_objective: + objectives = [i.lower() for i in obj_lst] else: - self.metric_criterion = \ - deep_get(self.cfg, 'evaluation.accuracy.multi_metrics.higher_is_better') - - self.objectives = MultiObjective(objectives, - self.cfg.tuning.accuracy_criterion, - self.metric_criterion, - self.metric_weight, - deep_get(self.cfg, 'tuning.multi_objectives.higher_is_better'), - deep_get(self.cfg, 'tuning.multi_objectives.weight')) + objectives = [self.conf.quantization.tuning_criterion.objective.lower()] + + # set metric + self.metric_name = ['Accuracy'] + self.metric_criterion = [self.higher_is_better] + self.metric_weight = None + use_multi_metrics = False + if self.eval_metric: + # metric name + # 'weight','higher_is_better', 'metric1', 'metric2', ... + if len(self.eval_metric.keys()) >= 4: + self.metric_name = self.eval_metric.keys() - {'weight','higher_is_better'} + use_multi_metrics = True + metric_higher_is_better = self.eval_metric.get('higher_is_better', None) + # metric criterion + if use_multi_metrics: + if metric_higher_is_better is not None: + self.metric_criterion = [metric_higher_is_better] * len(self.metric_name) + else: + self.metric_criterion = [True] * len(self.metric_name) + # metric weight + self.metric_weight = self.eval_metric.get('weight', None) + + accuracy_criterion = {'relative': 0.01, 'higher_is_better': True} + accuracy_criterion_conf = self.conf.quantization.accuracy_criterion + accuracy_criterion[accuracy_criterion_conf.criterion] = accuracy_criterion_conf.tolerable_loss + accuracy_criterion['higher_is_better'] = accuracy_criterion_conf.higher_is_better + self.objectives = MultiObjective(objectives=objectives, + accuracy_criterion=accuracy_criterion, + metric_criterion=self.metric_criterion, + metric_weight=self.metric_weight, + obj_criterion=obj_higher_is_better, + obj_weight=obj_weight) def _same_yaml(self, src_yaml, dst_yaml): """Check if the two yamls are the same. @@ -1146,41 +1143,43 @@ def update_best_op_tuning_cfg(self, op_tuning_cfg): return self.cur_best_acc, self.cur_best_tuning_cfg def deploy_config(self): - """Save the configuration locally for deployment.""" - acc_dataloader_cfg = deep_get(self.cfg, 'evaluation.accuracy.dataloader') - perf_dataloader_cfg = deep_get(self.cfg, 'evaluation.performance.dataloader') - # use acc dataloader if perf dataloader is not configured - if perf_dataloader_cfg is None: - perf_dataloader_cfg = acc_dataloader_cfg - - self.deploy_cfg = OrderedDict() - # int8 dataloader graph transform - if deep_get(perf_dataloader_cfg, 'transform.QuantizedInput') is not None \ - or deep_get(acc_dataloader_cfg, 'transform.QuantizedInput') is not None: - self.best_qmodel, scale = self.adaptor.quantize_input(self.best_qmodel) - deep_set(perf_dataloader_cfg, 'transform.QuantizedInput.dtype', 'int8') - deep_set(perf_dataloader_cfg, 'transform.QuantizedInput.scale', scale) - deep_set(acc_dataloader_cfg, 'transform.QuantizedInput.dtype', 'int8') - deep_set(acc_dataloader_cfg, 'transform.QuantizedInput.scale', scale) - - self.deploy_cfg['model'] = self.cfg.model - self.deploy_cfg['device'] = self.cfg.device - if self.cfg.evaluation is not None: - deep_set(self.cfg, 'evaluation.performance.dataloader',\ - perf_dataloader_cfg) - deep_set(self.cfg, 'evaluation.accuracy.dataloader', \ - acc_dataloader_cfg) - self.deploy_cfg['evaluation'] = self.cfg.evaluation - - def setup_yaml(): - represent_dict_order = lambda self, \ - data: self.represent_mapping('tag:yaml.org,2002:map', data.items()) - yaml.add_representer(OrderedDict, represent_dict_order) - yaml.add_representer(DotDict, represent_dict_order) - setup_yaml() - with open(self.deploy_path, 'w+') as f: - yaml.dump(self.deploy_cfg, f) - logger.info("Save deploy yaml to {}".format(self.deploy_path)) + return + #TODO uncomment it after config ready + # """Save the configuration locally for deployment.""" + # acc_dataloader_cfg = deep_get(self.cfg, 'evaluation.accuracy.dataloader') + # perf_dataloader_cfg = deep_get(self.cfg, 'evaluation.performance.dataloader') + # # use acc dataloader if perf dataloader is not configured + # if perf_dataloader_cfg is None: + # perf_dataloader_cfg = acc_dataloader_cfg + + # self.deploy_cfg = OrderedDict() + # # int8 dataloader graph transform + # if deep_get(perf_dataloader_cfg, 'transform.QuantizedInput') is not None \ + # or deep_get(acc_dataloader_cfg, 'transform.QuantizedInput') is not None: + # self.best_qmodel, scale = self.adaptor.quantize_input(self.best_qmodel) + # deep_set(perf_dataloader_cfg, 'transform.QuantizedInput.dtype', 'int8') + # deep_set(perf_dataloader_cfg, 'transform.QuantizedInput.scale', scale) + # deep_set(acc_dataloader_cfg, 'transform.QuantizedInput.dtype', 'int8') + # deep_set(acc_dataloader_cfg, 'transform.QuantizedInput.scale', scale) + + # self.deploy_cfg['model'] = self.cfg.model + # self.deploy_cfg['device'] = self.conf.quantization.device + # if self.cfg.evaluation is not None: + # deep_set(self.cfg, 'evaluation.performance.dataloader',\ + # perf_dataloader_cfg) + # deep_set(self.cfg, 'evaluation.accuracy.dataloader', \ + # acc_dataloader_cfg) + # self.deploy_cfg['evaluation'] = self.cfg.evaluation + + # def setup_yaml(): + # represent_dict_order = lambda self, \ + # data: self.represent_mapping('tag:yaml.org,2002:map', data.items()) + # yaml.add_representer(OrderedDict, represent_dict_order) + # yaml.add_representer(DotDict, represent_dict_order) + # setup_yaml() + # with open(self.deploy_path, 'w+') as f: + # yaml.dump(self.deploy_cfg, f) + # logger.info("Save deploy yaml to {}".format(self.deploy_path)) def _get_common_cfg(self, model_wise_cfg, op_wise_cfgs): """Get the common parts from the model_wise_cfg. @@ -1228,36 +1227,29 @@ def _evaluate(self, model): Objective: The objective value evaluated. """ if self.eval_func: - if self.cfg.tuning.tensorboard: + if self.conf.options.tensorboard: # Pytorch can insert observer to model in this hook. # Tensorflow don't support this mode for now model = self.adaptor._pre_eval_hook(model) val = self.objectives.evaluate( self.eval_func, model if self.framework == "pytorch_ipex" else model.model ) - if self.cfg.tuning.tensorboard: + if self.conf.options.tensorboard: # post_eval_hook to deal the tensor self.adaptor._post_eval_hook(model, accuracy=val[0]) else: - assert self.cfg.evaluation and self.cfg.evaluation.accuracy and \ - (self.cfg.evaluation.accuracy.metric or \ - self.cfg.evaluation.accuracy.multi_metrics), \ - "metric or multi_metrics field of accuracy field of evaluation" \ - " section should not be empty" - - postprocess_cfg = self.cfg.evaluation.accuracy.postprocess - metric_cfg = self.cfg.evaluation.accuracy.metric if \ - self.cfg.evaluation.accuracy.metric else \ - self.cfg.evaluation.accuracy.multi_metrics - iteration = -1 if self.cfg.evaluation.accuracy.iteration is None \ - else self.cfg.evaluation.accuracy.iteration + assert self._not_tuning, "Please set eval_dataloader and eval_metric for create eval_func" + + postprocess_cfg = None + metric_cfg = self.eval_metric + iteration = -1 eval_func = create_eval_func(self.framework, self.eval_dataloader, self.adaptor, metric_cfg, postprocess_cfg, iteration, - tensorboard = self.cfg.tuning.tensorboard, + tensorboard = self.conf.options.tensorboard, fp32_baseline = self.baseline == None) if getattr(self.eval_dataloader, 'distributed', False): @@ -1320,7 +1312,7 @@ def stop(self, timeout, trials_count): bool: True if need stop, otherwise False """ need_stop = False - if self.cfg.tuning.exit_policy.performance_only or \ + if self._not_tuning or \ self.objectives.compare(self.best_tune_result, self.baseline): self.best_tune_result = self.last_tune_result self.best_qmodel = self.last_qmodel @@ -1431,11 +1423,11 @@ def stop(self, timeout, trials_count): 'Best tune result']).print_stat() - if self.cfg.tuning.exit_policy.performance_only: + if self._not_tuning: need_stop = True elif timeout == 0 and self.best_tune_result: need_stop = True - elif self.trials_count >= self.cfg.tuning.exit_policy.max_trials: + elif self.trials_count >= self.conf.quantization.tuning_criterion.max_trials: need_stop = True else: need_stop = False @@ -1460,7 +1452,8 @@ def _find_tuning_history(self, tune_cfg): for tuning_history in self.tuning_history: # only check if a tune_cfg is evaluated under same yam config, excluding # some fields in tuning section of yaml, such as tensorboard, snapshot, resume. - if self._same_yaml(tuning_history['cfg'], self.cfg): + # TODO double check + if self._same_yaml(tuning_history['cfg'], self.conf): for history in tuning_history['history']: if history and history['tune_cfg'] == tune_cfg: return tuning_history @@ -1476,7 +1469,8 @@ def _find_history(self, tune_cfg): for tuning_history in self.tuning_history: # only check if a tune_cfg is evaluated under same yam config, excluding # some fields in tuning section of yaml, such as tensorboard, snapshot, resume. - if self._same_yaml(tuning_history['cfg'], self.cfg): + # TODO double check + if self._same_yaml(tuning_history['cfg'], self.conf): for history in tuning_history['history']: if history and history['tune_cfg'] == tune_cfg: return history @@ -1491,7 +1485,7 @@ def _find_self_tuning_history(self): for tuning_history in self.tuning_history: # only check if a tune_cfg is evaluated under same yam config, excluding # some fields in tuning section of yaml, such as tensorboard, snapshot, resume. - if self._same_yaml(tuning_history['cfg'], self.cfg): + if self._same_yaml(tuning_history['cfg'], self.conf): return tuning_history return None @@ -1499,24 +1493,48 @@ def _find_self_tuning_history(self): def _add_tuning_history(self, tune_cfg=None, tune_result=None, **kwargs): """Add tuning config to tuining history. + + The tuning history ever made, structured like below: + [ + { + 'version': __version__, + 'cfg': cfg1, + 'framework': tensorflow + 'baseline': baseline1, + 'last_tune_result': last_tune_result1, + 'best_tune_result': best_tune_result1, + 'history': [ + # tuning history under same yaml config + {'tune_cfg': tune_cfg1, 'tune_result': \ + tune_result1, 'q_config': q_config1, ...}, + + ..., + ], + # new fields added by subclass for resuming + ..., + }, + # tuning history under different yaml configs + ..., + ] + Note this record is added under same yaml config. """ found = False d = {'tune_cfg': tune_cfg, 'tune_result': tune_result} for tuning_history in self.tuning_history: - if self._same_yaml(tuning_history['cfg'], self.cfg): + if self._same_yaml(tuning_history['cfg'], self.conf): d.update(kwargs) tuning_history['history'].append(d) tuning_history['last_tune_result'] = self.last_tune_result tuning_history['best_tune_result'] = self.best_tune_result - tuning_history['cfg'] = self.cfg + tuning_history['cfg'] = self.conf found = True break if not found: tuning_history = {} tuning_history['version'] = __version__ - tuning_history['cfg'] = self.cfg + tuning_history['cfg'] = self.conf tuning_history['baseline'] = self.baseline tuning_history['last_tune_result'] = self.last_tune_result tuning_history['best_tune_result'] = self.best_tune_result @@ -1538,15 +1556,15 @@ def _collect_ops_by_quant_mode(self, tune_cfg, quant_mode): def _diagnosis(self): import logging logger = logging.getLogger("neural_compressor") - iteration_list = self.cfg.tuning.diagnosis.iteration_list - inspect_type = self.cfg.tuning.diagnosis.inspect_type - save_to_disk = self.cfg.tuning.diagnosis.save_to_disk - save_path = self.cfg.tuning.diagnosis.save_path + iteration_list = [1] + inspect_type = 'all' + save_to_disk = True + save_path = './nc_workspace/inspect_saved/' inspect_node_lst, updated_cfg = self.adaptor.diagnosis_helper(self._fp32_model, self.last_qmodel, self.tune_cfg, save_path = save_path) - op_list = self.cfg.tuning.diagnosis.op_list + op_list = [] if not op_list: op_list = list(inspect_node_lst) else: diff --git a/neural_compressor/strategy/utils/tuning_space.py b/neural_compressor/strategy/utils/tuning_space.py index 505fdef7a15..07909d7f711 100644 --- a/neural_compressor/strategy/utils/tuning_space.py +++ b/neural_compressor/strategy/utils/tuning_space.py @@ -127,7 +127,7 @@ def __init__(self, capability, conf, framework=None): self.op_type_wise_items = defaultdict(list) # op_type: {(op_name, op_type), ...} self.framework = framework self.ops_dtype = defaultdict(OrderedDict) - usr_cfg = conf.usr_cfg if conf else None + self._usr_cfg = self._init_usr_cfg() self.op_items = {} # {(op_name, op_type): {(path): data type}} self.ops_data_type = OrderedDefaultDict() @@ -135,7 +135,15 @@ def __init__(self, capability, conf, framework=None): # {(op_name, op_type): {path1, path2, ...} self.ops_path_set = defaultdict(set) - self._create_tuning_space(capability, usr_cfg) + self._create_tuning_space(capability, self._usr_cfg) + + def _init_usr_cfg(self): + """Init user config.""" + usr_cfg = {'quantization': {}} + usr_cfg['quantization']['model_wise'] = None + usr_cfg['quantization']['optype_wise'] = self.conf.quantization.op_type_dict + usr_cfg['quantization']['op_wise'] = self.conf.quantization.op_name_dict + return usr_cfg def _parse_capability(self, capability: Dict) -> None: """Parse the capability and construct the tuning space(a tree). diff --git a/test/strategy/test_basic.py b/test/strategy/test_basic.py index b110238c3d1..781a7ee333f 100644 --- a/test/strategy/test_basic.py +++ b/test/strategy/test_basic.py @@ -2,121 +2,6 @@ import numpy as np import unittest import shutil -import os -import yaml - -def build_fake_yaml(): - fake_yaml = ''' - model: - name: fake_yaml - framework: tensorflow - inputs: x - outputs: op2_to_store - device: cpu - evaluation: - accuracy: - metric: - topk: 1 - tuning: - strategy: - name: basic - accuracy_criterion: - relative: 0.01 - workspace: - path: saved - ''' - y = yaml.load(fake_yaml, Loader=yaml.SafeLoader) - with open('fake_yaml.yaml',"w",encoding="utf-8") as f: - yaml.dump(y,f) - f.close() - -def build_fake_yaml2(): - fake_yaml = ''' - model: - name: fake_yaml - framework: tensorflow - inputs: x - outputs: op2_to_store - device: cpu - evaluation: - accuracy: - metric: - topk: 1 - tuning: - strategy: - name: basic - exit_policy: - max_trials: 3 - accuracy_criterion: - relative: -0.01 - workspace: - path: saved - ''' - y = yaml.load(fake_yaml, Loader=yaml.SafeLoader) - with open('fake_yaml2.yaml',"w",encoding="utf-8") as f: - yaml.dump(y,f) - f.close() - -def build_fake_yaml3(): - fake_yaml = ''' - model: - name: fake_yaml - framework: tensorflow - inputs: x - outputs: op2_to_store - device: cpu - evaluation: - accuracy: - multi_metrics: - topk: 1 - MSE: - compare_label: False - tuning: - strategy: - name: basic - exit_policy: - max_trials: 3 - timeout: 50 - accuracy_criterion: - relative: -0.01 - workspace: - path: saved - ''' - y = yaml.load(fake_yaml, Loader=yaml.SafeLoader) - with open('fake_yaml3.yaml',"w",encoding="utf-8") as f: - yaml.dump(y,f) - f.close() - -def build_fake_yaml4(): - fake_yaml = ''' - model: - name: fake_yaml - framework: tensorflow - inputs: x - outputs: op2_to_store - device: cpu - evaluation: - accuracy: - multi_metrics: - topk: 1 - MSE: - compare_label: False - weight: [1, 0] - tuning: - strategy: - name: basic - exit_policy: - max_trials: 3 - timeout: 50 - accuracy_criterion: - relative: -0.01 - workspace: - path: saved - ''' - y = yaml.load(fake_yaml, Loader=yaml.SafeLoader) - with open('fake_yaml4.yaml',"w",encoding="utf-8") as f: - yaml.dump(y,f) - f.close() def build_fake_model(): import tensorflow as tf @@ -160,63 +45,10 @@ class TestBasicTuningStrategy(unittest.TestCase): @classmethod def setUpClass(self): self.constant_graph = build_fake_model() - build_fake_yaml() - build_fake_yaml2() - build_fake_yaml3() - build_fake_yaml4() @classmethod def tearDownClass(self): - os.remove('fake_yaml.yaml') - os.remove('fake_yaml2.yaml') - os.remove('fake_yaml3.yaml') - os.remove('fake_yaml4.yaml') shutil.rmtree('saved', ignore_errors=True) - - def test_run_basic_one_trial(self): - from neural_compressor.experimental import Quantization, common - - quantizer = Quantization('fake_yaml.yaml') - dataset = quantizer.dataset('dummy', (100, 3, 3, 1), label=True) - quantizer.calib_dataloader = common.DataLoader(dataset) - quantizer.eval_dataloader = common.DataLoader(dataset) - quantizer.model = self.constant_graph - quantizer.fit() - - # resume tuning history - quantizer.conf.usr_cfg.tuning.workspace.resume = 'saved/history.snapshot' - quantizer.fit() - - def test_run_basic_max_trials(self): - from neural_compressor.experimental import Quantization, common - - quantizer = Quantization('fake_yaml2.yaml') - dataset = quantizer.dataset('dummy', (100, 3, 3, 1), label=True) - quantizer.calib_dataloader = common.DataLoader(dataset) - quantizer.eval_dataloader = common.DataLoader(dataset) - quantizer.model = self.constant_graph - quantizer.fit() - - def test_run_basic_max_trials_multimetric(self): - from neural_compressor.experimental import Quantization, common - - quantizer = Quantization('fake_yaml3.yaml') - dataset = quantizer.dataset('dummy', (100, 3, 3, 1), label=True) - quantizer.calib_dataloader = common.DataLoader(dataset) - quantizer.eval_dataloader = common.DataLoader(dataset) - quantizer.model = self.constant_graph - quantizer.fit() - - def test_run_basic_max_trials_multimetric_weight(self): - from neural_compressor.experimental import Quantization, common - - quantizer = Quantization('fake_yaml4.yaml') - dataset = quantizer.dataset('dummy', (100, 3, 3, 1), label=True) - quantizer.calib_dataloader = common.DataLoader(dataset) - quantizer.eval_dataloader = common.DataLoader(dataset) - quantizer.model = self.constant_graph - quantizer.fit() - def test_run_basic_one_trial_new_api(self): from neural_compressor.quantization import fit @@ -227,9 +59,12 @@ def test_run_basic_one_trial_new_api(self): dataset = Datasets("tensorflow")["dummy"](((100, 3, 3, 1))) dataloader = DATALOADERS["tensorflow"](dataset) + def fake_eval(model): + return 1 + # tuning and accuracy criterion conf = PostTrainingQuantConfig() - q_model = fit(model=self.constant_graph, conf=conf, calib_dataloader= dataloader, eval_dataloader=dataloader) + q_model = fit(model=self.constant_graph, conf=conf, calib_dataloader= dataloader, eval_func=fake_eval) self.assertIsNotNone(q_model) def test_no_tuning(self): diff --git a/test/strategy/test_basic_1.x.py b/test/strategy/test_basic_1.x.py new file mode 100644 index 00000000000..89b47ffa722 --- /dev/null +++ b/test/strategy/test_basic_1.x.py @@ -0,0 +1,221 @@ +"""Tests for quantization""" +import numpy as np +import unittest +import shutil +import os +import yaml + +def build_fake_yaml(): + fake_yaml = ''' + model: + name: fake_yaml + framework: tensorflow + inputs: x + outputs: op2_to_store + device: cpu + evaluation: + accuracy: + metric: + topk: 1 + tuning: + strategy: + name: basic + accuracy_criterion: + relative: 0.01 + workspace: + path: saved + ''' + y = yaml.load(fake_yaml, Loader=yaml.SafeLoader) + with open('fake_yaml.yaml',"w",encoding="utf-8") as f: + yaml.dump(y,f) + f.close() + +def build_fake_yaml2(): + fake_yaml = ''' + model: + name: fake_yaml + framework: tensorflow + inputs: x + outputs: op2_to_store + device: cpu + evaluation: + accuracy: + metric: + topk: 1 + tuning: + strategy: + name: basic + exit_policy: + max_trials: 3 + accuracy_criterion: + relative: -0.01 + workspace: + path: saved + ''' + y = yaml.load(fake_yaml, Loader=yaml.SafeLoader) + with open('fake_yaml2.yaml',"w",encoding="utf-8") as f: + yaml.dump(y,f) + f.close() + +def build_fake_yaml3(): + fake_yaml = ''' + model: + name: fake_yaml + framework: tensorflow + inputs: x + outputs: op2_to_store + device: cpu + evaluation: + accuracy: + multi_metrics: + topk: 1 + MSE: + compare_label: False + tuning: + strategy: + name: basic + exit_policy: + max_trials: 3 + timeout: 50 + accuracy_criterion: + relative: -0.01 + workspace: + path: saved + ''' + y = yaml.load(fake_yaml, Loader=yaml.SafeLoader) + with open('fake_yaml3.yaml',"w",encoding="utf-8") as f: + yaml.dump(y,f) + f.close() + +def build_fake_yaml4(): + fake_yaml = ''' + model: + name: fake_yaml + framework: tensorflow + inputs: x + outputs: op2_to_store + device: cpu + evaluation: + accuracy: + multi_metrics: + topk: 1 + MSE: + compare_label: False + weight: [1, 0] + tuning: + strategy: + name: basic + exit_policy: + max_trials: 3 + timeout: 50 + accuracy_criterion: + relative: -0.01 + workspace: + path: saved + ''' + y = yaml.load(fake_yaml, Loader=yaml.SafeLoader) + with open('fake_yaml4.yaml',"w",encoding="utf-8") as f: + yaml.dump(y,f) + f.close() + +def build_fake_model(): + import tensorflow as tf + try: + graph = tf.Graph() + graph_def = tf.compat.v1.GraphDef() + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, shape=(1,3,3,1), name='x') + y = tf.constant(np.random.random((2,2,1,1)).astype(np.float32), name='y') + z = tf.constant(np.random.random((1,1,1,1)).astype(np.float32), name='z') + op = tf.nn.conv2d(input=x, filters=y, strides=[1,1,1,1], padding='VALID', name='op_to_store') + op2 = tf.nn.conv2d(input=op, filters=z, strides=[1,1,1,1], padding='VALID', ) + last_identity = tf.identity(op2, name='op2_to_store') + sess.run(tf.compat.v1.global_variables_initializer()) + constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op2_to_store']) + + graph_def.ParseFromString(constant_graph.SerializeToString()) + with graph.as_default(): + tf.import_graph_def(graph_def, name='') + except: + graph = tf.Graph() + graph_def = tf.compat.v1.GraphDef() + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, shape=(1,3,3,1), name='x') + y = tf.constant(np.random.random((2,2,1,1)).astype(np.float32), name='y') + z = tf.constant(np.random.random((1,1,1,1)).astype(np.float32), name='z') + op = tf.nn.conv2d(input=x, filters=y, strides=[1,1,1,1], padding='VALID', name='op_to_store') + op2 = tf.nn.conv2d(input=op, filters=z, strides=[1,1,1,1], padding='VALID') + last_identity = tf.identity(op2, name='op2_to_store') + + sess.run(tf.compat.v1.global_variables_initializer()) + constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op2_to_store']) + + graph_def.ParseFromString(constant_graph.SerializeToString()) + with graph.as_default(): + tf.import_graph_def(graph_def, name='') + return graph + +class TestBasicTuningStrategy(unittest.TestCase): + + @classmethod + def setUpClass(self): + self.constant_graph = build_fake_model() + build_fake_yaml() + build_fake_yaml2() + build_fake_yaml3() + build_fake_yaml4() + + @classmethod + def tearDownClass(self): + os.remove('fake_yaml.yaml') + os.remove('fake_yaml2.yaml') + os.remove('fake_yaml3.yaml') + os.remove('fake_yaml4.yaml') + shutil.rmtree('saved', ignore_errors=True) + + def test_run_basic_one_trial(self): + from neural_compressor.experimental import Quantization, common + + quantizer = Quantization('fake_yaml.yaml') + dataset = quantizer.dataset('dummy', (100, 3, 3, 1), label=True) + quantizer.calib_dataloader = common.DataLoader(dataset) + quantizer.eval_dataloader = common.DataLoader(dataset) + quantizer.model = self.constant_graph + quantizer.fit() + + # resume tuning history + quantizer.conf.usr_cfg.tuning.workspace.resume = 'saved/history.snapshot' + quantizer.fit() + + def test_run_basic_max_trials(self): + from neural_compressor.experimental import Quantization, common + + quantizer = Quantization('fake_yaml2.yaml') + dataset = quantizer.dataset('dummy', (100, 3, 3, 1), label=True) + quantizer.calib_dataloader = common.DataLoader(dataset) + quantizer.eval_dataloader = common.DataLoader(dataset) + quantizer.model = self.constant_graph + quantizer.fit() + + def test_run_basic_max_trials_multimetric(self): + from neural_compressor.experimental import Quantization, common + + quantizer = Quantization('fake_yaml3.yaml') + dataset = quantizer.dataset('dummy', (100, 3, 3, 1), label=True) + quantizer.calib_dataloader = common.DataLoader(dataset) + quantizer.eval_dataloader = common.DataLoader(dataset) + quantizer.model = self.constant_graph + quantizer.fit() + + def test_run_basic_max_trials_multimetric_weight(self): + from neural_compressor.experimental import Quantization, common + + quantizer = Quantization('fake_yaml4.yaml') + dataset = quantizer.dataset('dummy', (100, 3, 3, 1), label=True) + quantizer.calib_dataloader = common.DataLoader(dataset) + quantizer.eval_dataloader = common.DataLoader(dataset) + quantizer.model = self.constant_graph + quantizer.fit() + +if __name__ == "__main__": + unittest.main()