Skip to content

Commit 108c245

Browse files
authored
Support FP16/BF16 for onnxrt adaptor (#273)
Signed-off-by: Mengni Wang <[email protected]>
1 parent ba42d00 commit 108c245

File tree

30 files changed

+283
-125
lines changed

30 files changed

+283
-125
lines changed

.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ entrypoint
554554
enum
555555
env
556556
environ
557+
ep
557558
eq
558559
erf
559560
Erf

docs/source/mixed_precision.md

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,63 @@ The recently launched 3rd Gen Intel® Xeon® Scalable processor (codenamed Coope
2020

2121
## Mixed Precision Support Matrix
2222

23-
|Framework |BF16 |
24-
|--------------|:-----------:|
25-
|TensorFlow |&#10004; |
26-
|PyTorch |&#10004; |
27-
|ONNX |plan to support in the future |
28-
|MXNet |&#10004; |
23+
|Framework |BF16 |FP16 |
24+
|--------------|:-----------:|:-----------:|
25+
|TensorFlow |&#10004; |:x: |
26+
|PyTorch |&#10004; |:x: |
27+
|ONNX Runtime |&#10004; |&#10004; |
28+
|MXNet |&#10004; |:x: |
2929

30-
> **During quantization, BF16 conversion is default enabled. Please refer to this [document](./quantization_mixed_precision.md) for its workflow.**
30+
> **During quantization, BF16 conversion is default enabled, FP16 can be executed if 'device' of config is 'gpu'. Please refer to this [document](./quantization_mixed_precision.md) for its workflow.**
3131
3232
## Get Started with Mixed Precision API
3333

34-
To get a bf16 model, users can use the Mixed Precision API as follows.
34+
To get a bf16/fp16 model, users can use the Mixed Precision API as follows.
3535

3636

37+
Supported precisions for mix precision include bf16 and fp16. If users want to get a pure fp16 or bf16 model, they should add another precision into excluded_precisions.
38+
39+
- BF16:
40+
3741
```python
3842
from neural_compressor import mix_precision
3943
from neural_compressor.config import MixedPrecisionConfig
4044

41-
conf = MixedPrecisionConfig()
45+
conf = MixedPrecisionConfig(excluded_precisions=['fp16'])
46+
converted_model = mix_precision.fit(model, config=conf)
47+
converted_model.save('./path/to/save/')
48+
```
49+
50+
- FP16:
4251

52+
```python
53+
from neural_compressor import mix_precision
54+
from neural_compressor.config import MixedPrecisionConfig
55+
56+
conf = MixedPrecisionConfig(
57+
backend='onnxrt_cuda_ep',
58+
device='gpu',
59+
excluded_precisions=['bf16'])
4360
converted_model = mix_precision.fit(model, config=conf)
4461
converted_model.save('./path/to/save/')
4562
```
4663

47-
> **BF16 conversion may lead to accuracy drop. Intel® Neural Compressor provides an accuracy-aware tuning function to reduce accuracy loss, which will fallback converted ops to FP32 automatically to get better accuracy. To enable this function, users only need to provide an evaluation function (or dataloader + metric).**
64+
> **BF16/FP16 conversion may lead to accuracy drop. Intel® Neural Compressor provides an accuracy-aware tuning function to reduce accuracy loss, which will fallback converted ops to FP32 automatically to get better accuracy. To enable this function, users only need to provide an evaluation function (or dataloader + metric).**
4865
4966

5067
## Examples
5168

52-
There are 2 pre-requirements to run BF16 mixed precision examples:
69+
- BF16:
70+
71+
There are 2 pre-requirements to run BF16 mixed precision examples:
72+
73+
1. Hardware: CPU supports `avx512_bf16` instruction set.
74+
2. Software: intel-tensorflow >= [2.3.0](https://pypi.org/project/intel-tensorflow/2.3.0/) or torch >= [1.11.0](https://download.pytorch.org/whl/torch_stable.html).
75+
76+
If either pre-requirement can't be met, the program would exit consequently.
5377

54-
- Hardware: CPU supports `avx512_bf16` instruction set.
55-
- Software: intel-tensorflow >= [2.3.0](https://pypi.org/project/intel-tensorflow/2.3.0/) or torch >= [1.11.0](https://download.pytorch.org/whl/torch_stable.html).
78+
- FP16
5679

57-
If either pre-requirement can't be met, the program would exit consequently.
80+
Currently Intel® Neural Compressor only support FP16 mixed precision for ONNX models.
81+
82+
To run FP16 mixed precision examples, users need to set 'device' of config to 'gpu' and 'backend' to 'onnxrt_cuda_ep'.

neural_compressor/adaptor/onnxrt.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@ def __init__(self, framework_specific_info):
8686
logger.warning("Dynamic approach doesn't support QDQ format.")
8787

8888
# get quantization config file according to backend
89+
config_file = None
8990
if self.backend == 'CPUExecutionProvider':
9091
config_file = 'onnxrt.yaml'
9192
elif self.backend == 'TensorrtExecutionProvider':
9293
config_file = 'onnxrt_trt.yaml'
9394
elif self.backend == 'CUDAExecutionProvider':
94-
config_file == 'onnxrt_cuda.yaml'
95+
config_file = 'onnxrt_cuda.yaml'
9596
else: # pragma: no cover
9697
assert False, "{} provider is not supported in current environment, " \
9798
"supported providers: {}".format(self.backend,
@@ -128,6 +129,8 @@ def __init__(self, framework_specific_info):
128129

129130
for precision in self.query_handler.get_precisions():
130131
if precision != 'fp32':
132+
if self.device == 'cpu' and precision == 'fp16':
133+
continue
131134
self.quantizable_op_types += \
132135
self.query_handler.get_op_types_by_precision(precision=precision)
133136

@@ -930,6 +933,8 @@ def query_fw_capability(self, model):
930933
precisions = query.get_precisions()
931934

932935
for precision in precisions:
936+
if precision == 'fp16' and self.device == 'cpu':
937+
continue
933938
# get supported optype for target precision
934939
optypes = query.get_op_types_by_precision(precision) if \
935940
query.get_op_types_by_precision(precision) != ['*'] else \
@@ -1046,7 +1051,7 @@ def query_fw_capability(self, model):
10461051
else: # pragma: no cover
10471052
op_wise.update(
10481053
{(node.name, node.op_type): copy.deepcopy(optype_wise[node.op_type])})
1049-
1054+
10501055
return {'optypewise': optype_wise, 'opwise': op_wise}
10511056

10521057
def _optypewise_filter_for_qdq(self, optype_wise):
@@ -1411,12 +1416,17 @@ def _compare(version1, version2):
14111416
config['capabilities'] = {}
14121417

14131418
# generate other config content including precisions and ops
1414-
precisions = [key for key in config['capabilities'].keys()]
1419+
precisions = list(version_config.keys() - {'version', 'recipes'})
14151420
if 'fp32' not in precisions:
14161421
precisions.append('fp32')
14171422
config['precisions'] = {'names': ','.join(precisions)}
14181423

14191424
op_types = {}
1425+
for precision in precisions:
1426+
if precision in config['capabilities']:
1427+
op_types[precision] = [op_type for op_type in config['capabilities'][precision].keys()]
1428+
elif precision in version_config:
1429+
op_types[precision] = version_config[precision]
14201430
for precision, precision_config in config['capabilities'].items():
14211431
op_types[precision] = [op_type for op_type in precision_config.keys()]
14221432
if 'fp32' not in op_types:
@@ -1485,4 +1495,4 @@ def get_fallback_list(self):
14851495

14861496
def get_specific_cfg_version(self):
14871497
"""Get version of the specific config."""
1488-
return self.config_version
1498+
return self.config_version

neural_compressor/adaptor/onnxrt_cuda.yaml

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@
9797
'LSTM': *default_dynamic,
9898
}
9999
}
100+
fp16: &common_fp16 ['Concat', 'Gather', 'Reshape', 'Squeeze', 'Transpose', 'Unsqueeze',
101+
'EmbedLayerNormalization', 'Attention', 'Split', 'Sigmoid', 'Relu', 'Mul', 'Pad', 'MaxPool',
102+
'MatMul', 'LeakyRelu', 'GlobalAveragePool', 'Gemm', 'Conv', 'AveragePool', 'Add', 'Clip']
103+
bf16: &common_bf16 ['Concat', 'Gather', 'Reshape', 'Squeeze', 'Transpose', 'Unsqueeze',
104+
'Split', 'Sigmoid', 'Relu', 'Mul', 'MatMul', 'Gemm', 'Add']
100105
recipes: &default_optimization
101106
graph_optimization: # from onnxruntime graph_optimization_level
102107
level: ['DISABLE_ALL', 'ENABLE_BASIC', 'ENABLE_EXTENDED', 'ENABLE_ALL']
@@ -137,6 +142,8 @@
137142
},
138143
'dynamic': *ref_1_6_dynamic
139144
}
145+
fp16: *common_fp16
146+
bf16: *common_bf16
140147
recipes:
141148
<<: *default_optimization
142149

@@ -204,6 +211,8 @@
204211
'LSTM': *default_dynamic,
205212
}
206213
}
214+
fp16: *common_fp16
215+
bf16: *common_bf16
207216
recipes:
208217
<<: *default_optimization
209218

