@@ -464,3 +464,50 @@ def load(checkpoint_dir=None, model=None, layer_wise=False, history_cfg=None, **
464464 assert len (mismatch_log .unexpected_keys ) == 0 , "Loading state_dict failed: {}" .format (mismatch_log )
465465 util .get_embedding_contiguous (model )
466466 return model
467+
468+
469+ def recover_model_from_json (model , json_file_path , example_inputs ):
470+ """Recover ipex model from JSON file.
471+
472+ Args:
473+ model (object): fp32 model need to do quantization.
474+ json_file_path (json): configuration JSON file for ipex.
475+ example_inputs (tuple or torch.Tensor or dict): example inputs that will be passed to the ipex function.
476+
477+ Returns:
478+ (object): quantized model
479+ """
480+ from ..utils .utility import LazyImport
481+
482+ ipex = LazyImport ("intel_extension_for_pytorch" )
483+ from torch .ao .quantization .observer import MinMaxObserver
484+
485+ qconfig = ipex .quantization .get_smooth_quant_qconfig_mapping (alpha = 0.5 , act_observer = MinMaxObserver ())
486+ if isinstance (example_inputs , dict ):
487+ model = ipex .quantization .prepare (model , qconfig , example_kwarg_inputs = example_inputs , inplace = True )
488+ else :
489+ model = ipex .quantization .prepare (model , qconfig , example_inputs = example_inputs , inplace = True )
490+ model .load_qconf_summary (qconf_summary = json_file_path )
491+ model = ipex .quantization .convert (model , inplace = True )
492+ with torch .no_grad ():
493+ try :
494+ if isinstance (example_inputs , dict ):
495+ # pylint: disable=E1120,E1123
496+ model = torch .jit .trace (model , example_kwarg_inputs = example_inputs )
497+ else :
498+ model = torch .jit .trace (model , example_inputs )
499+ model = torch .jit .freeze (model .eval ())
500+ except :
501+ if isinstance (example_inputs , dict ):
502+ # pylint: disable=E1120,E1123
503+ model = torch .jit .trace (model , example_kwarg_inputs = example_inputs , strict = False , check_trace = False )
504+ else :
505+ model = torch .jit .trace (model , example_inputs , strict = False )
506+ model = torch .jit .freeze (model .eval ())
507+ if isinstance (example_inputs , dict ):
508+ model (** example_inputs )
509+ model (** example_inputs )
510+ else :
511+ model (example_inputs )
512+ model (example_inputs )
513+ return model
0 commit comments