|  | 
| 29 | 29 | from ..utils import logger | 
| 30 | 30 | from .query import QueryBackendCapability | 
| 31 | 31 | from ..experimental.data.dataloaders.base_dataloader import BaseDataLoader | 
| 32 |  | -try:  # pragma: no cover | 
| 33 |  | -    import intel_extension_for_pytorch as ipex | 
| 34 |  | -    IPEX = True | 
| 35 |  | -except:  # pragma: no cover | 
| 36 |  | -    IPEX = False | 
| 37 | 32 | 
 | 
| 38 | 33 | 
 | 
| 39 | 34 | torch = LazyImport("torch") | 
| 40 | 35 | json = LazyImport("json") | 
| 41 | 36 | hvd = LazyImport("horovod.torch") | 
| 42 | 37 | torch_utils = LazyImport("neural_compressor.adaptor.torch_utils") | 
|  | 38 | +ipex = LazyImport("intel_extension_for_pytorch") | 
| 43 | 39 | 
 | 
| 44 | 40 | REDUCE_RANGE = False if CpuInfo().vnni else True | 
| 45 | 41 | logger.debug("Reduce range is {}".format(str(REDUCE_RANGE))) | 
| @@ -1033,7 +1029,7 @@ def _get_quantizable_ops(self, model): | 
| 1033 | 1029 | 
 | 
| 1034 | 1030 | 
 | 
| 1035 | 1031 |         # get bf16 capability | 
| 1036 |  | -        if self.use_bf16 and (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \ | 
|  | 1032 | +        if (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \ | 
| 1037 | 1033 |             (self.version.release >= Version("1.11.0").release): | 
| 1038 | 1034 |             self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16") | 
| 1039 | 1035 |             bf16_ops = [] | 
| @@ -2148,8 +2144,6 @@ class PyTorch_IPEXAdaptor(TemplateAdaptor):  # pragma: no cover | 
| 2148 | 2144 |     """ | 
| 2149 | 2145 |     def __init__(self, framework_specific_info): | 
| 2150 | 2146 |         super(PyTorch_IPEXAdaptor, self).__init__(framework_specific_info) | 
| 2151 |  | - | 
| 2152 |  | -        assert IPEX, "Please install intel-extension-for-pytorch." | 
| 2153 | 2147 |         self.version = get_torch_version() | 
| 2154 | 2148 |         query_config_file = "pytorch_ipex.yaml" | 
| 2155 | 2149 |         self.query_handler = PyTorchQuery( | 
| @@ -2226,20 +2220,30 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): | 
| 2226 | 2220 |                     self.model_calibration(q_model, dataloader, iterations, None, | 
| 2227 | 2221 |                                            tune_cfg.get('calib_sampling_size', 1)) | 
| 2228 | 2222 |                 q_model.save_qconf_summary(qconf_summary=self.ipex_config_path) | 
| 2229 |  | -                q_model = ipex.quantization.convert(q_model) | 
| 2230 |  | -            with torch.no_grad(): | 
| 2231 |  | -                try: | 
| 2232 |  | -                    q_model = torch.jit.trace(q_model, example_inputs) | 
| 2233 |  | -                    q_model = torch.jit.freeze(q_model.eval()) | 
| 2234 |  | -                except: | 
| 2235 |  | -                    q_model = torch.jit.trace(q_model, example_inputs, strict=False) | 
| 2236 |  | -                    q_model = torch.jit.freeze(q_model.eval()) | 
|  | 2223 | +                if self.use_bf16: | 
|  | 2224 | +                    with torch.no_grad(): | 
|  | 2225 | +                        with torch.cpu.amp.autocast(): | 
|  | 2226 | +                            q_model = ipex.quantization.convert(q_model) | 
|  | 2227 | +                            try: | 
|  | 2228 | +                                q_model = torch.jit.trace(q_model, example_inputs) | 
|  | 2229 | +                                q_model = torch.jit.freeze(q_model.eval()) | 
|  | 2230 | +                            except: | 
|  | 2231 | +                                q_model = torch.jit.trace(q_model, example_inputs, strict=False) | 
|  | 2232 | +                                q_model = torch.jit.freeze(q_model.eval()) | 
|  | 2233 | +                else: | 
|  | 2234 | +                    with torch.no_grad(): | 
|  | 2235 | +                        try: | 
|  | 2236 | +                            q_model = torch.jit.trace(q_model, example_inputs) | 
|  | 2237 | +                            q_model = torch.jit.freeze(q_model.eval()) | 
|  | 2238 | +                        except: | 
|  | 2239 | +                            q_model = torch.jit.trace(q_model, example_inputs, strict=False) | 
|  | 2240 | +                            q_model = torch.jit.freeze(q_model.eval()) | 
| 2237 | 2241 |                 # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile | 
| 2238 | 2242 |                 # At the 2nd run, the llga pass will be triggered and the model is turned into | 
| 2239 | 2243 |                 # an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph | 
| 2240 | 2244 |                 q_model(*example_inputs) | 
| 2241 | 2245 |                 q_model(*example_inputs) | 
| 2242 |  | -             | 
|  | 2246 | + | 
| 2243 | 2247 |         assert self.approach != 'quant_aware_training', \ | 
| 2244 | 2248 |                 "Intel PyTorch Extension didn't support quantization aware training mode" | 
| 2245 | 2249 |         model_._model = q_model | 
|  | 
0 commit comments