Skip to content

Commit b6c3b2b

Browse files
committed
fix: Repair EliminateExceptions lowering pass
- Add new `EliminateExceptionsSafe` lowering pass, which has functionally the same task as that of `EliminateExceptions`, but with a safer replacement scheme - Update EliminateExceptions to use `replaceAllUsesDominatedByNodeWith` instead of `replaceAllUsesWith` to avoid issue with invalid IR causing program halting - Add testing for new lowering pass
1 parent 71d71c7 commit b6c3b2b

File tree

4 files changed

+246
-0
lines changed

4 files changed

+246
-0
lines changed

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
104104
torch::jit::InlineFunctionalGraphs(g);
105105
torch::jit::PeepholeOptimize(g, false);
106106
torch::jit::FuseLinear(g);
107+
passes::EliminateExceptionsSafe(g);
107108
if (!lower_info.disable_cse) {
108109
torch::jit::EliminateCommonSubexpression(g);
109110
}

core/lowering/passes/exception_elimination.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <torch/csrc/jit/passes/constant_pooling.h>
12
#include "torch/csrc/jit/ir/alias_analysis.h"
23
#include "torch/csrc/jit/jit_log.h"
34
#include "torch/csrc/jit/passes/constant_propagation.h"
@@ -108,6 +109,59 @@ void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
108109
}
109110
}
110111

112+
/*
113+
Below is a fork of the torch::jit::EliminateExceptions pass, with node replacement
114+
using replaceAllUsesDominatedByNodeWith instead of replaceAllUsesWith,
115+
so as to not invalidate the IR in challenging cases, such as nested Ifs
116+
117+
Original Source from which it was adapted:
118+
https://github.com/pytorch/pytorch/blob/c29ab84115f40614d04e4557ea2e1ac40b7aa75c/torch/csrc/jit/passes/remove_exceptions.cpp
119+
*/
120+
121+
bool certainlyThrows(Block* block) {
122+
// A block certainly throws an exception if it contains
123+
// the prim::RaiseException operation
124+
for (Node* n : block->nodes()) {
125+
if (n->kind() == prim::RaiseException) {
126+
return true;
127+
}
128+
}
129+
return false;
130+
}
131+
132+
void EliminateExceptionsSafe(Block* block) {
133+
auto graph = block->owningGraph();
134+
// Generate false and true constant placeholders
135+
Value* false_const = graph->insertConstant(IValue(false));
136+
Value* true_const = graph->insertConstant(IValue(true));
137+
138+
// For each prim::If node, if either block certainly throws an exception
139+
// Replace all uses of the node input with the logical opposite
140+
for (Node* n : block->nodes()) {
141+
if (n->kind() == prim::If) {
142+
Block* true_block = n->blocks()[0];
143+
Block* false_block = n->blocks()[1];
144+
145+
if (certainlyThrows(true_block)) {
146+
n->input(0)->replaceAllUsesDominatedByNodeWith(n, false_const);
147+
} else if (certainlyThrows(false_block)) {
148+
n->input(0)->replaceAllUsesDominatedByNodeWith(n, true_const);
149+
}
150+
}
151+
152+
// Inspect and replace all instances within subblocks of the current node
153+
for (Block* subblock : n->blocks()) {
154+
EliminateExceptionsSafe(subblock);
155+
}
156+
}
157+
}
158+
159+
void EliminateExceptionsSafe(std::shared_ptr<Graph>& graph) {
160+
EliminateExceptionsSafe(graph->block());
161+
ConstantPropagation(graph);
162+
ConstantPooling(graph);
163+
}
164+
111165
} // namespace passes
112166
} // namespace lowering
113167
} // namespace core

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
2020
void ConvTransposed3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
2121
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
2222
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
23+
void EliminateExceptionsSafe(std::shared_ptr<torch::jit::Graph>& graph);
2324
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
2425
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
2526
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);

tests/core/lowering/test_exception_elimination_pass.cpp

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#include "core/lowering/passes/passes.h"
22
#include "gtest/gtest.h"
3+
#include "tests/util/util.h"
34
#include "torch/csrc/jit/ir/irparser.h"
5+
#include "torch/csrc/jit/passes/canonicalize.h"
6+
#include "torch/csrc/jit/passes/constant_pooling.h"
7+
#include "torch/csrc/jit/passes/remove_exceptions.h"
48

