Skip to content

Commit 369b9d0

Browse files
Add more BF16 ops support on stock tensorflow (#792)
Signed-off-by: Lv, Liang1 <[email protected]>
1 parent 35a2cd2 commit 369b9d0

File tree

7 files changed

+122
-7
lines changed

7 files changed

+122
-7
lines changed

neural_compressor/adaptor/tensorflow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,7 @@ def _query_quantizable_ops(self, matched_nodes):
722722
fp32_common_config = {'weight': {'dtype': 'fp32'}, 'activation': {'dtype': 'fp32'}}
723723
uint8_type = self.query_handler.get_op_types_by_precision(precision='uint8')
724724
int8_type = self.query_handler.get_op_types_by_precision(precision='int8')
725+
bf16_type = self.query_handler.get_op_types_by_precision(precision='bf16')
725726
tf_quantizable_op_type = list(set(uint8_type).union(set(int8_type)))
726727

727728
valid_precision = self.query_handler.get_mixed_precision_combination()
@@ -792,7 +793,8 @@ def _query_quantizable_ops(self, matched_nodes):
792793
self.quantizable_op_details[(
793794
node_name, self.unify_op_type_mapping[node_op]
794795
)] = [copy.deepcopy(other_config), fp32_common_config]
795-
if ('bf16' in valid_precision and CpuInfo().bf16) or os.getenv('FORCE_BF16') == '1':
796+
if node_op in bf16_type and (('bf16' in valid_precision and CpuInfo().bf16) \
797+
or os.getenv('FORCE_BF16') == '1'):
796798
self.quantizable_op_details[(
797799
node_name, self.unify_op_type_mapping[node_op]
798800
)].insert(1, bf16_common_config)
@@ -2228,7 +2230,7 @@ def get_op_types_by_precision(self, precision):
22282230
return self.cur_config[precision]
22292231
if version1_gte_version2(tf.version.VERSION, '2.1.0') or \
22302232
version1_eq_version2(tf.version.VERSION, '1.15.0-up3'):
2231-
return ['Conv2D']
2233+
return self.cur_config[precision]
22322234
return []
22332235

22342236
def get_mixed_precision_combination(self):

neural_compressor/adaptor/tensorflow.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@
153153
version:
154154
name: ['2.1.0', '2.2.0', '2.3.0', '2.4.0', '2.5.0', '2.6.0', '2.6.1', '2.6.2', '2.7.0', '2.8.0', '2.9.0', '2.9.1', '2.10.0', '2.11.0', '1.15.0-up1', '1.15.0-up2', 1.15.0-up3]
155155

156-
bf16: ['Conv2D']
156+
bf16: ['Conv2D', 'Conv3D', 'MatMul', 'BatchMatMul', 'MaxPool', 'MaxPool3D', 'AvgPool', 'AvgPool3D', 'DepthwiseConv2dNative']
157157
fp32: ['*'] # '*' means all op types
158158

159159
int8: {

neural_compressor/adaptor/tf_utils/graph_rewriter/bf16/bf16_convert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from neural_compressor.adaptor.tf_utils.graph_util import GraphRewriterHelper as Helper
3535
from ..generic.graph_cse_optimizer import GraphCseOptimizer
3636
from ..generic.dequantize_cast_optimizer import DequantizeCastOptimizer
37+
import tensorflow as tf
38+
from neural_compressor.adaptor.tf_utils.util import TF_SPR_BASE_VERSIONS
3739

3840
DT_FLOAT32 = attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum)
3941
DT_BFLOAT16 = attr_value_pb2.AttrValue(type=dtypes.bfloat16.as_datatype_enum)
@@ -179,7 +181,8 @@ def _bf16_convert(self, bf16_node_name):
179181
tensor=tensor_util.make_tensor_proto(
180182
fp32_value, dtypes.bfloat16, fp32_value.shape)))
181183
elif 'Dequantize' == input_node.op and len(input_node_outputs) == 1 \
182-
and input_node.attr['mode'].s != b'MIN_FIRST':
184+
and input_node.attr['mode'].s != b'MIN_FIRST' \
185+
and tf.version.VERSION in TF_SPR_BASE_VERSIONS:
183186
# Dequantize with mode MIN_FIRST does not support bf16 in both eigen and mkl
184187
_, outputs_dt_input_node = self._dtype(input_node)
185188
allowed_input_node_dt_val = self._allowed_dtype_val(input_node)

neural_compressor/adaptor/tf_utils/graph_rewriter/generic/dequantize_cast_optimizer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121

2222
from ..graph_base import GraphRewriterBase
2323
from neural_compressor.adaptor.tf_utils.graph_util import GraphAnalyzer
24-
from neural_compressor.adaptor.tf_utils.graph_util import GraphRewriterHelper as Helper
2524
from neural_compressor.utils.utility import dump_elapsed_time
25+
import tensorflow as tf
26+
from neural_compressor.adaptor.tf_utils.util import TF_SPR_BASE_VERSIONS
27+
2628
class DequantizeCastOptimizer(GraphRewriterBase):
2729
"""Remove the Cast OP and set Dequantize output to B16 if the Cast OP output is BF16."""
2830

