diff --git a/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py b/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py index 4b59de5ecb7..d7c2e33ca83 100644 --- a/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py +++ b/neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py @@ -146,16 +146,16 @@ def get_optimized_model(self, itex_mode=False): self._tmp_graph_def = ConvertPlaceholderToConst(self._tmp_graph_def).do_transformation() - self._tmp_graph_def = RemoveTrainingNodesOptimizer( - self._tmp_graph_def, protected_nodes=input_output_names).do_transformation() - self._tmp_graph_def = SwitchOptimizer(self._tmp_graph_def).do_transformation() + self._tmp_graph_def = GrapplerOptimizer( + self._tmp_graph_def, input_output_names, self.optimization).do_transformation() + self._tmp_graph_def = StripUnusedNodesOptimizer(self._tmp_graph_def, input_node_names, output_node_names).do_transformation() - self._tmp_graph_def = GrapplerOptimizer( - self._tmp_graph_def, input_output_names, self.optimization).do_transformation() + self._tmp_graph_def = RemoveTrainingNodesOptimizer( + self._tmp_graph_def, protected_nodes=input_output_names).do_transformation() self._tmp_graph_def = SplitSharedInputOptimizer(self._tmp_graph_def).do_transformation()