Skip to content

Commit 2a361b8

Browse files
authored
add torch.amp bf16 support for ipex backend (#1497)
1 parent 773bb3c commit 2a361b8

File tree

3 files changed

+37
-17
lines changed

3 files changed

+37
-17
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,13 @@
2929
from ..utils import logger
3030
from .query import QueryBackendCapability
3131
from ..experimental.data.dataloaders.base_dataloader import BaseDataLoader
32-
try: # pragma: no cover
33-
import intel_extension_for_pytorch as ipex
34-
IPEX = True
35-
except: # pragma: no cover
36-
IPEX = False
3732

3833

3934
torch = LazyImport("torch")
4035
json = LazyImport("json")
4136
hvd = LazyImport("horovod.torch")
4237
torch_utils = LazyImport("neural_compressor.adaptor.torch_utils")
38+
ipex = LazyImport("intel_extension_for_pytorch")
4339

4440
REDUCE_RANGE = False if CpuInfo().vnni else True
4541
logger.debug("Reduce range is {}".format(str(REDUCE_RANGE)))
@@ -1033,7 +1029,7 @@ def _get_quantizable_ops(self, model):
10331029

10341030

10351031
# get bf16 capability
1036-
if self.use_bf16 and (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \
1032+
if (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \
10371033
(self.version.release >= Version("1.11.0").release):
10381034
self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16")
10391035
bf16_ops = []
@@ -2148,8 +2144,6 @@ class PyTorch_IPEXAdaptor(TemplateAdaptor): # pragma: no cover
21482144
"""
21492145
def __init__(self, framework_specific_info):
21502146
super(PyTorch_IPEXAdaptor, self).__init__(framework_specific_info)
2151-
2152-
assert IPEX, "Please install intel-extension-for-pytorch."
21532147
self.version = get_torch_version()
21542148
query_config_file = "pytorch_ipex.yaml"
21552149
self.query_handler = PyTorchQuery(
@@ -2226,20 +2220,30 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
22262220
self.model_calibration(q_model, dataloader, iterations, None,
22272221
tune_cfg.get('calib_sampling_size', 1))
22282222
q_model.save_qconf_summary(qconf_summary=self.ipex_config_path)
2229-
q_model = ipex.quantization.convert(q_model)
2230-
with torch.no_grad():
2231-
try:
2232-
q_model = torch.jit.trace(q_model, example_inputs)
2233-
q_model = torch.jit.freeze(q_model.eval())
2234-
except:
2235-
q_model = torch.jit.trace(q_model, example_inputs, strict=False)
2236-
q_model = torch.jit.freeze(q_model.eval())
2223+
if self.use_bf16:
2224+
with torch.no_grad():
2225+
with torch.cpu.amp.autocast():
2226+
q_model = ipex.quantization.convert(q_model)
2227+
try:
2228+
q_model = torch.jit.trace(q_model, example_inputs)
2229+
q_model = torch.jit.freeze(q_model.eval())
2230+
except:
2231+
q_model = torch.jit.trace(q_model, example_inputs, strict=False)
2232+
q_model = torch.jit.freeze(q_model.eval())
2233+
else:
2234+
with torch.no_grad():
2235+
try:
2236+
q_model = torch.jit.trace(q_model, example_inputs)
2237+
q_model = torch.jit.freeze(q_model.eval())
2238+
except:
2239+
q_model = torch.jit.trace(q_model, example_inputs, strict=False)
2240+
q_model = torch.jit.freeze(q_model.eval())
22372241
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
22382242
# At the 2nd run, the llga pass will be triggered and the model is turned into
22392243
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
22402244
q_model(*example_inputs)
22412245
q_model(*example_inputs)
2242-
2246+
22432247
assert self.approach != 'quant_aware_training', \
22442248
"Intel PyTorch Extension didn't support quantization aware training mode"
22452249
model_._model = q_model

neural_compressor/adaptor/pytorch_ipex.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
ops:
9191
int8: *ops_default_s8
9292
uint8: *ops_default_s8
93+
bf16: []
9394
fp32: ['*'] # '*' means all op types
9495

9596
capabilities: &1_10_capabilities

test/ipex/test_adaptor_ipex.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,21 @@ def test_copy_prepared_model(self):
135135
copy_model = torch_utils.util.auto_copy(prepared_model)
136136
self.assertTrue(isinstance(copy_model, torch.nn.Module))
137137

138+
139+
def test_bf16(self):
140+
from neural_compressor.experimental import Quantization
141+
model = M()
142+
qconfig = ipex.quantization.default_static_qconfig
143+
prepared_model = ipex.quantization.prepare(model, qconfig, example_inputs=torch.ones(1, 3, 224, 224), inplace=False)
144+
config.quantization.use_bf16 = True
145+
config.quantization.performance_only = True
146+
quantizer = Quantization(config)
147+
dataset = quantizer.dataset('dummy', (100, 3, 224, 224), label=True)
148+
dataloader = torch.utils.data.DataLoader(dataset)
149+
quantizer.model = model
150+
quantizer.calib_dataloader = dataloader
151+
quantizer.eval_dataloader = dataloader
152+
nc_model = quantizer.fit()
138153

139154
if __name__ == "__main__":
140155
unittest.main()

0 commit comments

Comments
 (0)