diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 0a44fe2f5a3..2a5f62af196 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -1029,7 +1029,7 @@ def _get_quantizable_ops(self, model): # get bf16 capability - if (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \ + if self.use_bf16 and (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \ (self.version.release >= Version("1.11.0").release): self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16") bf16_ops = [] @@ -1308,19 +1308,34 @@ def _pre_hook_for_qat(self, dataloader=None): qscheme=torch.per_tensor_affine, reduce_range=REDUCE_RANGE), weight=torch.quantization.default_weight_fake_quant) + self.non_quant_dict = self.get_non_quant_modules(self.model.kwargs) + quantizable_ops = [] + self._get_quantizable_ops_recursively(self.model._model, '', quantizable_ops) + self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16") + bf16_ops = [] + if self.version.release >= Version("1.11.0").release and self.use_bf16 and \ + (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover + self._get_bf16_ops_recursively(self.model._model, '', bf16_ops) + bf16_ops_list = [(op) for op in bf16_ops if op not in quantizable_ops] self.model.model.training = True torch.quantization.prepare_qat(self.model._model, inplace=True) - def _post_hook_for_qat(self): - torch.quantization.convert(self.model._model, inplace=True) # This is a flag for reloading self.model.q_config = { 'is_oneshot': True, 'framework': 'pytorch', 'reduce_range': REDUCE_RANGE, - 'approach': 'quant_aware_training' + 'approach': 'quant_aware_training', + 'bf16_ops_list': bf16_ops_list, } + def _post_hook_for_qat(self): + torch.quantization.convert(self.model._model, inplace=True) + if len(self.model.q_config['bf16_ops_list']) > 0 and \ + self.version.release >= Version("1.11.0").release and self.use_bf16 and \ + (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover + self.model._model = torch_utils.bf16_convert.Convert(self.model._model, self.model.q_config) + def _pre_hook_for_hvd(self, dataloader=None): # TODO: lazy init here hvd.init() @@ -2220,7 +2235,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): self.model_calibration(q_model, dataloader, iterations, None, tune_cfg.get('calib_sampling_size', 1)) q_model.save_qconf_summary(qconf_summary=self.ipex_config_path) - if self.use_bf16: + if self.use_bf16 and (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \ + (self.version.release >= Version("1.11.0").release): with torch.no_grad(): with torch.cpu.amp.autocast(): q_model = ipex.quantization.convert(q_model) @@ -2487,7 +2503,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): if isinstance(self.q_dataloader, BaseDataLoader): self.q_dataloader.batch(batch_size) logger.info('Recovery `calibration.dataloader.batchsize` {} according \ - to config.yaml'.format(batch_size)) + to config.yaml' .format(batch_size)) del init_model with open(self.ipex_config_path, 'r') as f: self.cfgs = json.load(f) @@ -2773,7 +2789,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): q_model._model, prefix='') if len(self.tune_cfg['bf16_ops_list']) > 0 and \ - self.version.release >= Version("1.11.0").release and \ + self.version.release >= Version("1.11.0").release and self.use_bf16 and \ (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg) @@ -2843,6 +2859,12 @@ def _pre_hook_for_qat(self, dataloader=None): quantizable_ops = [] tmp_model = self.fuse_fx_model(self.model, is_qat=True) self._get_quantizable_ops_recursively(tmp_model, '', quantizable_ops) + self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16") + bf16_ops = [] + if self.version.release >= Version("1.11.0").release and self.use_bf16 and \ + (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover + self._get_bf16_ops_recursively(tmp_model, '', bf16_ops) + bf16_ops_list = [(op) for op in bf16_ops if op not in quantizable_ops] quantized_ops = OrderedDict() for op in quantizable_ops: if op[1] in [ @@ -2851,7 +2873,7 @@ def _pre_hook_for_qat(self, dataloader=None): quantized_ops[op[0]] = torch.quantization.default_dynamic_qconfig else: quantized_ops[op[0]] = q_cfgs - # build for fetching scale and zeropoint + # build for fetching scale and zeropoint op_config_dict = {} for op in quantizable_ops: op_config_dict[op] = {'weight': {'dtype': 'int8'}, 'activation': {'dtype': 'uint8'}} @@ -2901,6 +2923,7 @@ def _pre_hook_for_qat(self, dataloader=None): 'framework': 'pytorch_fx', 'reduce_range': REDUCE_RANGE, 'quantizable_ops': quantizable_ops, + 'bf16_ops_list': bf16_ops_list, 'op': op_config_dict, 'sub_module_list': self.sub_module_list, 'approach': 'quant_aware_training' @@ -2926,6 +2949,10 @@ def _post_hook_for_qat(self): if self.approach != 'post_training_dynamic_quant': self._get_scale_zeropoint(self.model._model, self.model.q_config) + if len(self.model.q_config['bf16_ops_list']) > 0 and \ + self.version.release >= Version("1.11.0").release and self.use_bf16 and \ + (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover + self.model._model = torch_utils.bf16_convert.Convert(self.model._model, self.model.q_config) self._dump_model_op_stats(self.model._model, self.model.q_config, self.approach) torch_utils.util.get_embedding_contiguous(self.model._model) @@ -3102,7 +3129,7 @@ def _dump_model_op_stats(self, model, tune_cfg, approach): res = dict() self._get_sub_module_op_stats(model, tune_cfg, approach, res) - if (self.version.release >= Version("1.11.0").release) and \ + if self.use_bf16 and (self.version.release >= Version("1.11.0").release) and \ (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover bf16_ops_list = tune_cfg['bf16_ops_list'] if len(bf16_ops_list) > 0: