From dc62880a931768484d84cd0c7dc8d1621dab8458 Mon Sep 17 00:00:00 2001 From: "Lv, Liang1" Date: Tue, 6 Dec 2022 18:04:48 +0800 Subject: [PATCH] Support 'Square', 'Sum', 'SparseSegmentSqrtN' BF16 Signed-off-by: Lv, Liang1 --- neural_compressor/adaptor/tensorflow.yaml | 2 +- .../adaptor/tf_utils/quantize_graph/qdq/fuse_qdq_matmul.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/neural_compressor/adaptor/tensorflow.yaml b/neural_compressor/adaptor/tensorflow.yaml index 62524f544db..256eb4a17bb 100644 --- a/neural_compressor/adaptor/tensorflow.yaml +++ b/neural_compressor/adaptor/tensorflow.yaml @@ -35,7 +35,7 @@ "Erf", "FusedBatchNormV2", "FusedBatchNormGradV2", "FusedBatchNormV3", "FusedBatchNormGradV3", "LeakyRelu", "LeakyReluGrad", "Mean", "Mul", "Sub", "Elu", "EluGrad", "FloorDiv", "_FusedBatchNormEx", "Log", "Log1p", "LogSoftmax", "Prod", "RealDiv", "Reciprocal", "Rsqrt", "Selu", "SeluGrad", "Sigmoid", "SigmoidGrad", "Softmax", "Softplus", "SoftplusGrad", "Softsign", - "SoftsignGrad", "Sqrt", "SquaredDifference", "Tanh", "TanhGrad", #infer_list + "SoftsignGrad", "Sqrt", "Square", "SquaredDifference", "Sum", "Tanh", "TanhGrad", "SparseSegmentSqrtN", # infer_list "Abs", "ArgMax","ArgMin","BatchToSpace","BatchToSpaceND","BroadcastTo","Ceil","CheckNumerics","ClipByValue","Concat","ConcatV2", "DepthToSpace","DynamicPartition","DynamicStitch","EnsureShape","Enter","Equal","Exit","ExpandDims","Fill","Floor","Gather", "GatherNd","GatherV2","Greater","GreaterEqual","Identity","IsFinite","IsInf","IsNan","Less","LessEqual","Max","Maximum","MaxPool", diff --git a/neural_compressor/adaptor/tf_utils/quantize_graph/qdq/fuse_qdq_matmul.py b/neural_compressor/adaptor/tf_utils/quantize_graph/qdq/fuse_qdq_matmul.py index 1b95f743fc5..40183e427d2 100644 --- a/neural_compressor/adaptor/tf_utils/quantize_graph/qdq/fuse_qdq_matmul.py +++ b/neural_compressor/adaptor/tf_utils/quantize_graph/qdq/fuse_qdq_matmul.py @@ -963,12 +963,6 @@ def _is_match_matmul(self, patterns, qdq_inserted=False): self.exclude_matmul_nodes.append(cur_node.name) continue - for i in self.node_name_mapping: - if weight_node.input and not weight_node.input[0].startswith('^') \ - and weight_node.name in self.node_name_mapping[i].output: - self.exclude_matmul_nodes.append(cur_node.name) - continue - for sub_rule in patterns: if sub_rule[0] != "Dequantize": self.exclude_matmul_nodes.append(cur_node.name)