@@ -278,6 +287,8 @@
278287
'LSTM': *default_dynamic,
279288
}
280289
}
290+
fp16: *common_fp16
291+
bf16: *common_bf16
281292
recipes:
282293
<<: *default_optimization
283294

@@ -332,6 +343,8 @@
332343
},
333344
'dynamic': *ref_1_9_dynamic
334345
}
346+
fp16: *common_fp16
347+
bf16: *common_bf16
335348
recipes:
336349
<<: *default_optimization
337350

@@ -393,19 +406,25 @@
393406
},
394407
'dynamic': *ref_1_9_dynamic
395408
}
409+
fp16: *common_fp16
410+
bf16: *common_bf16
396411
recipes:
397412
<<: *default_optimization
398413

399414
-
400415
version:
401416
name: '1.12.0'
402417
int8: *ref_1_11
418+
fp16: *common_fp16
419+
bf16: *common_bf16
403420
recipes:
404421
<<: *default_optimization
405422

406423
-
407424
version:
408425
name: 'default'
409426
int8: *ref_1_6
427+
fp16: *common_fp16
428+
bf16: *common_bf16
410429
recipes:
411-
<<: *default_optimization
430+
<<: *default_optimization

neural_compressor/adaptor/ox_utils/calibration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ def calculate_quantization_params(self, q_config, quantization_thresholds):
426426
qType = 2 # uint8
427427
if tensor_name in output_name_to_nodes:
428428
parent = output_name_to_nodes[tensor_name]
429-
if parent and parent.name in q_config and q_config[parent.name] not in ['fp32']:
429+
if parent and parent.name in q_config and \
430+
q_config[parent.name] not in ['fp32', 'fp16']:
430431
scheme = q_config[parent.name]['activation']['scheme']
431432
qType = q_config[parent.name]['activation']['dtype']
432433
elif self.backend in ['TensorrtExecutionProvider']:

