@@ -1029,7 +1029,7 @@ def _get_quantizable_ops(self, model):
10291029
10301030
10311031 # get bf16 capability
1032- if (CpuInfo ().bf16 or os .getenv ('FORCE_BF16' ) == '1' ) and \
1032+ if self . use_bf16 and (CpuInfo ().bf16 or os .getenv ('FORCE_BF16' ) == '1' ) and \
10331033 (self .version .release >= Version ("1.11.0" ).release ):
10341034 self .bf16_ops = self .query_handler .get_op_types_by_precision ("bf16" )
10351035 bf16_ops = []
@@ -1308,19 +1308,34 @@ def _pre_hook_for_qat(self, dataloader=None):
13081308 qscheme = torch .per_tensor_affine ,
13091309 reduce_range = REDUCE_RANGE ),
13101310 weight = torch .quantization .default_weight_fake_quant )
1311+ self .non_quant_dict = self .get_non_quant_modules (self .model .kwargs )
1312+ quantizable_ops = []
1313+ self ._get_quantizable_ops_recursively (self .model ._model , '' , quantizable_ops )
1314+ self .bf16_ops = self .query_handler .get_op_types_by_precision ("bf16" )
1315+ bf16_ops = []
1316+ if self .version .release >= Version ("1.11.0" ).release and self .use_bf16 and \
1317+ (CpuInfo ().bf16 or os .getenv ('FORCE_BF16' ) == '1' ): # pragma: no cover
1318+ self ._get_bf16_ops_recursively (self .model ._model , '' , bf16_ops )
1319+ bf16_ops_list = [(op ) for op in bf16_ops if op not in quantizable_ops ]
13111320 self .model .model .training = True
13121321 torch .quantization .prepare_qat (self .model ._model , inplace = True )
13131322
1314- def _post_hook_for_qat (self ):
1315- torch .quantization .convert (self .model ._model , inplace = True )
13161323 # This is a flag for reloading
13171324 self .model .q_config = {
13181325 'is_oneshot' : True ,
13191326 'framework' : 'pytorch' ,
13201327 'reduce_range' : REDUCE_RANGE ,
1321- 'approach' : 'quant_aware_training'
1328+ 'approach' : 'quant_aware_training' ,
1329+ 'bf16_ops_list' : bf16_ops_list ,
13221330 }
13231331
1332+ def _post_hook_for_qat (self ):
1333+ torch .quantization .convert (self .model ._model , inplace = True )
1334+ if len (self .model .q_config ['bf16_ops_list' ]) > 0 and \
1335+ self .version .release >= Version ("1.11.0" ).release and self .use_bf16 and \
1336+ (CpuInfo ().bf16 or os .getenv ('FORCE_BF16' ) == '1' ): # pragma: no cover
1337+ self .model ._model = torch_utils .bf16_convert .Convert (self .model ._model , self .model .q_config )
1338+
13241339 def _pre_hook_for_hvd (self , dataloader = None ):
13251340 # TODO: lazy init here
13261341 hvd .init ()
@@ -2220,7 +2235,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
22202235 self .model_calibration (q_model , dataloader , iterations , None ,
22212236 tune_cfg .get ('calib_sampling_size' , 1 ))
22222237 q_model .save_qconf_summary (qconf_summary = self .ipex_config_path )
2223- if self .use_bf16 :
2238+ if self .use_bf16 and (CpuInfo ().bf16 or os .getenv ('FORCE_BF16' ) == '1' ) and \
2239+ (self .version .release >= Version ("1.11.0" ).release ):
22242240 with torch .no_grad ():
22252241 with torch .cpu .amp .autocast ():
22262242 q_model = ipex .quantization .convert (q_model )
@@ -2487,7 +2503,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
24872503 if isinstance (self .q_dataloader , BaseDataLoader ):
24882504 self .q_dataloader .batch (batch_size )
24892505 logger .info ('Recovery `calibration.dataloader.batchsize` {} according \
2490- to config.yaml' .format (batch_size ))
2506+ to config.yaml' .format (batch_size ))
24912507 del init_model
24922508 with open (self .ipex_config_path , 'r' ) as f :
24932509 self .cfgs = json .load (f )
@@ -2773,7 +2789,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
27732789 q_model ._model , prefix = '' )
27742790
27752791 if len (self .tune_cfg ['bf16_ops_list' ]) > 0 and \
2776- self .version .release >= Version ("1.11.0" ).release and \
2792+ self .version .release >= Version ("1.11.0" ).release and self . use_bf16 and \
27772793 (CpuInfo ().bf16 or os .getenv ('FORCE_BF16' ) == '1' ): # pragma: no cover
27782794 q_model ._model = torch_utils .bf16_convert .Convert (q_model ._model , self .tune_cfg )
27792795
@@ -2843,6 +2859,12 @@ def _pre_hook_for_qat(self, dataloader=None):
28432859 quantizable_ops = []
28442860 tmp_model = self .fuse_fx_model (self .model , is_qat = True )
28452861 self ._get_quantizable_ops_recursively (tmp_model , '' , quantizable_ops )
2862+ self .bf16_ops = self .query_handler .get_op_types_by_precision ("bf16" )
2863+ bf16_ops = []
2864+ if self .version .release >= Version ("1.11.0" ).release and self .use_bf16 and \
2865+ (CpuInfo ().bf16 or os .getenv ('FORCE_BF16' ) == '1' ): # pragma: no cover
2866+ self ._get_bf16_ops_recursively (tmp_model , '' , bf16_ops )
2867+ bf16_ops_list = [(op ) for op in bf16_ops if op not in quantizable_ops ]
28462868 quantized_ops = OrderedDict ()
28472869 for op in quantizable_ops :
28482870 if op [1 ] in [
@@ -2851,7 +2873,7 @@ def _pre_hook_for_qat(self, dataloader=None):
28512873 quantized_ops [op [0 ]] = torch .quantization .default_dynamic_qconfig
28522874 else :
28532875 quantized_ops [op [0 ]] = q_cfgs
2854- # build for fetching scale and zeropoint
2876+ # build for fetching scale and zeropoint
28552877 op_config_dict = {}
28562878 for op in quantizable_ops :
28572879 op_config_dict [op ] = {'weight' : {'dtype' : 'int8' }, 'activation' : {'dtype' : 'uint8' }}
@@ -2901,6 +2923,7 @@ def _pre_hook_for_qat(self, dataloader=None):
29012923 'framework' : 'pytorch_fx' ,
29022924 'reduce_range' : REDUCE_RANGE ,
29032925 'quantizable_ops' : quantizable_ops ,
2926+ 'bf16_ops_list' : bf16_ops_list ,
29042927 'op' : op_config_dict ,
29052928 'sub_module_list' : self .sub_module_list ,
29062929 'approach' : 'quant_aware_training'
@@ -2926,6 +2949,10 @@ def _post_hook_for_qat(self):
29262949
29272950 if self .approach != 'post_training_dynamic_quant' :
29282951 self ._get_scale_zeropoint (self .model ._model , self .model .q_config )
2952+ if len (self .model .q_config ['bf16_ops_list' ]) > 0 and \
2953+ self .version .release >= Version ("1.11.0" ).release and self .use_bf16 and \
2954+ (CpuInfo ().bf16 or os .getenv ('FORCE_BF16' ) == '1' ): # pragma: no cover
2955+ self .model ._model = torch_utils .bf16_convert .Convert (self .model ._model , self .model .q_config )
29292956 self ._dump_model_op_stats (self .model ._model , self .model .q_config , self .approach )
29302957 torch_utils .util .get_embedding_contiguous (self .model ._model )
29312958
@@ -3102,7 +3129,7 @@ def _dump_model_op_stats(self, model, tune_cfg, approach):
31023129 res = dict ()
31033130 self ._get_sub_module_op_stats (model , tune_cfg , approach , res )
31043131
3105- if (self .version .release >= Version ("1.11.0" ).release ) and \
3132+ if self . use_bf16 and (self .version .release >= Version ("1.11.0" ).release ) and \
31063133 (CpuInfo ().bf16 or os .getenv ('FORCE_BF16' ) == '1' ): # pragma: no cover
31073134 bf16_ops_list = tune_cfg ['bf16_ops_list' ]
31083135 if len (bf16_ops_list ) > 0 :
0 commit comments