59
TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) {
610
// parseIR does not support " = prim::If(%51)" with no return value
@@ -169,3 +173,189 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) {
169173
}
170174
EXPECT_EQ(1, if_count);
171175
}
176+
177+
TEST(LoweringPasses, EliminateExceptionsSafeIfBlock) {
178+
/*std::string source_graph = R"IR(
179+
graph(%x, %y):
180+
%false : bool = prim::Constant[value=0]()
181+
%45 : str = prim::Constant[value="EXCEPTION"]()
182+
%4 : Tensor = prim::If(%false)
183+
block0():
184+
= prim::RaiseException(%45)
185+
-> (%x)
186+
block1():
187+
%res = aten::mul(%x, %y)
188+
-> (%res)
189+
return (%4))IR";*/
190+
191+
std::string target_graph = R"IR(
192+
graph(%x : Tensor,
193+
%y : Tensor):
194+
%6 : Tensor = aten::mul(%x, %y)
195+
return (%6))IR";
196+
197+
// Construct graph via manual commands, to avoid IR parsing issues with
198+
// unassigned variables (such as prim::RaiseException)
199+
auto g = std::make_shared<torch::jit::Graph>();
200+
auto x = g->insertInput(0, "x");
201+
auto y = g->insertInput(1, "y");
202+
auto none_const_val = g->insertConstant(torch::jit::IValue());
203+
auto false_const_val = g->insertConstant(torch::jit::IValue(false));
204+
torch::jit::IValue except("EXCEPTION");
205+
auto except_val = g->insertConstant(except);
206+
207+
auto if_node = g->create(torch::jit::prim::If, {false_const_val}, 1);
208+
auto if_block0 = if_node->addBlock();
209+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
210+
if_block0->appendNode(exception_node);
211+
if_block0->registerOutput(x);
212+
213+
auto if_block1 = if_node->addBlock();
214+
auto sum_node = g->create(torch::jit::aten::mul, {x, y}, 1);
215+
if_block1->appendNode(sum_node);
216+
if_block1->registerOutput(sum_node->output());
217+
218+
g->insertNode(if_node);
219+
g->registerOutput(if_node->output());
220+
221+
// Apply lowering pass and canonicalization to the graph
222+
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g);
223+
g = torch::jit::Canonicalize(g, false);
224+
225+
auto tg = std::make_shared<torch::jit::Graph>();
226+
torch::jit::parseIR(target_graph, tg.get());
227+
228+
torch::jit::ConstantPooling(tg);
229+
tg = torch::jit::Canonicalize(tg, false);
230+
231+
// Validate identical graphs after pooling constants and canonicalizing
232+
ASSERT_TRUE((tg->toString() == g->toString()));
233+
}
234+
235+
TEST(LoweringPasses, EliminateExceptionsSafeElseBlock) {
236+
/*std::string source_graph = R"IR(
237+
graph(%x, %y):
238+
%true : bool = prim::Constant[value=1]()
239+
%45 : str = prim::Constant[value="EXCEPTION"]()
240+
%4 : Tensor = prim::If(%true)
241+
block0():
242+
%res = aten::matmul(%x, %y)
243+
-> (%res)
244+
block1():
245+
= prim::RaiseException(%45)
246+
-> (%x)
247+
return (%4))IR";*/
248+
249+
std::string target_graph = R"IR(
250+
graph(%x : Tensor,
251+
%y : Tensor):
252+
%6 : Tensor = aten::matmul(%x, %y)
253+
return (%6))IR";
254+
255+
// Construct graph via manual commands, to avoid IR parsing issues with
256+
// unassigned variables (such as prim::RaiseException)
257+
auto g = std::make_shared<torch::jit::Graph>();
258+
auto x = g->insertInput(0, "x");
259+
auto y = g->insertInput(1, "y");
260+
auto none_const_val = g->insertConstant(torch::jit::IValue());
261+
auto true_const_val = g->insertConstant(torch::jit::IValue(true));
262+
torch::jit::IValue except("EXCEPTION");
263+
auto except_val = g->insertConstant(except);
264+
265+
auto if_node = g->create(torch::jit::prim::If, {true_const_val}, 1);
266+
auto if_block0 = if_node->addBlock();
267+
auto sum_node = g->create(torch::jit::aten::matmul, {x, y}, 1);
268+
if_block0->appendNode(sum_node);
269+
if_block0->registerOutput(sum_node->output());
270+
271+
auto if_block1 = if_node->addBlock();
272+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
273+
if_block1->appendNode(exception_node);
274+
if_block1->registerOutput(x);
275+
276+
g->insertNode(if_node);
277+
g->registerOutput(if_node->output());
278+
279+
// Apply lowering pass and canonicalization to the graph
280+
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g);
281+
g = torch::jit::Canonicalize(g, false);
282+
283+
auto tg = std::make_shared<torch::jit::Graph>();
284+
torch::jit::parseIR(target_graph, tg.get());
285+
286+
torch::jit::ConstantPooling(tg);
287+
tg = torch::jit::Canonicalize(tg, false);
288+
289+
// Validate identical graphs after pooling constants and canonicalizing
290+
ASSERT_TRUE((tg->toString() == g->toString()));
291+
}
292+
293+
TEST(LoweringPasses, EliminateExceptionsSafeNestedIfBlock) {
294+
/*std::string source_graph = R"IR(
295+
graph(%x, %y):
296+
%false : bool = prim::Constant[value=0]()
297+
%4 : Tensor = prim::If(%false)
298+
block0():
299+
%45 : str = prim::Constant[value="EXCEPTION"]()
300+
= prim::If(%false)
301+
block0():
302+
-> ()
303+
block1():
304+
= prim::RaiseException(%45)
305+
-> ()
306+
-> (%x)
307+
block1():
308+
%res = aten::mul(%x, %y)
309+
-> (%res)
310+
return (%4))IR";*/
311+
312+
std::string target_graph = R"IR(
313+
graph(%x : Tensor,
314+
%y : Tensor):
315+
%6 : Tensor = aten::mul(%x, %y)
316+
return (%6))IR";
317+
318+
// Construct graph via manual commands, to avoid IR parsing issues with
319+
// unassigned variables (such as prim::RaiseException)
320+
auto g = std::make_shared<torch::jit::Graph>();
321+
auto x = g->insertInput(0, "x");
322+
auto y = g->insertInput(1, "y");
323+
auto none_const_val = g->insertConstant(torch::jit::IValue());
324+
auto false_const_val = g->insertConstant(torch::jit::IValue(false));
325+
torch::jit::IValue except("EXCEPTION");
326+
auto except_val = g->insertConstant(except);
327+
328+
// Construct nested-If substructure in graph
329+
auto if_node = g->create(torch::jit::prim::If, {false_const_val}, 1);
330+
auto if_block0 = if_node->addBlock();
331+
auto if_if_node = g->create(torch::jit::prim::If, {false_const_val}, 0);
332+
if_block0->appendNode(if_if_node);
333+
/* auto if_if_block0 = */ if_if_node->addBlock();
334+
auto if_if_block1 = if_if_node->addBlock();
335+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
336+
if_if_block1->appendNode(exception_node);
337+
if_block0->registerOutput(x);
338+
339+
auto if_block1 = if_node->addBlock();
340+
auto sum_node = g->create(torch::jit::aten::mul, {x, y}, 1);
341+
if_block1->appendNode(sum_node);
342+
if_block1->registerOutput(sum_node->output());
343+
344+
g->insertNode(if_node);
345+
g->registerOutput(if_node->output());
346+
347+
// Apply lowering pass and canonicalization to the graph
348+
LOG_ERROR("BEFORE:\n" << *g);
349+
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g);
350+
g = torch::jit::Canonicalize(g, false);
351+
LOG_ERROR("AFTER:\n" << *g);
352+
353+
auto tg = std::make_shared<torch::jit::Graph>();
354+
torch::jit::parseIR(target_graph, tg.get());
355+
356+
torch::jit::ConstantPooling(tg);
357+
tg = torch::jit::Canonicalize(tg, false);
358+
359+
// Validate identical graphs after pooling constants and canonicalizing
360+
ASSERT_TRUE((tg->toString() == g->toString()));
361+
}

0 commit comments

Comments
 (0)