@@ -108,6 +108,57 @@ void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
108108 }
109109}
110110
111+ /*
112+ Below is a fork of the torch::jit::EliminateExceptions pass, with node replacement
113+ using replaceAllUsesDominatedByNodeWith instead of replaceAllUsesWith,
114+ so as to not invalidate the IR in challenging cases, such as nested Ifs
115+
116+ Original Source from which it was adapted:
117+ https://github.com/pytorch/pytorch/blob/c29ab84115f40614d04e4557ea2e1ac40b7aa75c/torch/csrc/jit/passes/remove_exceptions.cpp
118+ */
119+
120+ bool certainlyThrows (Block* block) {
121+ // A block certainly throws an exception if it contains
122+ // the prim::RaiseException operation
123+ for (Node* n : block->nodes ()) {
124+ if (n->kind () == prim::RaiseException) {
125+ return true ;
126+ }
127+ }
128+ return false ;
129+ }
130+
131+ void EliminateExceptionsNew (Block* block) {
132+ auto graph = block->owningGraph ();
133+ // Generate false and true constant placeholders
134+ Value* false_const = graph->insertConstant (IValue (false ));
135+ Value* true_const = graph->insertConstant (IValue (true ));
136+
137+ // For each prim::If node, if either block certainly throws an exception
138+ // Replace all uses of the node input with the logical opposite
139+ for (Node* n : block->nodes ()) {
140+ if (n->kind () == prim::If) {
141+ Block* true_block = n->blocks ()[0 ];
142+ Block* false_block = n->blocks ()[1 ];
143+
144+ if (certainlyThrows (true_block)) {
145+ n->input (0 )->replaceAllUsesDominatedByNodeWith (n, false_const);
146+ } else if (certainlyThrows (false_block)) {
147+ n->input (0 )->replaceAllUsesDominatedByNodeWith (n, true_const);
148+ }
149+ }
150+
151+ // Inspect and replace all instances within subblocks of the current node
152+ for (Block* subblock : n->blocks ()) {
153+ EliminateExceptionsNew (subblock);
154+ }
155+ }
156+ }
157+
158+ void EliminateExceptionsNew (std::shared_ptr<Graph>& graph) {
159+ EliminateExceptionsNew (graph->block ());
160+ }
161+
111162} // namespace passes
112163} // namespace lowering
113164} // namespace core
0 commit comments