diff --git a/neural_compressor/utils/pytorch.py b/neural_compressor/utils/pytorch.py index cda525f1cc6..f2267b1ca1a 100644 --- a/neural_compressor/utils/pytorch.py +++ b/neural_compressor/utils/pytorch.py @@ -464,3 +464,50 @@ def load(checkpoint_dir=None, model=None, layer_wise=False, history_cfg=None, ** assert len(mismatch_log.unexpected_keys) == 0, "Loading state_dict failed: {}".format(mismatch_log) util.get_embedding_contiguous(model) return model + + +def recover_model_from_json(model, json_file_path, example_inputs): + """Recover ipex model from JSON file. + + Args: + model (object): fp32 model need to do quantization. + json_file_path (json): configuration JSON file for ipex. + example_inputs (tuple or torch.Tensor or dict): example inputs that will be passed to the ipex function. + + Returns: + (object): quantized model + """ + from ..utils.utility import LazyImport + + ipex = LazyImport("intel_extension_for_pytorch") + from torch.ao.quantization.observer import MinMaxObserver + + qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver()) + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare(model, qconfig, example_kwarg_inputs=example_inputs, inplace=True) + else: + model = ipex.quantization.prepare(model, qconfig, example_inputs=example_inputs, inplace=True) + model.load_qconf_summary(qconf_summary=json_file_path) + model = ipex.quantization.convert(model, inplace=True) + with torch.no_grad(): + try: + if isinstance(example_inputs, dict): + # pylint: disable=E1120,E1123 + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs) + else: + model = torch.jit.trace(model, example_inputs) + model = torch.jit.freeze(model.eval()) + except: + if isinstance(example_inputs, dict): + # pylint: disable=E1120,E1123 + model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False) + else: + model = torch.jit.trace(model, example_inputs, strict=False) + model = torch.jit.freeze(model.eval()) + if isinstance(example_inputs, dict): + model(**example_inputs) + model(**example_inputs) + else: + model(example_inputs) + model(example_inputs) + return model diff --git a/test/algorithm/test_smooth_quant.py b/test/algorithm/test_smooth_quant.py index a6c99e66a69..45ef03af420 100644 --- a/test/algorithm/test_smooth_quant.py +++ b/test/algorithm/test_smooth_quant.py @@ -880,6 +880,26 @@ def calib_func(model): calib_func=calib_func, ) q_model.save("saved") + # test recover_model_from_json + from neural_compressor.utils.pytorch import recover_model_from_json + + tmp_model = copy.deepcopy(fp32_model) + + ipex_model = recover_model_from_json(tmp_model, "./saved/best_configure.json", example_inputs=input_ids) + inc_output = q_model.model(input_ids) + ipex_output = ipex_model(input_ids) + self.assertTrue(torch.allclose(inc_output, ipex_output, atol=1e-05)) + + example_tuple = (input_ids,) + ipex_model = recover_model_from_json(tmp_model, "./saved/best_configure.json", example_inputs=example_tuple) + ipex_output = ipex_model(input_ids) + self.assertTrue(torch.allclose(inc_output, ipex_output, atol=1e-05)) + + example_dict = {"x": input_ids} + ipex_model = recover_model_from_json(tmp_model, "./saved/best_configure.json", example_inputs=example_dict) + ipex_output = ipex_model(input_ids) + self.assertTrue(torch.allclose(inc_output, ipex_output, atol=1e-05)) + # compare ipex and inc quantization with open("saved/best_configure.json", "r") as f: inc_config_json = json.load(f)