neural_compressor/adaptor/ox_utils/operators/direct_q8.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,25 +81,10 @@ def cast(self): # pragma: no cover
8181
return
8282
self.quantizer.dtype_cast(self.node, self.dtype)
8383

84-
@op_registry(op_types="Shape, Loop, Slice")
85-
class DirectCastOperator(Operator): # pragma: no cover
86-
"""Direct8bit Operator Cast."""
87-
88-
def __init__(self, onnx_quantizer, onnx_node):
89-
"""Initialization."""
90-
super(DirectCastOperator, self).__init__(onnx_quantizer, onnx_node)
91-
92-
def cast(self):
93-
"""Cast node."""
94-
node = self.node
95-
if node.input[0] not in [i.tensor_name for i in self.quantizer.new_value_info.values()]:
96-
return
97-
self.quantizer.dtype_cast(self.node, self.dtype)
98-
9984
@qop_registry(op_types="Reshape, Transpose, Squeeze, Unsqueeze")
10085
class QDirectOperator(QOperator):
10186
"""QDirect Operator."""
10287

10388
def __init__(self, onnx_node, children, initializers):
10489
"""Initialization."""
105-
super().__init__(onnx_node, children, initializers)
90+
super().__init__(onnx_node, children, initializers)

neural_compressor/adaptor/ox_utils/operators/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, onnx_quantizer, onnx_node):
7070
self.activation_dtype = None
7171
self.activation_scheme = 'asym'
7272
if self.node.name in self.quantizer.config:
73-
if self.quantizer.config[self.node.name] != 'fp32':
73+
if self.quantizer.config[self.node.name] not in self.quantizer.fallback_list:
7474
if 'weight' in self.quantizer.config[self.node.name].keys():
7575
self.per_channel = self.quantizer.config[self.node.name]\
7676
['weight']['granularity'] == 'per_channel'
@@ -162,4 +162,4 @@ def convert(self):
162162
node.op_type, inputs,
163163
outputs, node.name + '_convert', **kwargs)
164164
add_nodes.append(new_node)
165-
return True, add_nodes, inits
165+
return True, add_nodes, inits

neural_compressor/adaptor/ox_utils/quantizer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def dfs(match_nodes, node, pattern):
320320
if len(outs) > 0:
321321
output_dtype = str(self.new_value_info[outs[0]].new_dtype)
322322
break
323-
if len(outs) == 0 or all([not self.should_convert(i) for i in children]):
323+
if len(outs) == 0 or all([not self.should_cast(i) for i in children]):
324324
return
325325
if input_dtype == str(match_nodes[1].attribute[0].i) and \
326326
output_dtype == str(match_nodes[0].attribute[0].i) and \
@@ -355,17 +355,13 @@ def dfs(match_nodes, node, pattern):
355355

