From 716f6e899b7f1c285edea5d0b2a7a95d94f638fb Mon Sep 17 00:00:00 2001 From: "Lv, Liang1" Date: Fri, 2 Dec 2022 17:55:10 +0800 Subject: [PATCH] Fix NTM-One-Shot failed with KeyError Signed-off-by: Lv, Liang1 --- .../tf_utils/graph_rewriter/generic/pre_optimize.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py b/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py index 4b59de5ecb7..d7c2e33ca83 100644 --- a/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py +++ b/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py @@ -146,16 +146,16 @@ def get_optimized_model(self, itex_mode=False): self._tmp_graph_def = ConvertPlaceholderToConst(self._tmp_graph_def).do_transformation() - self._tmp_graph_def = RemoveTrainingNodesOptimizer( - self._tmp_graph_def, protected_nodes=input_output_names).do_transformation() - self._tmp_graph_def = SwitchOptimizer(self._tmp_graph_def).do_transformation() + self._tmp_graph_def = GrapplerOptimizer( + self._tmp_graph_def, input_output_names, self.optimization).do_transformation() + self._tmp_graph_def = StripUnusedNodesOptimizer(self._tmp_graph_def, input_node_names, output_node_names).do_transformation() - self._tmp_graph_def = GrapplerOptimizer( - self._tmp_graph_def, input_output_names, self.optimization).do_transformation() + self._tmp_graph_def = RemoveTrainingNodesOptimizer( + self._tmp_graph_def, protected_nodes=input_output_names).do_transformation() self._tmp_graph_def = SplitSharedInputOptimizer(self._tmp_graph_def).do_transformation()