Skip to content

Commit 4f6ddbe

Browse files
committed
Fix: fix the bug that uninitialized tensor cannot be found
This PR fixes the bug that uninitialized tensor cannot be found. Specifically, uninitialized value is copied between subgraphs previously, this introduces unnecessary complexity and might incur bugs. In this PR, this uninitialized value is handled like constants: Create one uninitialized value for each subgraph in need.
1 parent 5dab7ca commit 4f6ddbe

File tree

5 files changed

+67
-15
lines changed

5 files changed

+67
-15
lines changed

core/partitioning/partitioning.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
7070
const auto to_compile_sym = c10::Symbol::attr("to_compile");
7171

7272
for (const auto n : nodes) {
73-
if (n->kind() == torch::jit::prim::Constant) {
73+
if (isConstantOrUninitialized(n)) {
7474
continue;
7575
}
7676

@@ -107,7 +107,7 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
107107
q.pop();
108108
// for every node that produces this fallback node's NonTensor input, they should fallback too
109109
for (auto input : cur_node->inputs()) {
110-
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
110+
if (!isTensor(input) && !isConstantOrUninitialized(input->node()) &&
111111
ctx->shouldNodeRunInTensorRT(input->node())) {
112112
ctx->setNodeExecutorDecision(input->node(), NodeExecutorDecision::kNON_TENSOR);
113113
q.push(input->node());
@@ -118,7 +118,7 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
118118
if (!isTensor(output)) {
119119
for (auto use : output->uses()) {
120120
auto node = use.user;
121-
if (node->kind() != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT(node)) {
121+
if (!isConstantOrUninitialized(node) && ctx->shouldNodeRunInTensorRT(node)) {
122122
ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR);
123123
q.push(node);
124124
}
@@ -128,11 +128,13 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
128128
}
129129
}
130130

131-
std::set<torch::jit::Node*> getDependentNodes(torch::jit::Node* n) {
132-
std::set<torch::jit::Node*> dependent_nodes;
131+
std::set<torch::jit::Node*> getUserNodes(torch::jit::Node* n) {
132+
std::set<torch::jit::Node*> user_nodes;
133133
for (auto val : n->outputs()) {
134134
for (auto use : val->uses()) {
135-
dependent_nodes.insert(use.user);
135+
if (use.user->owningBlock()->owningNode())
136+
user_nodes.insert(use.user->owningBlock()->owningNode());
137+
user_nodes.insert(use.user);
136138
}
137139
}
138140
if (const auto* schema = n->maybeSchema()) {
@@ -142,13 +144,13 @@ std::set<torch::jit::Node*> getDependentNodes(torch::jit::Node* n) {
142144
for (auto use : n->inputs()[i]->uses()) {
143145
torch::jit::Node* use_node = use.user;
144146
if (use_node->isAfter(n)) {
145-
dependent_nodes.insert(use_node);
147+
user_nodes.insert(use_node);
146148
}
147149
}
148150
}
149151
}
150152
}
151-
return dependent_nodes;
153+
return user_nodes;
152154
}
153155

154156
// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
@@ -158,14 +160,14 @@ std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx
158160
std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
159161
std::vector<torch::jit::Node*> min_block_fallback_nodes;
160162
for (const auto n : nodes) {
161-
if (n->kind() == torch::jit::prim::Constant) {
163+
if (isConstantOrUninitialized(n)) {
162164
continue;
163165
}
164166

165167
// check if current node fallback or not
166168
if (!ctx->shouldNodeRunInTorch(n)) {
167169
cur_trt_nodes.push_back(n);
168-
auto dependent_nodes = getDependentNodes(n);
170+
auto dependent_nodes = getUserNodes(n);
169171
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
170172
} else {
171173
if (cur_trt_nodes_uses.count(n)) {
@@ -250,7 +252,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(
250252
auto cur_val = q.front();
251253
q.pop();
252254
auto node = cur_val->node();
253-
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
255+
if (!isConstantOrUninitialized(node) && !visited.count(node)) {
254256
visited.insert(node);
255257
auto modifying_nodes = findModifyingNodes(cur_val, seg_block_nodes);
256258
stk.insert(stk.end(), modifying_nodes.rbegin(), modifying_nodes.rend());
@@ -454,10 +456,10 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
454456
std::unordered_set<torch::jit::Node*> cur_pyt_nodes_uses;
455457
for (const auto n : nodes) {
456458
// Skip constant nodes as they are resources for both kinds of modules
457-
if (n->kind() == torch::jit::prim::Constant) {
459+
if (isConstantOrUninitialized(n)) {
458460
continue;
459461
}
460-
auto dependent_nodes = getDependentNodes(n);
462+
auto dependent_nodes = getUserNodes(n);
461463
// the outputs of trt subgraph shouldn't be collections
462464
if (ctx->shouldNodeRunInTensorRT(n)) {
463465
in_prog_trt_blk_nodes.push_back(n);

core/partitioning/partitioningctx/PartitioningCtx.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@ PartitioningCtx::PartitioningCtx(torch::jit::Block* b, PartitioningInfo info)
1515
}
1616

1717
void PartitioningCtx::_load_nodes_into_decision_map(torch::jit::Block* b) {
18-
if (b->owningNode() && b->owningNode()->kind() == torch::jit::prim::Loop)
18+
// won't load nodes if these nodes are in prim::loop or if these nodes are 2-level nested
19+
if (b->owningNode() &&
20+
(b->owningNode()->kind() == torch::jit::prim::Loop || b->owningNode()->owningBlock()->owningNode()))
1921
return;
2022

2123
original_blocks.push_back(b);
2224

2325
for (const auto n : b->nodes()) {
24-
if (n->kind() == torch::jit::prim::Constant) {
26+
if (isConstantOrUninitialized(n)) {
2527
continue;
2628
}
2729
node_executor_decision_map[n] = NodeExecutorDecision::kUNKNOWN;

core/partitioning/partitioningctx/PartitioningCtx.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ struct PartitioningCtx {
7171

7272
std::ostream& operator<<(std::ostream& os, const PartitioningCtx& s);
7373

74+
inline bool isConstantOrUninitialized(torch::jit::Node* n) {
75+
return n->kind() == torch::jit::prim::Constant || n->kind() == torch::jit::prim::Uninitialized;
76+
}
77+
7478
} // namespace partitioning
7579
} // namespace core
7680
} // namespace torch_tensorrt

core/partitioning/segmentedblock/SegmentedBlock.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ torch::jit::Value* SegmentedBlock::getOrAddInputForValue(torch::jit::Value* old_
4646
old_to_new_[old_value] = new_const->output();
4747
return new_const->output();
4848
}
49+
if (node->kind() == torch::jit::prim::Uninitialized) {
50+
auto new_uninitialized = g_->createUninitialized(old_value->type());
51+
g_->block()->prependNode(new_uninitialized);
52+
old_to_new_[old_value] = new_uninitialized->output();
53+
return new_uninitialized->output();
54+
}
4955
auto new_value = g_->block()->addInput();
5056
// every time when we addInput, we push back the corresponding lowering graph torch::jit::Value to our raw_inputs
5157
inputs_.push_back(old_value);

tests/core/partitioning/test_segmentation.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,44 @@ TEST(Partitioning, SegmentModelWithDependencyAwareness) {
331331
checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 2, 4}, {1, 3, 5}, {6, 7}}));
332332
}
333333

334+
TEST(Partitioning, ContainUninitializedValueCorrectly) {
335+
auto g = std::make_shared<torch::jit::Graph>();
336+
auto x = g->insertInput(0, "x");
337+
auto none_const_val = g->insertConstant(torch::jit::IValue());
338+
auto ivalue_1 = g->insertConstant(torch::jit::IValue(1));
339+
auto ivalue_2 = g->insertConstant(torch::jit::IValue(2));
340+
341+
auto uninitialized_node = g->createUninitialized(torch::jit::BoolType::get());
342+
g->appendNode(uninitialized_node);
343+
344+
auto x_dim = g->create(torch::jit::aten::dim, {x}, 1);
345+
g->appendNode(x_dim);
346+
x_dim->output()->setType(torch::jit::IntType::get());
347+
348+
auto eq1 = g->create(torch::jit::aten::eq, {ivalue_1, x_dim->output()}, 1);
349+
g->appendNode(eq1);
350+
eq1->output()->setType(torch::jit::BoolType::get());
351+
352+
torch::jit::IValue except("EXCEPTION");
353+
auto exception_val = g->insertConstant(except);
354+
auto if_node = g->create(torch::jit::prim::If, {eq1->output()}, 1);
355+
auto if_block_0 = if_node->addBlock();
356+
auto exception_node = g->create(torch::jit::prim::RaiseException, {exception_val, none_const_val}, 0);
357+
if_block_0->appendNode(exception_node);
358+
if_block_0->registerOutput(uninitialized_node->output());
359+
360+
auto if_block_1 = if_node->addBlock();
361+
if_block_1->registerOutput(eq1->output());
362+
363+
g->insertNode(if_node);
364+
365+
PartitioningInfo partitioning_info;
366+
partitioning_info.enabled = true;
367+
PartitioningCtx ctx(g->block(), partitioning_info);
368+
segmentGraph(&ctx, g->block());
369+
ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2));
370+
}
371+
334372
} // namespace tests
335373
} // namespace partitioning
336374
} // namespace core

0 commit comments

Comments
 (0)