Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions neural_compressor/adaptor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ def _query_quantizable_ops(self, matched_nodes):
fp32_common_config = {'weight': {'dtype': 'fp32'}, 'activation': {'dtype': 'fp32'}}
uint8_type = self.query_handler.get_op_types_by_precision(precision='uint8')
int8_type = self.query_handler.get_op_types_by_precision(precision='int8')
bf16_type = self.query_handler.get_op_types_by_precision(precision='bf16')
tf_quantizable_op_type = list(set(uint8_type).union(set(int8_type)))

valid_precision = self.query_handler.get_mixed_precision_combination()
Expand Down Expand Up @@ -792,7 +793,8 @@ def _query_quantizable_ops(self, matched_nodes):
self.quantizable_op_details[(
node_name, self.unify_op_type_mapping[node_op]
)] = [copy.deepcopy(other_config), fp32_common_config]
if ('bf16' in valid_precision and CpuInfo().bf16) or os.getenv('FORCE_BF16') == '1':
if node_op in bf16_type and (('bf16' in valid_precision and CpuInfo().bf16) \
or os.getenv('FORCE_BF16') == '1'):
self.quantizable_op_details[(
node_name, self.unify_op_type_mapping[node_op]
)].insert(1, bf16_common_config)
Expand Down Expand Up @@ -2228,7 +2230,7 @@ def get_op_types_by_precision(self, precision):
return self.cur_config[precision]
if version1_gte_version2(tf.version.VERSION, '2.1.0') or \
version1_eq_version2(tf.version.VERSION, '1.15.0-up3'):
return ['Conv2D']
return self.cur_config[precision]
return []

def get_mixed_precision_combination(self):
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/tensorflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
version:
name: ['2.1.0', '2.2.0', '2.3.0', '2.4.0', '2.5.0', '2.6.0', '2.6.1', '2.6.2', '2.7.0', '2.8.0', '2.9.0', '2.9.1', '2.10.0', '2.11.0', '1.15.0-up1', '1.15.0-up2', 1.15.0-up3]

bf16: ['Conv2D']
bf16: ['Conv2D', 'Conv3D', 'MatMul', 'BatchMatMul', 'MaxPool', 'MaxPool3D', 'AvgPool', 'AvgPool3D', 'DepthwiseConv2dNative']
fp32: ['*'] # '*' means all op types

int8: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from neural_compressor.adaptor.tf_utils.graph_util import GraphRewriterHelper as Helper
from ..generic.graph_cse_optimizer import GraphCseOptimizer
from ..generic.dequantize_cast_optimizer import DequantizeCastOptimizer
import tensorflow as tf
from neural_compressor.adaptor.tf_utils.util import TF_SPR_BASE_VERSIONS

