diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 0c0ed5ceef..669cd5ed3c 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -9,6 +9,7 @@ #include "torch/csrc/jit/passes/lower_graph.h" #include "torch/csrc/jit/passes/lower_tuples.h" #include "torch/csrc/jit/passes/peephole.h" +#include "torch/csrc/jit/passes/remove_exceptions.h" #include "torch/csrc/jit/passes/remove_mutation.h" #include "core/lowering/lowering.h" @@ -33,6 +34,7 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { torch::jit::InlineFunctionalGraphs(g); torch::jit::PeepholeOptimize(g, false); torch::jit::FuseLinear(g); + torch::jit::EliminateExceptions(g); if (!lower_info.disable_cse) { torch::jit::EliminateCommonSubexpression(g); } diff --git a/core/lowering/passes/exception_elimination.cpp b/core/lowering/passes/exception_elimination.cpp index 02fb773653..63581af1ca 100644 --- a/core/lowering/passes/exception_elimination.cpp +++ b/core/lowering/passes/exception_elimination.cpp @@ -4,7 +4,6 @@ #include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/passes/guard_elimination.h" #include "torch/csrc/jit/passes/peephole.h" -#include "torch/csrc/jit/passes/remove_exceptions.h" #include "torch/csrc/jit/runtime/graph_executor.h" #include "core/util/prelude.h" @@ -22,7 +21,6 @@ struct ExceptionOrPassPatternElimination { void run() { findExceptionOrPassNodes(graph_->block()); - torch::jit::EliminateExceptions(graph_); torch::jit::EliminateDeadCode(graph_); LOG_GRAPH("Post exeception or pass elimination: " << *graph_); }