Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions neural_compressor/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,3 +464,50 @@ def load(checkpoint_dir=None, model=None, layer_wise=False, history_cfg=None, **
assert len(mismatch_log.unexpected_keys) == 0, "Loading state_dict failed: {}".format(mismatch_log)
util.get_embedding_contiguous(model)
return model


def recover_model_from_json(model, json_file_path, example_inputs):
"""Recover ipex model from JSON file.

Args:
model (object): fp32 model need to do quantization.
json_file_path (json): configuration JSON file for ipex.
example_inputs (tuple or torch.Tensor or dict): example inputs that will be passed to the ipex function.

Returns:
(object): quantized model
"""
from ..utils.utility import LazyImport

ipex = LazyImport("intel_extension_for_pytorch")
from torch.ao.quantization.observer import MinMaxObserver

qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver())
if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(model, qconfig, example_kwarg_inputs=example_inputs, inplace=True)
else:
model = ipex.quantization.prepare(model, qconfig, example_inputs=example_inputs, inplace=True)
model.load_qconf_summary(qconf_summary=json_file_path)
model = ipex.quantization.convert(model, inplace=True)
with torch.no_grad():
try:
if isinstance(example_inputs, dict):
# pylint: disable=E1120,E1123
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
else:
model = torch.jit.trace(model, example_inputs)
model = torch.jit.freeze(model.eval())
except:
if isinstance(example_inputs, dict):
# pylint: disable=E1120,E1123
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
else:
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.freeze(model.eval())
if isinstance(example_inputs, dict):
model(**example_inputs)
model(**example_inputs)
else:
model(example_inputs)
model(example_inputs)
return model
20 changes: 20 additions & 0 deletions test/algorithm/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,26 @@ def calib_func(model):
calib_func=calib_func,
)
q_model.save("saved")
# test recover_model_from_json
from neural_compressor.utils.pytorch import recover_model_from_json

tmp_model = copy.deepcopy(fp32_model)

ipex_model = recover_model_from_json(tmp_model, "./saved/best_configure.json", example_inputs=input_ids)
inc_output = q_model.model(input_ids)
ipex_output = ipex_model(input_ids)
self.assertTrue(torch.allclose(inc_output, ipex_output, atol=1e-05))

example_tuple = (input_ids,)
ipex_model = recover_model_from_json(tmp_model, "./saved/best_configure.json", example_inputs=example_tuple)
ipex_output = ipex_model(input_ids)
self.assertTrue(torch.allclose(inc_output, ipex_output, atol=1e-05))

example_dict = {"x": input_ids}
ipex_model = recover_model_from_json(tmp_model, "./saved/best_configure.json", example_inputs=example_dict)
ipex_output = ipex_model(input_ids)
self.assertTrue(torch.allclose(inc_output, ipex_output, atol=1e-05))

# compare ipex and inc quantization
with open("saved/best_configure.json", "r") as f:
inc_config_json = json.load(f)
Expand Down