DT_FLOAT32 = attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum)
DT_BFLOAT16 = attr_value_pb2.AttrValue(type=dtypes.bfloat16.as_datatype_enum)
Expand Down Expand Up @@ -179,7 +181,8 @@ def _bf16_convert(self, bf16_node_name):
tensor=tensor_util.make_tensor_proto(
fp32_value, dtypes.bfloat16, fp32_value.shape)))
elif 'Dequantize' == input_node.op and len(input_node_outputs) == 1 \
and input_node.attr['mode'].s != b'MIN_FIRST':
and input_node.attr['mode'].s != b'MIN_FIRST' \
and tf.version.VERSION in TF_SPR_BASE_VERSIONS:
# Dequantize with mode MIN_FIRST does not support bf16 in both eigen and mkl
_, outputs_dt_input_node = self._dtype(input_node)
allowed_input_node_dt_val = self._allowed_dtype_val(input_node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

from ..graph_base import GraphRewriterBase
from neural_compressor.adaptor.tf_utils.graph_util import GraphAnalyzer
from neural_compressor.adaptor.tf_utils.graph_util import GraphRewriterHelper as Helper
from neural_compressor.utils.utility import dump_elapsed_time
import tensorflow as tf
from neural_compressor.adaptor.tf_utils.util import TF_SPR_BASE_VERSIONS

class DequantizeCastOptimizer(GraphRewriterBase):
"""Remove the Cast OP and set Dequantize output to B16 if the Cast OP output is BF16."""

Expand All @@ -36,6 +38,11 @@ def do_transformation(self):
Returns:
[graphdef]: optimized graph
"""
# stock TF _MklDequantize doesn't support BF16 currently.
# TODO remove this when spr-base upstream to stock TF.
if not tf.version.VERSION in TF_SPR_BASE_VERSIONS:
return self.model

DT_BFLOAT16 = attr_value_pb2.AttrValue(type=dtypes.bfloat16.as_datatype_enum)
cur_graph = GraphAnalyzer()
cur_graph.graph = self.model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def test_dequantize_cast_normal(self):
graph_def = build_fake_graphdef()
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
for i in converted_graph_def.node:
self.assertNotEqual(i.op, 'Cast')
if i.op == 'Cast':
hasCast = True
break
self.assertEqual(hasCast, True)

@disable_random()
def test_dequantize_cast_min_first(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def build_fake_framework_yaml():
---
-
version:
name: ['2.1.0', '2.2.0', '2.3.0', '2.4.0', '2.5.0', '2.6.0', '2.7.0']
name: ['2.1.0', '2.2.0', '2.3.0', '2.4.0', '2.5.0', '2.6.0', '2.7.0']

bf16: ['Conv2D', 'MatMul', 'ConcatV2', 'MaxPool', 'AvgPool', 'DepthwiseConv2dNative']

int8: {
'static': {
Expand Down Expand Up @@ -93,6 +95,8 @@ def build_fake_framework_yaml():
version:
name: ['default']

bf16: ['Conv2D', 'MatMul', 'ConcatV2', 'MaxPool', 'AvgPool', 'DepthwiseConv2dNative']

int8: {
'static': {
'Conv2D': {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import unittest
import os
import yaml
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import dtypes
from neural_compressor.adaptor.tf_utils.util import disable_random
from neural_compressor.adaptor.tf_utils.graph_util import GraphRewriterHelper as Helper
from neural_compressor.adaptor.tf_utils.graph_rewriter.generic.dequantize_cast_optimizer import DequantizeCastOptimizer

def build_fake_graphdef(set_min_first=False, dq_multi_outputs=False):
tf.compat.v1.disable_eager_execution()

input = tf.compat.v1.placeholder(tf.float32, shape=(32, 224, 224, 3), name='input')
graph_def = tf.compat.v1.get_default_graph().as_graph_def(add_shapes=True)

min_input = Helper.create_constant_node(
'test_min',
value=0.,
dtype=dtypes.float32)

max_input = Helper.create_constant_node(
'test_max',
value=[1],
dtype=dtypes.float32)

quant_v2_node = Helper.create_node("QuantizeV2", 'test_quantize',
[input.name, min_input.name, max_input.name])

dequantize_node = Helper.create_node(
"Dequantize", 'test_dequantize',
[quant_v2_node.name, quant_v2_node.name + ':1', quant_v2_node.name + ':2'])
if set_min_first:
Helper.set_attr_string(dequantize_node, "mode", b'MIN_FIRST')

cast_node = Helper.create_node(
"Cast", 'test_cast', [dequantize_node.name])
Helper.set_attr_dtype(cast_node, "DstT", dtypes.bfloat16)
Helper.set_attr_dtype(cast_node, "SrcT", dtypes.float32)
Helper.set_attr_bool(cast_node, "Truncate", False)

dentity_node = Helper.create_node(
"Identity", 'output', [cast_node.name])
Helper.set_attr_dtype(dentity_node, "T", dtypes.bfloat16)

graph_def.node.extend([
min_input,
max_input,
quant_v2_node,
dequantize_node,
cast_node,
dentity_node,
])

if dq_multi_outputs:
dentity_node_2 = Helper.create_node(
"Identity", 'id_1', [dequantize_node.name])
Helper.set_attr_dtype(dentity_node_2, "T", dtypes.float32)
graph_def.node.extend([dentity_node_2])

return graph_def

class TestDequantizeCastOptimizer(unittest.TestCase):

@disable_random()
def test_dequantize_cast_normal(self):
graph_def = build_fake_graphdef()
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
for i in converted_graph_def.node:
self.assertNotEqual(i.op, 'Cast')

@disable_random()
def test_dequantize_cast_min_first(self):
graph_def = build_fake_graphdef(set_min_first=True)
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
hasCast = False
for i in converted_graph_def.node:
if i.op == 'Cast':
hasCast = True
break
self.assertEqual(hasCast, True)

@disable_random()
def test_dequantize_cast_multiple_outputs(self):
graph_def = build_fake_graphdef(dq_multi_outputs=True)
converted_graph_def = DequantizeCastOptimizer(graph_def).do_transformation()
hasCast = False
for i in converted_graph_def.node:
if i.op == 'Cast':
hasCast = True
break
self.assertEqual(hasCast, True)


if __name__ == "__main__":
unittest.main()