Skip to content

Commit ad6978a

Browse files
PenghuiChengyiliu30
authored andcommitted
Fixed UT error for bf16 op list for QAT mode (#200)
* Fixed UT error for bf16 op list for QAT mode Signed-off-by: Cheng, Penghui <[email protected]> Signed-off-by: yiliu30 <[email protected]>
1 parent 64fa8aa commit ad6978a

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,7 @@ def _get_quantizable_ops(self, model):
10291029

10301030

10311031
# get bf16 capability
1032-
if (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \
1032+
if self.use_bf16 and (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \
10331033
(self.version.release >= Version("1.11.0").release):
10341034
self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16")
10351035
bf16_ops = []
@@ -1308,19 +1308,34 @@ def _pre_hook_for_qat(self, dataloader=None):
13081308
qscheme=torch.per_tensor_affine,
13091309
reduce_range=REDUCE_RANGE),
13101310
weight=torch.quantization.default_weight_fake_quant)
1311+
self.non_quant_dict = self.get_non_quant_modules(self.model.kwargs)
1312+
quantizable_ops = []
1313+
self._get_quantizable_ops_recursively(self.model._model, '', quantizable_ops)
1314+
self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16")
1315+
bf16_ops = []
1316+
if self.version.release >= Version("1.11.0").release and self.use_bf16 and \
1317+
(CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover
1318+
self._get_bf16_ops_recursively(self.model._model, '', bf16_ops)
1319+
bf16_ops_list = [(op) for op in bf16_ops if op not in quantizable_ops]
13111320
self.model.model.training = True
13121321
torch.quantization.prepare_qat(self.model._model, inplace=True)
13131322

1314-
def _post_hook_for_qat(self):
1315-
torch.quantization.convert(self.model._model, inplace=True)
13161323
# This is a flag for reloading
13171324
self.model.q_config = {
13181325
'is_oneshot': True,
13191326
'framework': 'pytorch',
13201327
'reduce_range': REDUCE_RANGE,
1321-
'approach': 'quant_aware_training'
1328+
'approach': 'quant_aware_training',
1329+
'bf16_ops_list': bf16_ops_list,
13221330
}
13231331

1332+
def _post_hook_for_qat(self):
1333+
torch.quantization.convert(self.model._model, inplace=True)
1334+
if len(self.model.q_config['bf16_ops_list']) > 0 and \
1335+
self.version.release >= Version("1.11.0").release and self.use_bf16 and \
1336+
(CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover
1337+
self.model._model = torch_utils.bf16_convert.Convert(self.model._model, self.model.q_config)
1338+
13241339
def _pre_hook_for_hvd(self, dataloader=None):
13251340
# TODO: lazy init here
13261341
hvd.init()
@@ -2220,7 +2235,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
22202235
self.model_calibration(q_model, dataloader, iterations, None,
22212236
tune_cfg.get('calib_sampling_size', 1))
22222237
q_model.save_qconf_summary(qconf_summary=self.ipex_config_path)
2223-
if self.use_bf16:
2238+
if self.use_bf16 and (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \
2239+
(self.version.release >= Version("1.11.0").release):
22242240
with torch.no_grad():
22252241
with torch.cpu.amp.autocast():
22262242
q_model = ipex.quantization.convert(q_model)
@@ -2487,7 +2503,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
24872503
if isinstance(self.q_dataloader, BaseDataLoader):
24882504
self.q_dataloader.batch(batch_size)
24892505
logger.info('Recovery `calibration.dataloader.batchsize` {} according \
2490-
to config.yaml'.format(batch_size))
2506+
to config.yaml' .format(batch_size))
24912507
del init_model
24922508
with open(self.ipex_config_path, 'r') as f:
24932509
self.cfgs = json.load(f)
@@ -2773,7 +2789,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
27732789
q_model._model, prefix='')
27742790

27752791
if len(self.tune_cfg['bf16_ops_list']) > 0 and \
2776-
self.version.release >= Version("1.11.0").release and \
2792+
self.version.release >= Version("1.11.0").release and self.use_bf16 and \
27772793
(CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover
27782794
q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg)
27792795

@@ -2843,6 +2859,12 @@ def _pre_hook_for_qat(self, dataloader=None):
28432859
quantizable_ops = []
28442860
tmp_model = self.fuse_fx_model(self.model, is_qat=True)
28452861
self._get_quantizable_ops_recursively(tmp_model, '', quantizable_ops)
2862+
self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16")
2863+
bf16_ops = []
2864+
if self.version.release >= Version("1.11.0").release and self.use_bf16 and \
2865+
(CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover
2866+
self._get_bf16_ops_recursively(tmp_model, '', bf16_ops)
2867+
bf16_ops_list = [(op) for op in bf16_ops if op not in quantizable_ops]
28462868
quantized_ops = OrderedDict()
28472869
for op in quantizable_ops:
28482870
if op[1] in [
@@ -2851,7 +2873,7 @@ def _pre_hook_for_qat(self, dataloader=None):
28512873
quantized_ops[op[0]] = torch.quantization.default_dynamic_qconfig
28522874
else:
28532875
quantized_ops[op[0]] = q_cfgs
2854-
# build for fetching scale and zeropoint
2876+
# build for fetching scale and zeropoint
28552877
op_config_dict = {}
28562878
for op in quantizable_ops:
28572879
op_config_dict[op] = {'weight': {'dtype': 'int8'}, 'activation': {'dtype': 'uint8'}}
@@ -2901,6 +2923,7 @@ def _pre_hook_for_qat(self, dataloader=None):
29012923
'framework': 'pytorch_fx',
29022924
'reduce_range': REDUCE_RANGE,
29032925
'quantizable_ops': quantizable_ops,
2926+
'bf16_ops_list': bf16_ops_list,
29042927
'op': op_config_dict,
29052928
'sub_module_list': self.sub_module_list,
29062929
'approach': 'quant_aware_training'
@@ -2926,6 +2949,10 @@ def _post_hook_for_qat(self):
29262949

29272950
if self.approach != 'post_training_dynamic_quant':
29282951
self._get_scale_zeropoint(self.model._model, self.model.q_config)
2952+
if len(self.model.q_config['bf16_ops_list']) > 0 and \
2953+
self.version.release >= Version("1.11.0").release and self.use_bf16 and \
2954+
(CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover
2955+
self.model._model = torch_utils.bf16_convert.Convert(self.model._model, self.model.q_config)
29292956
self._dump_model_op_stats(self.model._model, self.model.q_config, self.approach)
29302957
torch_utils.util.get_embedding_contiguous(self.model._model)
29312958

@@ -3102,7 +3129,7 @@ def _dump_model_op_stats(self, model, tune_cfg, approach):
31023129
res = dict()
31033130
self._get_sub_module_op_stats(model, tune_cfg, approach, res)
31043131

3105-
if (self.version.release >= Version("1.11.0").release) and \
3132+
if self.use_bf16 and (self.version.release >= Version("1.11.0").release) and \
31063133
(CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover
31073134
bf16_ops_list = tune_cfg['bf16_ops_list']
31083135
if len(bf16_ops_list) > 0:

0 commit comments

Comments
 (0)