Skip to content

Commit 444617b

Browse files
committed
Fix: fix the bug that uninitialized tensor cannot be found
This PR fixes the bug that uninitialized tensor cannot be found. Specifically, uninitialized value is copied between subgraphs previously, this introduces unnecessary complexity and might incur bugs. In this PR, this uninitialized value is handled like constants: Create one uninitialized value for each subgraph in need.
1 parent 98376b3 commit 444617b

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

core/partitioning/partitioningctx/PartitioningCtx.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ PartitioningCtx::PartitioningCtx(torch::jit::Block* b, PartitioningInfo info)
1515
}
1616

1717
void PartitioningCtx::_load_nodes_into_decision_map(torch::jit::Block* b) {
18-
if (b->owningNode() && b->owningNode()->kind() == torch::jit::prim::Loop)
18+
// won't load nodes if these nodes are in prim::loop or if these nodes are 2-level nested
19+
if (b->owningNode() && (b->owningNode()->kind() == torch::jit::prim::Loop || b->owningNode()->owningBlock()->owningNode()))
1920
return;
2021

2122
original_blocks.push_back(b);

core/partitioning/segmentedblock/SegmentedBlock.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ torch::jit::Value* SegmentedBlock::getOrAddInputForValue(torch::jit::Value* old_
4646
old_to_new_[old_value] = new_const->output();
4747
return new_const->output();
4848
}
49+
if (node->kind() == torch::jit::prim::Uninitialized) {
50+
auto new_uninitialized = g_->createUninitialized(old_value->type());
51+
g_->block()->prependNode(new_uninitialized);
52+
old_to_new_[old_value] = new_uninitialized->output();
53+
return new_uninitialized->output();
54+
}
4955
auto new_value = g_->block()->addInput();
5056
// every time when we addInput, we push back the corresponding lowering graph torch::jit::Value to our raw_inputs
5157
inputs_.push_back(old_value);

0 commit comments

Comments
 (0)