Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions docs/source/mixed_precision.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Supported precisions for mix precision include bf16 and fp16. If users want to g
from neural_compressor import mix_precision
from neural_compressor.config import MixedPrecisionConfig

conf = MixedPrecisionConfig(excluded_precisions=['fp16'])
conf = MixedPrecisionConfig(precision='bf16')
converted_model = mix_precision.fit(model, config=conf)
converted_model.save('./path/to/save/')
```
Expand All @@ -56,7 +56,7 @@ from neural_compressor.config import MixedPrecisionConfig
conf = MixedPrecisionConfig(
backend='onnxrt_cuda_ep',
device='gpu',
excluded_precisions=['bf16'])
precision='fp16')
converted_model = mix_precision.fit(model, config=conf)
converted_model.save('./path/to/save/')
```
Expand All @@ -66,17 +66,29 @@ converted_model.save('./path/to/save/')

## Examples

There are some pre-requirements to run mixed precision examples for each framework. If the hardware requirements cannot be met, the program would exit consequently.

- BF16:

There are 2 pre-requirements to run BF16 mixed precision examples:

### TensorFlow

1. Hardware: CPU supports `avx512_bf16` instruction set.
2. Software: intel-tensorflow >= [2.3.0](https://pypi.org/project/intel-tensorflow/2.3.0/) or torch >= [1.11.0](https://download.pytorch.org/whl/torch_stable.html).
2. Software: intel-tensorflow >= [2.3.0](https://pypi.org/project/intel-tensorflow/2.3.0/).

If either pre-requirement can't be met, the program would exit consequently.
### PyTorch

1. Hardware: CPU supports `avx512_bf16` instruction set.
2. Software: torch >= [1.11.0](https://download.pytorch.org/whl/torch_stable.html).

### ONNX Runtime

1. Hardware: GPU, set 'device' of config to 'gpu' and 'backend' to 'onnxrt_cuda_ep'.
2. Software: onnxruntime-gpu.

- FP16

Currently Intel® Neural Compressor only support FP16 mixed precision for ONNX models.

To run FP16 mixed precision examples, users need to set 'device' of config to 'gpu' and 'backend' to 'onnxrt_cuda_ep'.
### ONNX Runtime

1. Hardware: GPU, set 'device' of config to 'gpu' and 'backend' to 'onnxrt_cuda_ep'.
2. Software: onnxruntime-gpu.
4 changes: 3 additions & 1 deletion neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,9 @@ def query_fw_capability(self, model):
precisions = query.get_precisions()

for precision in precisions:
if precision == 'fp16' and self.device == 'cpu':
if precision in ['fp16', 'bf16'] and (self.device == 'cpu' or self.backend != 'CUDAExecutionProvider'):
continue
elif precision == 'bf16' and 'CUDAExecutionProvider' not in ort.get_available_providers():
continue
# get supported optype for target precision
optypes = query.get_op_types_by_precision(precision) if \
Expand Down
17 changes: 17 additions & 0 deletions neural_compressor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,7 @@ class MixedPrecisionConfig(PostTrainingQuantConfig):
def __init__(self,
device="cpu",
backend="default",
precision="bf16",
inputs=[],
outputs=[],
tuning_criterion=tuning_criterion,
Expand All @@ -1214,7 +1215,23 @@ def __init__(self,
accuracy_criterion=accuracy_criterion,
excluded_precisions=excluded_precisions,
)
self.precision = precision

@property
def precision(self):
"""Get precision."""
return self._precision

@precision.setter
def precision(self, precision):
"""Set precision."""
if isinstance(precision, str):
assert precision in ["fp16", "bf16"], "Only support 'fp16' and 'bf16' for mix precision."
self._precision = [precision]
elif isinstance(precision, list):
assert all([i in ["fp16", "bf16"] for i in precision]), "Only " \
"support 'fp16' and 'bf16' for mix precision."
self._precision = precision

class ExportConfig:
"""Config Class for Export."""
Expand Down
22 changes: 18 additions & 4 deletions neural_compressor/mix_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,18 +370,32 @@ def fit(model,
converted_model = mix_precision.fit(model, config=conf)
"""
converter = MixedPrecision(config)
precisions = ["bf16", "fp16", "fp32"]
precisions = list(set(precisions) - set(config.excluded_precisions))
if config.precision in config.excluded_precisions:
logger.warning("Target precision is in excluded_precisions, "\
"please modify precision or excluded_precisions to make it understandable.")
sys.exit(0)
precisions = list(set(config.precision) - set(config.excluded_precisions))
converter.precisions = precisions
if 'bf16' in precisions and not CpuInfo().bf16:
converter.model = model

if ('bf16' in precisions or 'fp16' in precisions) and converter.model.framework() == "onnxruntime":
if config.device == "cpu":
logger.warning("Mix precision exits due to device isn't gpu for onnx models.")
sys.exit(0)
elif config.backend != "onnxrt_cuda_ep":
logger.warning("Mix precision exits due to backend isn't onnxrt_cuda_ep for onnx models.")
sys.exit(0)
elif 'bf16' in precisions and not CpuInfo().bf16 and converter.model.framework() != "onnxruntime":
if os.getenv('FORCE_BF16') == '1':
logger.warning("Mix precision will generate bf16 graph although " \
"the hardware doesn't support bf16 instruction.")
else:
logger.warning("Mix precision exits due to the hardware " \
"doesn't support bf16 instruction.")
sys.exit(0)
converter.model = model
elif 'fp16' in precisions and converter.model.framework() != "onnxruntime":
logger.warning("Currently mix precision only supports fp16 for onnx models.")
sys.exit(0)
if eval_func is not None:
converter.eval_func = eval_func
if eval_dataloader is not None:
Expand Down
4 changes: 2 additions & 2 deletions test/mixed_precision/test_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_on_non_enabled_dtype(self):
output_model = mix_precision.fit(self.onnx_model, conf)
self.assertEqual(cm.exception.code, 0)

conf = MixedPrecisionConfig(excluded_precisions=["fp16"])
conf = MixedPrecisionConfig(precision="fp16")
with self.assertRaises(SystemExit) as cm:
output_model = mix_precision.fit(self.tf_model, conf)
self.assertEqual(cm.exception.code, 0)
Expand Down Expand Up @@ -309,7 +309,7 @@ def test_mixed_precision_with_evaluation(self):
#self.assertTrue(any([i.op_type == 'Cast' for i in output_model.nodes()]))

tuning_criterion = TuningCriterion(max_trials=3, timeout=1000000)
conf = MixedPrecisionConfig(device='gpu', tuning_criterion=tuning_criterion, backend='onnxrt_cuda_ep', excluded_precisions=['bf16'])
conf = MixedPrecisionConfig(device='gpu', tuning_criterion=tuning_criterion, backend='onnxrt_cuda_ep', precision="fp16")
output_model = mix_precision.fit(self.onnx_model,
conf,
eval_dataloader=self.matmul_dataloader,
Expand Down