@@ -36,6 +38,11 @@ def do_transformation(self):
3638
Returns:
3739
[graphdef]: optimized graph
3840
"""
41+
# stock TF _MklDequantize doesn't support BF16 currently.
42+
# TODO remove this when spr-base upstream to stock TF.
43+
if not tf.version.VERSION in TF_SPR_BASE_VERSIONS:
44+
return self.model
45+
3946
DT_BFLOAT16 = attr_value_pb2.AttrValue(type=dtypes.bfloat16.as_datatype_enum)
4047
cur_graph = GraphAnalyzer()
4148
cur_graph.graph = self.model

test/adaptor/tensorflow_adaptor/test_tensorflow_graph_dequantize_cast_optimizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ def test_dequantize_cast_normal(self):
6767
graph_def = build_fake_graphdef()
6868
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
6969
for i in converted_graph_def.node:
70-
self.assertNotEqual(i.op, 'Cast')
70+
if i.op == 'Cast':
71+
hasCast = True
72+
break
73+
self.assertEqual(hasCast, True)
7174

7275
@disable_random()
7376
def test_dequantize_cast_min_first(self):

test/adaptor/tensorflow_adaptor/test_tensorflow_query_yaml.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def build_fake_framework_yaml():
5252
---
5353
-
5454
version:
55-
name: ['2.1.0', '2.2.0', '2.3.0', '2.4.0', '2.5.0', '2.6.0', '2.7.0']
55+
name: ['2.1.0', '2.2.0', '2.3.0', '2.4.0', '2.5.0', '2.6.0', '2.7.0']
56+
57+
bf16: ['Conv2D', 'MatMul', 'ConcatV2', 'MaxPool', 'AvgPool', 'DepthwiseConv2dNative']
5658
5759
int8: {
5860
'static': {
@@ -93,6 +95,8 @@ def build_fake_framework_yaml():
9395
version:
9496
name: ['default']
9597
98+
bf16: ['Conv2D', 'MatMul', 'ConcatV2', 'MaxPool', 'AvgPool', 'DepthwiseConv2dNative']
99+
96100
int8: {
97101
'static': {
98102
'Conv2D': {
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import unittest
2+
import os
3+
import yaml
4+
import numpy as np
5+
import tensorflow as tf
6+
from tensorflow.python.framework import dtypes
7+
from neural_compressor.adaptor.tf_utils.util import disable_random
8+
from neural_compressor.adaptor.tf_utils.graph_util import GraphRewriterHelper as Helper
9+
from neural_compressor.adaptor.tf_utils.graph_rewriter.generic.dequantize_cast_optimizer import DequantizeCastOptimizer
10+
11+
def build_fake_graphdef(set_min_first=False, dq_multi_outputs=False):
12+
tf.compat.v1.disable_eager_execution()
13+
14+
input = tf.compat.v1.placeholder(tf.float32, shape=(32, 224, 224, 3), name='input')
15+
graph_def = tf.compat.v1.get_default_graph().as_graph_def(add_shapes=True)
16+
17+
min_input = Helper.create_constant_node(
18+
'test_min',
19+
value=0.,
20+
dtype=dtypes.float32)
21+
22+
max_input = Helper.create_constant_node(
23+
'test_max',
24+
value=[1],
25+
dtype=dtypes.float32)
26+
27+
quant_v2_node = Helper.create_node("QuantizeV2", 'test_quantize',
28+
[input.name, min_input.name, max_input.name])
29+
30+
dequantize_node = Helper.create_node(
31+
"Dequantize", 'test_dequantize',
32+
[quant_v2_node.name, quant_v2_node.name + ':1', quant_v2_node.name + ':2'])
33+
if set_min_first:
34+
Helper.set_attr_string(dequantize_node, "mode", b'MIN_FIRST')
35+
36+
cast_node = Helper.create_node(
37+
"Cast", 'test_cast', [dequantize_node.name])
38+
Helper.set_attr_dtype(cast_node, "DstT", dtypes.bfloat16)
39+
Helper.set_attr_dtype(cast_node, "SrcT", dtypes.float32)
40+
Helper.set_attr_bool(cast_node, "Truncate", False)
41+
42+
dentity_node = Helper.create_node(
43+
"Identity", 'output', [cast_node.name])
44+
Helper.set_attr_dtype(dentity_node, "T", dtypes.bfloat16)
45+
46+
graph_def.node.extend([
47+
min_input,
48+
max_input,
49+
quant_v2_node,
50+
dequantize_node,
51+
cast_node,
52+
dentity_node,
53+
])
54+
55+
if dq_multi_outputs:
56+
dentity_node_2 = Helper.create_node(
57+
"Identity", 'id_1', [dequantize_node.name])
58+
Helper.set_attr_dtype(dentity_node_2, "T", dtypes.float32)
59+
graph_def.node.extend([dentity_node_2])
60+
61+
return graph_def
62+
63+
class TestDequantizeCastOptimizer(unittest.TestCase):
64+
65+
@disable_random()
66+
def test_dequantize_cast_normal(self):
67+
graph_def = build_fake_graphdef()
68+
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
69+
for i in converted_graph_def.node:
70+
self.assertNotEqual(i.op, 'Cast')
71+
72+
@disable_random()
73+
def test_dequantize_cast_min_first(self):
74+
graph_def = build_fake_graphdef(set_min_first=True)
75+
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
76+
hasCast = False
77+
for i in converted_graph_def.node:
78+
if i.op == 'Cast':
79+
hasCast = True
80+
break
81+
self.assertEqual(hasCast, True)
82+
83+
@disable_random()
84+
def test_dequantize_cast_multiple_outputs(self):
85+
graph_def = build_fake_graphdef(dq_multi_outputs=True)
86+
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
87+
hasCast = False
88+
for i in converted_graph_def.node:
89+
if i.op == 'Cast':
90+
hasCast = True
91+
break
92+
self.assertEqual(hasCast, True)
93+
94+
95+
if __name__ == "__main__":
96+
unittest.main()

0 commit comments

Comments
 (0)