diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 935a9013cb..eaa8b1ad1f 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -70,7 +70,7 @@ void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) { const auto to_compile_sym = c10::Symbol::attr("to_compile"); for (const auto n : nodes) { - if (n->kind() == torch::jit::prim::Constant) { + if (isConstantOrUninitialized(n)) { continue; } @@ -107,7 +107,7 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vectorinputs()) { - if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant && + if (!isTensor(input) && !isConstantOrUninitialized(input->node()) && ctx->shouldNodeRunInTensorRT(input->node())) { ctx->setNodeExecutorDecision(input->node(), NodeExecutorDecision::kNON_TENSOR); q.push(input->node()); @@ -118,7 +118,7 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vectoruses()) { auto node = use.user; - if (node->kind() != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT(node)) { + if (!isConstantOrUninitialized(node) && ctx->shouldNodeRunInTensorRT(node)) { ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR); q.push(node); } @@ -128,11 +128,13 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector getDependentNodes(torch::jit::Node* n) { - std::set dependent_nodes; +std::set getUserNodes(torch::jit::Node* n) { + std::set user_nodes; for (auto val : n->outputs()) { for (auto use : val->uses()) { - dependent_nodes.insert(use.user); + if (use.user->owningBlock()->owningNode()) + user_nodes.insert(use.user->owningBlock()->owningNode()); + user_nodes.insert(use.user); } } if (const auto* schema = n->maybeSchema()) { @@ -142,13 +144,13 @@ std::set getDependentNodes(torch::jit::Node* n) { for (auto use : n->inputs()[i]->uses()) { torch::jit::Node* use_node = use.user; if (use_node->isAfter(n)) { - dependent_nodes.insert(use_node); + user_nodes.insert(use_node); } } } } } - return dependent_nodes; + return user_nodes; } // Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size @@ -158,14 +160,14 @@ std::vector traverseNodesForMinBlockSize(PartitioningCtx* ctx std::unordered_set cur_trt_nodes_uses; std::vector min_block_fallback_nodes; for (const auto n : nodes) { - if (n->kind() == torch::jit::prim::Constant) { + if (isConstantOrUninitialized(n)) { continue; } // check if current node fallback or not if (!ctx->shouldNodeRunInTorch(n)) { cur_trt_nodes.push_back(n); - auto dependent_nodes = getDependentNodes(n); + auto dependent_nodes = getUserNodes(n); cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end()); } else { if (cur_trt_nodes_uses.count(n)) { @@ -250,7 +252,7 @@ std::vector getDependencyNodes( auto cur_val = q.front(); q.pop(); auto node = cur_val->node(); - if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) { + if (!isConstantOrUninitialized(node) && !visited.count(node)) { visited.insert(node); auto modifying_nodes = findModifyingNodes(cur_val, seg_block_nodes); stk.insert(stk.end(), modifying_nodes.rbegin(), modifying_nodes.rend()); @@ -454,10 +456,10 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) { std::unordered_set cur_pyt_nodes_uses; for (const auto n : nodes) { // Skip constant nodes as they are resources for both kinds of modules - if (n->kind() == torch::jit::prim::Constant) { + if (isConstantOrUninitialized(n)) { continue; } - auto dependent_nodes = getDependentNodes(n); + auto dependent_nodes = getUserNodes(n); // the outputs of trt subgraph shouldn't be collections if (ctx->shouldNodeRunInTensorRT(n)) { in_prog_trt_blk_nodes.push_back(n); diff --git a/core/partitioning/partitioningctx/PartitioningCtx.cpp b/core/partitioning/partitioningctx/PartitioningCtx.cpp index 73fae28e91..6f10c5159f 100644 --- a/core/partitioning/partitioningctx/PartitioningCtx.cpp +++ b/core/partitioning/partitioningctx/PartitioningCtx.cpp @@ -15,13 +15,15 @@ PartitioningCtx::PartitioningCtx(torch::jit::Block* b, PartitioningInfo info) } void PartitioningCtx::_load_nodes_into_decision_map(torch::jit::Block* b) { - if (b->owningNode() && b->owningNode()->kind() == torch::jit::prim::Loop) + // won't load nodes if these nodes are in prim::loop or if these nodes are 2-level nested + if (b->owningNode() && + (b->owningNode()->kind() == torch::jit::prim::Loop || b->owningNode()->owningBlock()->owningNode())) return; original_blocks.push_back(b); for (const auto n : b->nodes()) { - if (n->kind() == torch::jit::prim::Constant) { + if (isConstantOrUninitialized(n)) { continue; } node_executor_decision_map[n] = NodeExecutorDecision::kUNKNOWN; diff --git a/core/partitioning/partitioningctx/PartitioningCtx.h b/core/partitioning/partitioningctx/PartitioningCtx.h index 91e376eab3..f3290d2aa3 100644 --- a/core/partitioning/partitioningctx/PartitioningCtx.h +++ b/core/partitioning/partitioningctx/PartitioningCtx.h @@ -71,6 +71,10 @@ struct PartitioningCtx { std::ostream& operator<<(std::ostream& os, const PartitioningCtx& s); +inline bool isConstantOrUninitialized(torch::jit::Node* n) { + return n->kind() == torch::jit::prim::Constant || n->kind() == torch::jit::prim::Uninitialized; +} + } // namespace partitioning } // namespace core } // namespace torch_tensorrt diff --git a/core/partitioning/segmentedblock/SegmentedBlock.cpp b/core/partitioning/segmentedblock/SegmentedBlock.cpp index 249a293bc3..c55127713b 100644 --- a/core/partitioning/segmentedblock/SegmentedBlock.cpp +++ b/core/partitioning/segmentedblock/SegmentedBlock.cpp @@ -46,6 +46,12 @@ torch::jit::Value* SegmentedBlock::getOrAddInputForValue(torch::jit::Value* old_ old_to_new_[old_value] = new_const->output(); return new_const->output(); } + if (node->kind() == torch::jit::prim::Uninitialized) { + auto new_uninitialized = g_->createUninitialized(old_value->type()); + g_->block()->prependNode(new_uninitialized); + old_to_new_[old_value] = new_uninitialized->output(); + return new_uninitialized->output(); + } auto new_value = g_->block()->addInput(); // every time when we addInput, we push back the corresponding lowering graph torch::jit::Value to our raw_inputs inputs_.push_back(old_value); diff --git a/tests/core/partitioning/test_segmentation.cpp b/tests/core/partitioning/test_segmentation.cpp index b365a95ad9..95dfd772c3 100644 --- a/tests/core/partitioning/test_segmentation.cpp +++ b/tests/core/partitioning/test_segmentation.cpp @@ -331,6 +331,44 @@ TEST(Partitioning, SegmentModelWithDependencyAwareness) { checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 2, 4}, {1, 3, 5}, {6, 7}})); } +TEST(Partitioning, ContainUninitializedValueCorrectly) { + auto g = std::make_shared(); + auto x = g->insertInput(0, "x"); + auto none_const_val = g->insertConstant(torch::jit::IValue()); + auto ivalue_1 = g->insertConstant(torch::jit::IValue(1)); + auto ivalue_2 = g->insertConstant(torch::jit::IValue(2)); + + auto uninitialized_node = g->createUninitialized(torch::jit::BoolType::get()); + g->appendNode(uninitialized_node); + + auto x_dim = g->create(torch::jit::aten::dim, {x}, 1); + g->appendNode(x_dim); + x_dim->output()->setType(torch::jit::IntType::get()); + + auto eq1 = g->create(torch::jit::aten::eq, {ivalue_1, x_dim->output()}, 1); + g->appendNode(eq1); + eq1->output()->setType(torch::jit::BoolType::get()); + + torch::jit::IValue except("EXCEPTION"); + auto exception_val = g->insertConstant(except); + auto if_node = g->create(torch::jit::prim::If, {eq1->output()}, 1); + auto if_block_0 = if_node->addBlock(); + auto exception_node = g->create(torch::jit::prim::RaiseException, {exception_val, none_const_val}, 0); + if_block_0->appendNode(exception_node); + if_block_0->registerOutput(uninitialized_node->output()); + + auto if_block_1 = if_node->addBlock(); + if_block_1->registerOutput(eq1->output()); + + g->insertNode(if_node); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2)); +} + } // namespace tests } // namespace partitioning } // namespace core