Skip to content

Commit 162611b

Browse files
committed
fixed some bugs
Signed-off-by: yiliu30 <[email protected]>
1 parent 83018ef commit 162611b

File tree

1 file changed

+26
-3
lines changed
  • neural_compressor/adaptor/torch_utils

1 file changed

+26
-3
lines changed

neural_compressor/adaptor/torch_utils/util.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,10 @@ def output_hook(self, input, output):
609609
from torch.quantization.quantize_fx import prepare_fx,convert_fx
610610
# do quantization
611611
if adaptor.sub_module_list is None:
612-
tmp_model = prepare_fx(tmp_model, fx_op_cfgs,)
612+
if _torch_version_greater_than("1.13.0"):
613+
tmp_model = prepare_fx(tmp_model, fx_op_cfgs, example_inputs=None)
614+
else:
615+
tmp_model = prepare_fx(tmp_model, fx_op_cfgs,)
613616
else:
614617
PyTorch_FXAdaptor.prepare_sub_graph(adaptor.sub_module_list, fx_op_cfgs, \
615618
tmp_model, prefix='')
@@ -658,7 +661,10 @@ def output_hook(self, input, output):
658661
from torch.quantization.quantize_fx import prepare_fx,convert_fx
659662
# do quantization
660663
if adaptor.sub_module_list is None:
661-
tmp_model = prepare_fx(tmp_model, fx_op_cfgs,)
664+
if _torch_version_greater_than("1.13.0"):
665+
tmp_model = prepare_fx(tmp_model, fx_op_cfgs, example_inputs=None)
666+
else:
667+
tmp_model = prepare_fx(tmp_model, fx_op_cfgs,)
662668
else:
663669
PyTorch_FXAdaptor.prepare_sub_graph(adaptor.sub_module_list, fx_op_cfgs, \
664670
tmp_model, prefix='')
@@ -722,7 +728,10 @@ def output_hook(self, input, output):
722728
from torch.quantization.quantize_fx import prepare_fx,convert_fx
723729
# do quantization
724730
if adaptor.sub_module_list is None:
725-
tmp_model = prepare_fx(tmp_model, fx_op_cfgs,)
731+
if _torch_version_greater_than("1.13.0"):
732+
tmp_model = prepare_fx(tmp_model, fx_op_cfgs,example_inputs=None)
733+
else:
734+
tmp_model = prepare_fx(tmp_model, fx_op_cfgs,)
726735
else:
727736
PyTorch_FXAdaptor.prepare_sub_graph(adaptor.sub_module_list, fx_op_cfgs, \
728737
tmp_model, prefix='')
@@ -750,3 +759,17 @@ def output_hook(self, input, output):
750759
ordered_ops = sorted(fallback_order.keys(), key=lambda key: fallback_order[key], \
751760
reverse=False)
752761
return ordered_ops
762+
763+
def get_torch_version():
764+
from packaging.version import Version
765+
try:
766+
torch_version = torch.__version__.split('+')[0]
767+
except ValueError as e: # pragma: no cover
768+
assert False, 'Got an unknown version of torch: {}'.format(e)
769+
version = Version(torch_version)
770+
return version
771+
772+
def _torch_version_greater_than(version_number : str):
773+
from packaging.version import Version
774+
torch_version = get_torch_version()
775+
return torch_version.release >= Version(version_number).release # pragma: no cover

0 commit comments

Comments
 (0)