356356
def dtype_cast(self, node, cfg, keep_io_types=True): # pragma: no cover
357357
"""Cast node dtype."""
358-
min_positive_val = 1e-7
359-
max_finite_val = 1e4
360358
for idx, tensor_name in enumerate(node.input):
361359
initializer = find_by_name(tensor_name, self.model.initializer())
362360
if initializer is not None:
363361
if initializer.data_type != onnx_proto.TensorProto.FLOAT:
364362
continue
365-
new_tensor = cast_tensor(initializer, cfg)
366-
if new_tensor:
367-
self.model.remove_initializer(initializer)
368-
self.model.add_initializer(new_tensor)
363+
do_cast = cast_tensor(initializer, cfg)
364+
if do_cast:
369365
self.new_value_info[tensor_name] = ValueInfo(tensor_name,
370366
TensorProto.FLOAT, dtype_mapping[cfg])
371367
else:

neural_compressor/adaptor/ox_utils/util.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,16 @@
3333
ms_domain = "com.microsoft"
3434

3535
support_pair = {
36+
'float32 bfloat16': True,
37+
'1 16': True,
38+
'bfloat16 float32': True,
39+
'16 1': True,
3640
'uint8 uint8': True,
3741
'2 2': True,
3842
'float16 float16': True,
3943
'10 10': True,
44+
'bfloat16 bfloat16': True,
45+
'16 16': True,
4046
'float32 float16': True,
4147
'1 10': True,
4248
'float16 float32': True,
@@ -59,6 +65,7 @@
5965
'uint64': 13,
6066
'complex64': 14,
6167
'complex128': 15,
68+
'bf16': 16
6269
}
6370

6471
PROVIDERS = {
@@ -135,6 +142,26 @@ def split_shared_bias(model):
135142
node.input[2] = new_input_name
136143
return model
137144

145+
def float_to_float16(tensor):
146+
"""Convert float to float16."""
147+
min_val = 5.96e-08
148+
max_val = 65504.0
149+
tensor[(tensor > max_val) & (tensor < float('inf'))] = max_val
150+
tensor[(tensor < min_val) & (tensor > 0)] = min_val
151+
tensor[(tensor > -min_val) & (tensor < 0)] = -min_val
152+
tensor[(tensor < -max_val) & (tensor > float('-inf'))] = -max_val
153+
return np.float16(tensor)
154+
155+
def float_to_bfloat16(tensor):
156+
"""Convert float to bfloat16."""
157+
min_val = 9.2e-41
158+
max_val = 3.38953139e38
159+
tensor[(tensor > max_val) & (tensor < float('inf'))] = max_val
160+
tensor[(tensor < min_val) & (tensor > 0)] = min_val
161+
tensor[(tensor > -min_val) & (tensor < 0)] = -min_val
162+
tensor[(tensor < -max_val) & (tensor > float('-inf'))] = -max_val
163+
return tensor
164+
138165
def cast_tensor(tensor, dtype): # pragma: no cover
139166
"""Convert tensor float to target dtype.
140167
@@ -146,14 +173,19 @@ def cast_tensor(tensor, dtype): # pragma: no cover
146173
raise ValueError('Expected input type is an ONNX TensorProto but got %s' % type(tensor))
147174

148175
if tensor.data_type == onnx_proto.TensorProto.FLOAT:
149-
new_tensor = helper.make_tensor(
150-
name=tensor.name,
151-
data_type=dtype_mapping[dtype],
152-
dims=numpy_helper.to_array(tensor).shape,
153-
vals=numpy_helper.to_array(tensor)
154-
)
155-
return new_tensor
156-
return None
176+
val = numpy_helper.to_array(tensor).copy()
177+
if dtype == 'fp16':
178+
new_val = float_to_float16(val)
179+
elif dtype == 'bf16':
180+
new_val = float_to_bfloat16(val)
181+
else:
182+
raise ValueError('Expect fp16 or bf16 but get {}.'.format(dtype))
183+
tensor.float_data[:] = []
184+
tensor.int32_data[:] = []
185+
tensor.raw_data = new_val.tostring()
186+
tensor.data_type = dtype_mapping[dtype]
187+
return True
188+
return False
157189

158190
def remove_init_from_model_input(model):
159191
"""Remove initializer from model input."""

0 commit comments

Comments
 (0)