@@ -609,7 +609,10 @@ def output_hook(self, input, output):
609609 from torch .quantization .quantize_fx import prepare_fx ,convert_fx
610610 # do quantization
611611 if adaptor .sub_module_list is None :
612- tmp_model = prepare_fx (tmp_model , fx_op_cfgs ,)
612+ if _torch_version_greater_than ("1.13.0" ):
613+ tmp_model = prepare_fx (tmp_model , fx_op_cfgs , example_inputs = None )
614+ else :
615+ tmp_model = prepare_fx (tmp_model , fx_op_cfgs ,)
613616 else :
614617 PyTorch_FXAdaptor .prepare_sub_graph (adaptor .sub_module_list , fx_op_cfgs , \
615618 tmp_model , prefix = '' )
@@ -658,7 +661,10 @@ def output_hook(self, input, output):
658661 from torch .quantization .quantize_fx import prepare_fx ,convert_fx
659662 # do quantization
660663 if adaptor .sub_module_list is None :
661- tmp_model = prepare_fx (tmp_model , fx_op_cfgs ,)
664+ if _torch_version_greater_than ("1.13.0" ):
665+ tmp_model = prepare_fx (tmp_model , fx_op_cfgs , example_inputs = None )
666+ else :
667+ tmp_model = prepare_fx (tmp_model , fx_op_cfgs ,)
662668 else :
663669 PyTorch_FXAdaptor .prepare_sub_graph (adaptor .sub_module_list , fx_op_cfgs , \
664670 tmp_model , prefix = '' )
@@ -722,7 +728,10 @@ def output_hook(self, input, output):
722728 from torch .quantization .quantize_fx import prepare_fx ,convert_fx
723729 # do quantization
724730 if adaptor .sub_module_list is None :
725- tmp_model = prepare_fx (tmp_model , fx_op_cfgs ,)
731+ if _torch_version_greater_than ("1.13.0" ):
732+ tmp_model = prepare_fx (tmp_model , fx_op_cfgs ,example_inputs = None )
733+ else :
734+ tmp_model = prepare_fx (tmp_model , fx_op_cfgs ,)
726735 else :
727736 PyTorch_FXAdaptor .prepare_sub_graph (adaptor .sub_module_list , fx_op_cfgs , \
728737 tmp_model , prefix = '' )
@@ -750,3 +759,17 @@ def output_hook(self, input, output):
750759 ordered_ops = sorted (fallback_order .keys (), key = lambda key : fallback_order [key ], \
751760 reverse = False )
752761 return ordered_ops
762+
763+ def get_torch_version ():
764+ from packaging .version import Version
765+ try :
766+ torch_version = torch .__version__ .split ('+' )[0 ]
767+ except ValueError as e : # pragma: no cover
768+ assert False , 'Got an unknown version of torch: {}' .format (e )
769+ version = Version (torch_version )
770+ return version
771+
772+ def _torch_version_greater_than (version_number : str ):
773+ from packaging .version import Version
774+ torch_version = get_torch_version ()
775+ return torch_version .release >= Version (version_number ).release # pragma: no cover
0 commit comments