Skip to content

Commit 0615d2d

Browse files
committed
Add test case and refactored code
1 parent e912f91 commit 0615d2d

File tree

4 files changed

+59
-14
lines changed

4 files changed

+59
-14
lines changed

core/partitioning/partitioning.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
7070
const auto to_compile_sym = c10::Symbol::attr("to_compile");
7171

7272
for (const auto n : nodes) {
73-
if (n->kind() == torch::jit::prim::Constant) {
73+
if (isConstantOrUninitialized(n)) {
7474
continue;
7575
}
7676

@@ -107,7 +107,7 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
107107
q.pop();
108108
// for every node that produces this fallback node's NonTensor input, they should fallback too
109109
for (auto input : cur_node->inputs()) {
110-
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
110+
if (!isTensor(input) && !isConstantOrUninitialized(input->node()) &&
111111
ctx->shouldNodeRunInTensorRT(input->node())) {
112112
ctx->setNodeExecutorDecision(input->node(), NodeExecutorDecision::kNON_TENSOR);
113113
q.push(input->node());
@@ -118,7 +118,7 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
118118
if (!isTensor(output)) {
119119
for (auto use : output->uses()) {
120120
auto node = use.user;
121-
if (node->kind() != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT(node)) {
121+
if (isConstantOrUninitialized(node) && ctx->shouldNodeRunInTensorRT(node)) {
122122
ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR);
123123
q.push(node);
124124
}
@@ -128,11 +128,13 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
128128
}
129129
}
130130

131-
std::set<torch::jit::Node*> getDependentNodes(torch::jit::Node* n) {
132-
std::set<torch::jit::Node*> dependent_nodes;
131+
std::set<torch::jit::Node*> getUserNodes(torch::jit::Node* n) {
132+
std::set<torch::jit::Node*> user_nodes;
133133
for (auto val : n->outputs()) {
134134
for (auto use : val->uses()) {
135-
dependent_nodes.insert(use.user);
135+
if (use.user->owningBlock()->owningNode())
136+
user_nodes.insert(use.user->owningBlock()->owningNode());
137+
user_nodes.insert(use.user);
136138
}
137139
}
138140
if (const auto* schema = n->maybeSchema()) {
@@ -142,13 +144,13 @@ std::set<torch::jit::Node*> getDependentNodes(torch::jit::Node* n) {
142144
for (auto use : n->inputs()[i]->uses()) {
143145
torch::jit::Node* use_node = use.user;
144146
if (use_node->isAfter(n)) {
145-
dependent_nodes.insert(use_node);
147+
user_nodes.insert(use_node);
146148
}
147149
}
148150
}
149151
}
150152
}
151-
return dependent_nodes;
153+
return user_nodes;
152154
}
153155

154156
// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
@@ -158,14 +160,14 @@ std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx
158160
std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
159161
std::vector<torch::jit::Node*> min_block_fallback_nodes;
160162
for (const auto n : nodes) {
161-
if (n->kind() == torch::jit::prim::Constant) {
163+
if (isConstantOrUninitialized(n)) {
162164
continue;
163165
}
164166

165167
// check if current node fallback or not
166168
if (!ctx->shouldNodeRunInTorch(n)) {
167169
cur_trt_nodes.push_back(n);
168-
auto dependent_nodes = getDependentNodes(n);
170+
auto dependent_nodes = getUserNodes(n);
169171
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
170172
} else {
171173
if (cur_trt_nodes_uses.count(n)) {
@@ -250,7 +252,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(
250252
auto cur_val = q.front();
251253
q.pop();
252254
auto node = cur_val->node();
253-
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
255+
if (!isConstantOrUninitialized(node) && !visited.count(node)) {
254256
visited.insert(node);
255257
auto modifying_nodes = findModifyingNodes(cur_val, seg_block_nodes);
256258
stk.insert(stk.end(), modifying_nodes.rbegin(), modifying_nodes.rend());
@@ -454,10 +456,10 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
454456
std::unordered_set<torch::jit::Node*> cur_pyt_nodes_uses;
455457
for (const auto n : nodes) {
456458
// Skip constant nodes as they are resources for both kinds of modules
457-
if (n->kind() == torch::jit::prim::Constant) {
459+
if (isConstantOrUninitialized(n)) {
458460
continue;
459461
}
460-
auto dependent_nodes = getDependentNodes(n);
462+
auto dependent_nodes = getUserNodes(n);
461463
// the outputs of trt subgraph shouldn't be collections
462464
if (ctx->shouldNodeRunInTensorRT(n)) {
463465
in_prog_trt_blk_nodes.push_back(n);

core/partitioning/partitioningctx/PartitioningCtx.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ void PartitioningCtx::_load_nodes_into_decision_map(torch::jit::Block* b) {
2222
original_blocks.push_back(b);
2323

2424
for (const auto n : b->nodes()) {
25-
if (n->kind() == torch::jit::prim::Constant) {
25+
if (isConstantOrUninitialized(n)) {
2626
continue;
2727
}
2828
node_executor_decision_map[n] = NodeExecutorDecision::kUNKNOWN;

core/partitioning/partitioningctx/PartitioningCtx.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ struct PartitioningCtx {
7171

7272
std::ostream& operator<<(std::ostream& os, const PartitioningCtx& s);
7373

74+
inline bool isConstantOrUninitialized(torch::jit::Node* n) {
75+
return n->kind() == torch::jit::prim::Constant || n->kind() == torch::jit::prim::Uninitialized;
76+
}
77+
7478
} // namespace partitioning
7579
} // namespace core
7680
} // namespace torch_tensorrt

tests/core/partitioning/test_segmentation.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,45 @@ TEST(Partitioning, SegmentModelWithDependencyAwareness) {
331331
checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 2, 4}, {1, 3, 5}, {6, 7}}));
332332
}
333333

334+
TEST(Partitioning, ContainUninitializedValueCorrectly) {
335+
auto g = std::make_shared<torch::jit::Graph>();
336+
auto x = g->insertInput(0, "x");
337+
auto none_const_val = g->insertConstant(torch::jit::IValue());
338+
auto ivalue_1 = g->insertConstant(torch::jit::IValue(1));
339+
auto ivalue_2 = g->insertConstant(torch::jit::IValue(2));
340+
341+
auto uninitialized_node = g->createUninitialized(torch::jit::BoolType::get());
342+
g->appendNode(uninitialized_node);
343+
344+
auto x_dim = g->create(torch::jit::aten::dim, {x}, 1);
345+
g->appendNode(x_dim);
346+
x_dim->output()->setType(torch::jit::IntType::get());
347+
348+
349+
auto eq1 = g->create(torch::jit::aten::eq, {ivalue_1, x_dim->output()}, 1);
350+
g->appendNode(eq1);
351+
eq1->output()->setType(torch::jit::BoolType::get());
352+
353+
torch::jit::IValue except("EXCEPTION");
354+
auto exception_val = g->insertConstant(except);
355+
auto if_node = g->create(torch::jit::prim::If, {eq1->output()}, 1);
356+
auto if_block_0 = if_node->addBlock();
357+
auto exception_node = g->create(torch::jit::prim::RaiseException, {exception_val, none_const_val}, 0);
358+
if_block_0->appendNode(exception_node);
359+
if_block_0->registerOutput(uninitialized_node->output());
360+
361+
auto if_block_1 = if_node->addBlock();
362+
if_block_1->registerOutput(eq1->output());
363+
364+
g->insertNode(if_node);
365+
366+
PartitioningInfo partitioning_info;
367+
partitioning_info.enabled = true;
368+
PartitioningCtx ctx(g->block(), partitioning_info);
369+
segmentGraph(&ctx, g->block());
370+
ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2));
371+
}
372+
334373
} // namespace tests
335374
} // namespace partitioning
336375
} // namespace core

0 commit comments

Comments
 (0)