From 3a33b6ec3a8f7b13fe381d96a5913bddc7a4172c Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Fri, 12 Aug 2022 17:27:31 -0700 Subject: [PATCH 1/5] [feat] Add dependency awareness to torch-trt partitioning (#40) Adds a heuristic to torch-trt partitioning's segmentation to avoid materializing segments until we hit a dependency of that segment. This can significantly reduce the number of segments/engines in cases where the linear traversal of torchscipt nodes would otherwise produce alternating torch and TRT segments which are not dependent on each-other Fixes # (issue) Please delete options that are not relevant and/or add your own. - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - This change requires a documentation update - [ ] My code follows the style guidelines of this project (You can use the linters) - [ ] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas and hacks - [ ] I have made corresponding changes to the documentation - [ ] I have added tests to verify my fix or my feature - [ ] New and existing unit tests pass locally with my changes - [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified --- core/partitioning/partitioning.cpp | 136 +++- .../segmentedblock/SegmentedBlock.h | 9 + tests/core/partitioning/test_conditionals.cpp | 2 +- .../test_resolve_nontensor_inputs.cpp | 2 +- tests/core/partitioning/test_segmentation.cpp | 646 +++++++++--------- 5 files changed, 466 insertions(+), 329 deletions(-) diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index eb8c86de50..e5e16b9029 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -111,10 +111,34 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector getDependentNodes(torch::jit::Node* n) { + std::set dependent_nodes; + for (auto val : n->outputs()) { + for (auto use : val->uses()) { + dependent_nodes.insert(use.user); + } + } + if (const auto* schema = n->maybeSchema()) { + for (size_t i = 0; i < n->inputs().size(); ++i) { + const at::AliasInfo* formal = schema->arguments()[i].alias_info(); + if (formal && formal->isWrite()) { + for (auto use : n->inputs()[i]->uses()) { + torch::jit::Node* use_node = use.user; + if (use_node->isAfter(n)) { + dependent_nodes.insert(use_node); + } + } + } + } + } + return dependent_nodes; +} + // Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size std::vector traverseNodesForMinBlockSize(PartitioningCtx* ctx, torch::jit::Block* block) { auto nodes = block->nodes(); std::vector cur_trt_nodes; + std::unordered_set cur_trt_nodes_uses; std::vector min_block_fallback_nodes; for (const auto n : nodes) { if (n->kind() == torch::jit::prim::Constant) { @@ -124,11 +148,16 @@ std::vector traverseNodesForMinBlockSize(PartitioningCtx* ctx // check if current node fallback or not if (!ctx->shouldNodeRunInTorch(n)) { cur_trt_nodes.push_back(n); + auto dependent_nodes = getDependentNodes(n); + cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end()); } else { - if (cur_trt_nodes.size() < ctx->settings.min_block_size) { - min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); + if (cur_trt_nodes_uses.count(n)) { + if (cur_trt_nodes.size() < ctx->settings.min_block_size) { + min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); + } + cur_trt_nodes.clear(); + cur_trt_nodes_uses.clear(); } - cur_trt_nodes.clear(); } } if (cur_trt_nodes.size() < ctx->settings.min_block_size) { @@ -355,6 +384,59 @@ void setNodeExecutorLUT(PartitioningCtx* ctx, torch::jit::Block* block) { setMinBlockFallbackNodes(ctx, block); } +void merge_adjacent_segments_list_in_new_partition( + PartitionedGraph& original_partition, + PartitionedGraph& new_partition, + SegmentedBlock::SegmentedBlockTarget& segment_kind, + std::vector& same_type_segment_idx) { + TORCHTRT_CHECK(!same_type_segment_idx.empty(), "Unable to merge empty segment list"); + if (same_type_segment_idx.size() == 1) { + new_partition.push_back(original_partition[same_type_segment_idx[0]]); + } else { + auto first_idx = same_type_segment_idx[0]; + for (size_t i = 1; i < same_type_segment_idx.size(); ++i) { + TORCHTRT_CHECK( + same_type_segment_idx[i] == (first_idx + i), + "Unable to merge non-sequential segments: " << same_type_segment_idx); + } + LOG_DEBUG( + "Merging adjacent " << SegmentedBlock::target_to_str(segment_kind) << " segments: " << same_type_segment_idx); + std::vector nodes; + for (auto segment_to_merge : same_type_segment_idx) { + const auto& merge_nodes = original_partition[segment_to_merge].raw_nodes(); + nodes.insert(nodes.end(), merge_nodes.begin(), merge_nodes.end()); + } + new_partition.emplace_back(segment_kind, nodes); + } +} + +PartitionedGraph merge_adjacent_segments_of_same_type(PartitionedGraph& original_partition) { + PartitionedGraph new_partition; + SegmentedBlock::SegmentedBlockTarget segment_kind = SegmentedBlock::SegmentedBlockTarget::kTorch; + std::vector same_type_segment_idx; + for (size_t i = 0UL; i < original_partition.size(); ++i) { + auto& segment = original_partition[i]; + if (same_type_segment_idx.empty()) { + segment_kind = segment.target(); + } else if (segment_kind != segment.target() || segment.do_not_merge()) { + merge_adjacent_segments_list_in_new_partition( + original_partition, new_partition, segment_kind, same_type_segment_idx); + same_type_segment_idx.clear(); + segment_kind = segment.target(); + } + if (segment.do_not_merge()) { + new_partition.push_back(segment); + } else { + same_type_segment_idx.push_back(i); + } + } + if (!same_type_segment_idx.empty()) { + merge_adjacent_segments_list_in_new_partition( + original_partition, new_partition, segment_kind, same_type_segment_idx); + } + return new_partition; +} + void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) { // Find all the fallback nodes and build execution decision LUT for all nodes setNodeExecutorLUT(ctx, block); @@ -365,34 +447,45 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) { PartitionedGraph segmented_blocks; std::vector in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes; + std::unordered_set cur_trt_nodes_uses; + std::unordered_set cur_pyt_nodes_uses; for (const auto n : nodes) { // Skip constant nodes as they are resources for both kinds of modules if (n->kind() == torch::jit::prim::Constant) { continue; } + auto dependent_nodes = getDependentNodes(n); // the outputs of trt subgraph shouldn't be collections if (ctx->shouldNodeRunInTensorRT(n)) { in_prog_trt_blk_nodes.push_back(n); + cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end()); - // If there is an active PyTorch block and we have passed the threshold for a valid TRT - // block then segment and reset the active PyTorch block - if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size && !in_prog_pyt_blk_nodes.empty()) { + // If we hit a TRT node that is dependent on nodes in the active PyTorch block, finalize the block to materialize + // those dependencies in the graph + if (cur_pyt_nodes_uses.count(n)) { finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); + cur_pyt_nodes_uses.clear(); } } else { - // If there is an active TRT block that is valid segment and reset the active TRT block - // otherwise add it to the active PyTorch block and reset - if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) { - finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes); - } else { - LOG_DEBUG( - "In progress TRT block does not meet minimum block size requirements (" - << in_prog_trt_blk_nodes.size() << ", expected at least " << ctx->settings.min_block_size - << "), therefore folding into in progress PyTorch block"); - in_prog_pyt_blk_nodes.insert( - in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end()); + // The current node is dependent on the active TRT block, finalize it to materialize those dependencies in the + // graph or add them to the active PyTorch block + if (cur_trt_nodes_uses.count(n)) { + // If there is an active TRT block that is valid segment and reset the active TRT block + // otherwise add it to the active PyTorch block and reset + if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) { + finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes); + } else { + LOG_DEBUG( + "In progress TRT block does not meet minimum block size requirements (" + << in_prog_trt_blk_nodes.size() << ", expected at least " << ctx->settings.min_block_size + << "), therefore folding into in progress PyTorch block"); + in_prog_pyt_blk_nodes.insert( + in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end()); + cur_pyt_nodes_uses.insert(cur_trt_nodes_uses.begin(), cur_trt_nodes_uses.end()); + } + in_prog_trt_blk_nodes.clear(); + cur_trt_nodes_uses.clear(); } - in_prog_trt_blk_nodes.clear(); // if there is a prim::If then this if node will be encapsulated in a SegmentedBlock // we shouldn't inject node for this block in dependency analysis process if (n->kind() == torch::jit::prim::If) { @@ -400,23 +493,29 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) { "Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional"); if (!in_prog_pyt_blk_nodes.empty()) { finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); + cur_pyt_nodes_uses.clear(); } auto cond_node = std::vector{n}; finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, cond_node); + segmented_blocks.back().do_not_merge(true); continue; } else if (n->kind() == torch::jit::prim::Loop) { if (!in_prog_pyt_blk_nodes.empty()) { finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); + cur_pyt_nodes_uses.clear(); } if (checkLoopEvaluatable(n)) { in_prog_trt_blk_nodes.push_back(n); + cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end()); } else { auto loop_node = std::vector{n}; finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, loop_node); + segmented_blocks.back().do_not_merge(true); } continue; } in_prog_pyt_blk_nodes.push_back(n); + cur_pyt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end()); } } @@ -432,6 +531,7 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) { finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); } + segmented_blocks = merge_adjacent_segments_of_same_type(segmented_blocks); ctx->partitioned_blocks.insert({block, segmented_blocks}); return; } diff --git a/core/partitioning/segmentedblock/SegmentedBlock.h b/core/partitioning/segmentedblock/SegmentedBlock.h index 0e04237f63..0cea11e99d 100644 --- a/core/partitioning/segmentedblock/SegmentedBlock.h +++ b/core/partitioning/segmentedblock/SegmentedBlock.h @@ -94,6 +94,14 @@ struct SegmentedBlock { return target_; } + bool do_not_merge(void) const { + return do_not_merge_; + } + + void do_not_merge(bool x) { + do_not_merge_ = x; + } + friend std::ostream& operator<<(std::ostream& os, const SegmentedBlock& b); private: @@ -106,6 +114,7 @@ struct SegmentedBlock { std::vector nodes_; std::shared_ptr g_; std::unordered_map old_to_new_; + bool do_not_merge_ = false; }; std::ostream& operator<<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t); diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp index ba336db663..e8dfeacdb1 100644 --- a/tests/core/partitioning/test_conditionals.cpp +++ b/tests/core/partitioning/test_conditionals.cpp @@ -40,7 +40,7 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) { auto conditional_engines_count = count_trt_engines_in_conditionals(new_g); - ASSERT_TRUE(conditional_engines_count == 2); + ASSERT_TRUE(conditional_engines_count == 1); } TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) { diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index 950859e524..d0745b0a2c 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -204,7 +204,7 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) { })); } } - ASSERT_TRUE(trt_block_cnt == 2 && torch_block_cnt == 2); + ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 1); } TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) { diff --git a/tests/core/partitioning/test_segmentation.cpp b/tests/core/partitioning/test_segmentation.cpp index 8d47af553e..c17f1c9172 100644 --- a/tests/core/partitioning/test_segmentation.cpp +++ b/tests/core/partitioning/test_segmentation.cpp @@ -1,309 +1,337 @@ -#include -#include "core/partitioning/partitioning.h" -#include "gtest/gtest.h" -#include "tests/util/util.h" -#include "torch/csrc/jit/ir/irparser.h" -#include "torch/script.h" -#include "torch_tensorrt/torch_tensorrt.h" - -namespace torch_tensorrt { -namespace core { -namespace partitioning { -namespace tests { - -bool checkSegmentedBlockNumber( - PartitionedGraph& segmented_blocks, - SegmentedBlock::SegmentedBlockTarget target, - int target_count) { - int64_t cnt = 0; - for (auto& seg_block : segmented_blocks) { - if (seg_block.target() == target) { - cnt++; - } - } - std::cout << "Found count of " << cnt << " " << target << " blocks (looking for " << target_count << " blocks)" - << std::endl; - - if (target_count != cnt) { - std::cout << segmented_blocks << std::endl; - } - - return target_count == cnt; -} - -bool checkSegmentedBlockNodesMapping( - std::vector& segmented_blocks, - std::shared_ptr g, - std::vector> nodes_index) { - std::vector graph_nodes; - for (const auto n : g->nodes()) { - if (n->kind() != torch::jit::prim::Constant) { - graph_nodes.push_back(n); - } - } - for (size_t i = 0; i < nodes_index.size(); ++i) { - size_t seg_block_node_id = 0; - for (int j : nodes_index[i]) { - if (segmented_blocks[i].raw_nodes()[seg_block_node_id++] != graph_nodes[j]) { - return false; - } - } - if (seg_block_node_id != segmented_blocks[i].raw_nodes().size()) - return false; - } - return true; -} - -TEST(Partitioning, SegmentSequentialModelCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %b1 : Float(32), - %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %b2 : Float(16), - %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), - %b3 : Float(8)): - %2 : int[] = prim::Constant[value=[1, 1]]() - %3 : int = prim::Constant[value=1]() - %10 : bool = prim::Constant[value=0]() - %11 : int[] = prim::Constant[value=[0, 0]]() - %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %13 : Tensor = aten::relu(%12) - %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %15 : Tensor = aten::log_sigmoid(%14) - %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3}, {4}})); -} - -TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %b1 : Float(32), - %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %b2 : Float(16), - %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), - %b3 : Float(8)): - %2 : int[] = prim::Constant[value=[1, 1]]() - %3 : int = prim::Constant[value=1]() - %10 : bool = prim::Constant[value=0]() - %11 : int[] = prim::Constant[value=[0, 0]]() - %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %13 : Tensor = aten::relu(%12) - %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %15 : Tensor = aten::log_sigmoid(%14) - %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - partitioning_info.min_block_size = 3; - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4}})); -} - -TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %1 : Tensor, - %2 : Tensor): - %3 : int[] = prim::Constant[value=[-1, 5]]() - %4 : int[] = prim::Constant[value=[-1]]() - %5 : int = prim::Constant[value=2]() - %6 : int = prim::Constant[value=4]() - %7 : int = prim::Constant[value=5]() - %8 : int = prim::Constant[value=0]() - %9 : bool = prim::Constant[value=0]() - %10 : NoneType = prim::Constant() - %11 : int = prim::Constant[value=1]() - %12: Tensor = aten::reshape(%1, %4) - %13: Tensor = aten::reshape(%2, %3) - %14: Tensor = aten::reshape(%1, %3) - %15 : Tensor = aten::to(%12, %6, %9, %9, %10) - %16 : int = aten::size(%1, %8) - %17 : int[] = prim::ListConstruct(%16, %6, %5, %7) - %18 : Tensor = aten::index_add_(%14, %8, %15, %13, %11) - %20 : Tensor = aten::reshape(%18, %17) - return (%20))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - partitioning_info.min_block_size = 3; - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2, 3}, {4, 5, 6, 7}})); -} - -TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %b1 : Float(32), - %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %b2 : Float(16), - %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), - %b3 : Float(8)): - %2 : int[] = prim::Constant[value=[1, 1]]() - %3 : int = prim::Constant[value=1]() - %10 : bool = prim::Constant[value=0]() - %11 : int[] = prim::Constant[value=[0, 0]]() - %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %13 : Tensor = aten::relu(%12) - %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %15 : Tensor = aten::log_sigmoid(%14) - %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - partitioning_info.forced_fallback_operators.push_back("aten::relu"); - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 3)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0}, {1}, {2}, {3}, {4}})); -} - -TEST(Partitioning, SegmentBranchModelCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %2 : Float(32), - %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %4 : Float(16)): - %5 : int[] = prim::Constant[value=[0, 0]]() - %6 : int[] = prim::Constant[value=[2, 2]]() - %7 : bool = prim::Constant[value=0]() - %8 : int[] = prim::Constant[value=[1, 1]]() - %9 : int = prim::Constant[value=1]() - %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %12: Tensor = aten::log_sigmoid(%10) - %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %14 : Tensor = aten::relu(%11) - %15 : Tensor = aten::add(%13, %14, %9) - %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2}, {3, 4, 5, 6}})); -} - -TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %2 : Float(32), - %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %4 : Float(16)): - %5 : int[] = prim::Constant[value=[0, 0]]() - %6 : int[] = prim::Constant[value=[2, 2]]() - %7 : bool = prim::Constant[value=0]() - %8 : int[] = prim::Constant[value=[1, 1]]() - %9 : int = prim::Constant[value=1]() - %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %12: Tensor = aten::log_sigmoid(%10) - %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %14 : Tensor = aten::relu(%11) - %15 : Tensor = aten::add(%13, %14, %9) - %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - partitioning_info.min_block_size = 3; - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4, 5, 6}})); -} - -TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %2 : Float(32), - %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %4 : Float(16)): - %5 : int[] = prim::Constant[value=[0, 0]]() - %6 : int[] = prim::Constant[value=[2, 2]]() - %7 : bool = prim::Constant[value=0]() - %8 : int[] = prim::Constant[value=[1, 1]]() - %9 : int = prim::Constant[value=1]() - %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - - %12: Tensor = aten::log_sigmoid(%10) - - %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - - %14 : Tensor = aten::relu(%11) - - %15 : Tensor = aten::add(%13, %14, %9) - %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - partitioning_info.forced_fallback_operators.push_back("aten::relu"); - PartitioningCtx ctx(g->block(), partitioning_info); - - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 3)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2)); - ASSERT_TRUE( - checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2}, {3}, {4}, {5, 6}})); -} - -} // namespace tests -} // namespace partitioning -} // namespace core -} // namespace torch_tensorrt +#include +#include "core/partitioning/partitioning.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/script.h" +#include "torch_tensorrt/torch_tensorrt.h" + +namespace torch_tensorrt { +namespace core { +namespace partitioning { +namespace tests { + +bool checkSegmentedBlockNumber( + PartitionedGraph& segmented_blocks, + SegmentedBlock::SegmentedBlockTarget target, + int target_count) { + int64_t cnt = 0; + for (auto& seg_block : segmented_blocks) { + if (seg_block.target() == target) { + cnt++; + } + } + std::cout << "Found count of " << cnt << " " << target << " blocks (looking for " << target_count << " blocks)" + << std::endl; + + if (target_count != cnt) { + std::cout << segmented_blocks << std::endl; + } + + return target_count == cnt; +} + +bool checkSegmentedBlockNodesMapping( + std::vector& segmented_blocks, + std::shared_ptr g, + std::vector> nodes_index) { + std::vector graph_nodes; + for (const auto n : g->nodes()) { + if (n->kind() != torch::jit::prim::Constant) { + graph_nodes.push_back(n); + } + } + for (size_t i = 0; i < nodes_index.size(); ++i) { + size_t seg_block_node_id = 0; + for (int j : nodes_index[i]) { + if (segmented_blocks[i].raw_nodes()[seg_block_node_id++] != graph_nodes[j]) { + return false; + } + } + if (seg_block_node_id != segmented_blocks[i].raw_nodes().size()) + return false; + } + return true; +} + +TEST(Partitioning, SegmentSequentialModelCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %b1 : Float(32), + %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %b2 : Float(16), + %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), + %b3 : Float(8)): + %2 : int[] = prim::Constant[value=[1, 1]]() + %3 : int = prim::Constant[value=1]() + %10 : bool = prim::Constant[value=0]() + %11 : int[] = prim::Constant[value=[0, 0]]() + %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %13 : Tensor = aten::relu(%12) + %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %15 : Tensor = aten::log_sigmoid(%14) + %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3}, {4}})); +} + +TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %b1 : Float(32), + %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %b2 : Float(16), + %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), + %b3 : Float(8)): + %2 : int[] = prim::Constant[value=[1, 1]]() + %3 : int = prim::Constant[value=1]() + %10 : bool = prim::Constant[value=0]() + %11 : int[] = prim::Constant[value=[0, 0]]() + %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %13 : Tensor = aten::relu(%12) + %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %15 : Tensor = aten::log_sigmoid(%14) + %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.min_block_size = 3; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4}})); +} + +TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Tensor, + %2 : Tensor): + %3 : int[] = prim::Constant[value=[-1, 5]]() + %4 : int[] = prim::Constant[value=[-1]]() + %5 : int = prim::Constant[value=2]() + %6 : int = prim::Constant[value=4]() + %7 : int = prim::Constant[value=5]() + %8 : int = prim::Constant[value=0]() + %9 : bool = prim::Constant[value=0]() + %10 : NoneType = prim::Constant() + %11 : int = prim::Constant[value=1]() + %12: Tensor = aten::reshape(%1, %4) + %13: Tensor = aten::reshape(%2, %3) + %14: Tensor = aten::reshape(%1, %3) + %15 : Tensor = aten::to(%12, %6, %9, %9, %10) + %16 : int = aten::size(%1, %8) + %17 : int[] = prim::ListConstruct(%16, %6, %5, %7) + %18 : Tensor = aten::index_add_(%14, %8, %15, %13, %11) + %20 : Tensor = aten::reshape(%18, %17) + return (%20))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.min_block_size = 3; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2, 3}, {4, 5, 6, 7}})); +} + +TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %b1 : Float(32), + %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %b2 : Float(16), + %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), + %b3 : Float(8)): + %2 : int[] = prim::Constant[value=[1, 1]]() + %3 : int = prim::Constant[value=1]() + %10 : bool = prim::Constant[value=0]() + %11 : int[] = prim::Constant[value=[0, 0]]() + %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %13 : Tensor = aten::relu(%12) + %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %15 : Tensor = aten::log_sigmoid(%14) + %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.forced_fallback_operators.push_back("aten::relu"); + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 3)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0}, {1}, {2}, {3}, {4}})); +} + +TEST(Partitioning, SegmentBranchModelCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %2 : Float(32), + %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %4 : Float(16)): + %5 : int[] = prim::Constant[value=[0, 0]]() + %6 : int[] = prim::Constant[value=[2, 2]]() + %7 : bool = prim::Constant[value=0]() + %8 : int[] = prim::Constant[value=[1, 1]]() + %9 : int = prim::Constant[value=1]() + %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %12: Tensor = aten::log_sigmoid(%10) + %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %14 : Tensor = aten::relu(%11) + %15 : Tensor = aten::add(%13, %14, %9) + %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2}, {3, 4, 5, 6}})); +} + +TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %2 : Float(32), + %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %4 : Float(16)): + %5 : int[] = prim::Constant[value=[0, 0]]() + %6 : int[] = prim::Constant[value=[2, 2]]() + %7 : bool = prim::Constant[value=0]() + %8 : int[] = prim::Constant[value=[1, 1]]() + %9 : int = prim::Constant[value=1]() + %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %12: Tensor = aten::log_sigmoid(%10) + %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %14 : Tensor = aten::relu(%11) + %15 : Tensor = aten::add(%13, %14, %9) + %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.min_block_size = 3; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4, 5, 6}})); +} + +TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %2 : Float(32), + %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %4 : Float(16)): + %5 : int[] = prim::Constant[value=[0, 0]]() + %6 : int[] = prim::Constant[value=[2, 2]]() + %7 : bool = prim::Constant[value=0]() + %8 : int[] = prim::Constant[value=[1, 1]]() + %9 : int = prim::Constant[value=1]() + %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + + %12: Tensor = aten::log_sigmoid(%10) + + %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + + %14 : Tensor = aten::relu(%11) + + %15 : Tensor = aten::add(%13, %14, %9) + %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.forced_fallback_operators.push_back("aten::relu"); + PartitioningCtx ctx(g->block(), partitioning_info); + + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE( + checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2, 4}, {3, 5, 6}})); +} + +TEST(Partitioning, SegmentModelWithDependencyAwareness) { + const auto graph = R"IR( + graph(%x : Tensor, + %y : Tensor): + %3 : int = prim::Constant[value=0]() + %20 : int = prim::Constant[value=1]() + %add : Tensor = aten::add(%x, %y, %20) + %x_lgamma : Tensor = aten::lgamma(%x) + %mul : Tensor = aten::mul(%x, %y) + %y_lgamma : Tensor = aten::lgamma(%y) + %div : Tensor = aten::div(%x, %y) + %div_lgamma : Tensor = aten::lgamma(%div) + %27 : Tensor[] = prim::ListConstruct(%x_lgamma, %y_lgamma, %div_lgamma, %add, %mul) + %12 : Tensor = aten::cat(%27, %3) + return (%12))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 2, 4}, {1, 3, 5}, {6, 7}})); +} + +} // namespace tests +} // namespace partitioning +} // namespace core +} // namespace torch_tensorrt From 119fd0a1fa2f8b2530b28618e448d837ee92e503 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Thu, 6 Oct 2022 15:53:26 -0700 Subject: [PATCH 2/5] lint --- tests/core/partitioning/test_segmentation.cpp | 674 +++++++++--------- 1 file changed, 337 insertions(+), 337 deletions(-) diff --git a/tests/core/partitioning/test_segmentation.cpp b/tests/core/partitioning/test_segmentation.cpp index c17f1c9172..b365a95ad9 100644 --- a/tests/core/partitioning/test_segmentation.cpp +++ b/tests/core/partitioning/test_segmentation.cpp @@ -1,337 +1,337 @@ -#include -#include "core/partitioning/partitioning.h" -#include "gtest/gtest.h" -#include "tests/util/util.h" -#include "torch/csrc/jit/ir/irparser.h" -#include "torch/script.h" -#include "torch_tensorrt/torch_tensorrt.h" - -namespace torch_tensorrt { -namespace core { -namespace partitioning { -namespace tests { - -bool checkSegmentedBlockNumber( - PartitionedGraph& segmented_blocks, - SegmentedBlock::SegmentedBlockTarget target, - int target_count) { - int64_t cnt = 0; - for (auto& seg_block : segmented_blocks) { - if (seg_block.target() == target) { - cnt++; - } - } - std::cout << "Found count of " << cnt << " " << target << " blocks (looking for " << target_count << " blocks)" - << std::endl; - - if (target_count != cnt) { - std::cout << segmented_blocks << std::endl; - } - - return target_count == cnt; -} - -bool checkSegmentedBlockNodesMapping( - std::vector& segmented_blocks, - std::shared_ptr g, - std::vector> nodes_index) { - std::vector graph_nodes; - for (const auto n : g->nodes()) { - if (n->kind() != torch::jit::prim::Constant) { - graph_nodes.push_back(n); - } - } - for (size_t i = 0; i < nodes_index.size(); ++i) { - size_t seg_block_node_id = 0; - for (int j : nodes_index[i]) { - if (segmented_blocks[i].raw_nodes()[seg_block_node_id++] != graph_nodes[j]) { - return false; - } - } - if (seg_block_node_id != segmented_blocks[i].raw_nodes().size()) - return false; - } - return true; -} - -TEST(Partitioning, SegmentSequentialModelCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %b1 : Float(32), - %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %b2 : Float(16), - %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), - %b3 : Float(8)): - %2 : int[] = prim::Constant[value=[1, 1]]() - %3 : int = prim::Constant[value=1]() - %10 : bool = prim::Constant[value=0]() - %11 : int[] = prim::Constant[value=[0, 0]]() - %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %13 : Tensor = aten::relu(%12) - %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %15 : Tensor = aten::log_sigmoid(%14) - %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3}, {4}})); -} - -TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %b1 : Float(32), - %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %b2 : Float(16), - %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), - %b3 : Float(8)): - %2 : int[] = prim::Constant[value=[1, 1]]() - %3 : int = prim::Constant[value=1]() - %10 : bool = prim::Constant[value=0]() - %11 : int[] = prim::Constant[value=[0, 0]]() - %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %13 : Tensor = aten::relu(%12) - %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %15 : Tensor = aten::log_sigmoid(%14) - %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - partitioning_info.min_block_size = 3; - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4}})); -} - -TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %1 : Tensor, - %2 : Tensor): - %3 : int[] = prim::Constant[value=[-1, 5]]() - %4 : int[] = prim::Constant[value=[-1]]() - %5 : int = prim::Constant[value=2]() - %6 : int = prim::Constant[value=4]() - %7 : int = prim::Constant[value=5]() - %8 : int = prim::Constant[value=0]() - %9 : bool = prim::Constant[value=0]() - %10 : NoneType = prim::Constant() - %11 : int = prim::Constant[value=1]() - %12: Tensor = aten::reshape(%1, %4) - %13: Tensor = aten::reshape(%2, %3) - %14: Tensor = aten::reshape(%1, %3) - %15 : Tensor = aten::to(%12, %6, %9, %9, %10) - %16 : int = aten::size(%1, %8) - %17 : int[] = prim::ListConstruct(%16, %6, %5, %7) - %18 : Tensor = aten::index_add_(%14, %8, %15, %13, %11) - %20 : Tensor = aten::reshape(%18, %17) - return (%20))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - partitioning_info.min_block_size = 3; - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2, 3}, {4, 5, 6, 7}})); -} - -TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %b1 : Float(32), - %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %b2 : Float(16), - %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), - %b3 : Float(8)): - %2 : int[] = prim::Constant[value=[1, 1]]() - %3 : int = prim::Constant[value=1]() - %10 : bool = prim::Constant[value=0]() - %11 : int[] = prim::Constant[value=[0, 0]]() - %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %13 : Tensor = aten::relu(%12) - %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - %15 : Tensor = aten::log_sigmoid(%14) - %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - partitioning_info.forced_fallback_operators.push_back("aten::relu"); - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 3)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0}, {1}, {2}, {3}, {4}})); -} - -TEST(Partitioning, SegmentBranchModelCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %2 : Float(32), - %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %4 : Float(16)): - %5 : int[] = prim::Constant[value=[0, 0]]() - %6 : int[] = prim::Constant[value=[2, 2]]() - %7 : bool = prim::Constant[value=0]() - %8 : int[] = prim::Constant[value=[1, 1]]() - %9 : int = prim::Constant[value=1]() - %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %12: Tensor = aten::log_sigmoid(%10) - %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %14 : Tensor = aten::relu(%11) - %15 : Tensor = aten::add(%13, %14, %9) - %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2}, {3, 4, 5, 6}})); -} - -TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %2 : Float(32), - %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %4 : Float(16)): - %5 : int[] = prim::Constant[value=[0, 0]]() - %6 : int[] = prim::Constant[value=[2, 2]]() - %7 : bool = prim::Constant[value=0]() - %8 : int[] = prim::Constant[value=[1, 1]]() - %9 : int = prim::Constant[value=1]() - %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %12: Tensor = aten::log_sigmoid(%10) - %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %14 : Tensor = aten::relu(%11) - %15 : Tensor = aten::add(%13, %14, %9) - %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - partitioning_info.min_block_size = 3; - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4, 5, 6}})); -} - -TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) { - const auto graph = R"IR( - graph(%0 : Tensor, - %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), - %2 : Float(32), - %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), - %4 : Float(16)): - %5 : int[] = prim::Constant[value=[0, 0]]() - %6 : int[] = prim::Constant[value=[2, 2]]() - %7 : bool = prim::Constant[value=0]() - %8 : int[] = prim::Constant[value=[1, 1]]() - %9 : int = prim::Constant[value=1]() - %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - - %12: Tensor = aten::log_sigmoid(%10) - - %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) - - %14 : Tensor = aten::relu(%11) - - %15 : Tensor = aten::add(%13, %14, %9) - %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) - return (%16))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - LOG_GRAPH(*g); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - partitioning_info.forced_fallback_operators.push_back("aten::relu"); - PartitioningCtx ctx(g->block(), partitioning_info); - - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE( - checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2, 4}, {3, 5, 6}})); -} - -TEST(Partitioning, SegmentModelWithDependencyAwareness) { - const auto graph = R"IR( - graph(%x : Tensor, - %y : Tensor): - %3 : int = prim::Constant[value=0]() - %20 : int = prim::Constant[value=1]() - %add : Tensor = aten::add(%x, %y, %20) - %x_lgamma : Tensor = aten::lgamma(%x) - %mul : Tensor = aten::mul(%x, %y) - %y_lgamma : Tensor = aten::lgamma(%y) - %div : Tensor = aten::div(%x, %y) - %div_lgamma : Tensor = aten::lgamma(%div) - %27 : Tensor[] = prim::ListConstruct(%x_lgamma, %y_lgamma, %div_lgamma, %add, %mul) - %12 : Tensor = aten::cat(%27, %3) - return (%12))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - PartitioningInfo partitioning_info; - partitioning_info.enabled = true; - PartitioningCtx ctx(g->block(), partitioning_info); - segmentGraph(&ctx, g->block()); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); - ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 2, 4}, {1, 3, 5}, {6, 7}})); -} - -} // namespace tests -} // namespace partitioning -} // namespace core -} // namespace torch_tensorrt +#include +#include "core/partitioning/partitioning.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/script.h" +#include "torch_tensorrt/torch_tensorrt.h" + +namespace torch_tensorrt { +namespace core { +namespace partitioning { +namespace tests { + +bool checkSegmentedBlockNumber( + PartitionedGraph& segmented_blocks, + SegmentedBlock::SegmentedBlockTarget target, + int target_count) { + int64_t cnt = 0; + for (auto& seg_block : segmented_blocks) { + if (seg_block.target() == target) { + cnt++; + } + } + std::cout << "Found count of " << cnt << " " << target << " blocks (looking for " << target_count << " blocks)" + << std::endl; + + if (target_count != cnt) { + std::cout << segmented_blocks << std::endl; + } + + return target_count == cnt; +} + +bool checkSegmentedBlockNodesMapping( + std::vector& segmented_blocks, + std::shared_ptr g, + std::vector> nodes_index) { + std::vector graph_nodes; + for (const auto n : g->nodes()) { + if (n->kind() != torch::jit::prim::Constant) { + graph_nodes.push_back(n); + } + } + for (size_t i = 0; i < nodes_index.size(); ++i) { + size_t seg_block_node_id = 0; + for (int j : nodes_index[i]) { + if (segmented_blocks[i].raw_nodes()[seg_block_node_id++] != graph_nodes[j]) { + return false; + } + } + if (seg_block_node_id != segmented_blocks[i].raw_nodes().size()) + return false; + } + return true; +} + +TEST(Partitioning, SegmentSequentialModelCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %b1 : Float(32), + %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %b2 : Float(16), + %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), + %b3 : Float(8)): + %2 : int[] = prim::Constant[value=[1, 1]]() + %3 : int = prim::Constant[value=1]() + %10 : bool = prim::Constant[value=0]() + %11 : int[] = prim::Constant[value=[0, 0]]() + %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %13 : Tensor = aten::relu(%12) + %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %15 : Tensor = aten::log_sigmoid(%14) + %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3}, {4}})); +} + +TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %b1 : Float(32), + %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %b2 : Float(16), + %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), + %b3 : Float(8)): + %2 : int[] = prim::Constant[value=[1, 1]]() + %3 : int = prim::Constant[value=1]() + %10 : bool = prim::Constant[value=0]() + %11 : int[] = prim::Constant[value=[0, 0]]() + %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %13 : Tensor = aten::relu(%12) + %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %15 : Tensor = aten::log_sigmoid(%14) + %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.min_block_size = 3; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4}})); +} + +TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Tensor, + %2 : Tensor): + %3 : int[] = prim::Constant[value=[-1, 5]]() + %4 : int[] = prim::Constant[value=[-1]]() + %5 : int = prim::Constant[value=2]() + %6 : int = prim::Constant[value=4]() + %7 : int = prim::Constant[value=5]() + %8 : int = prim::Constant[value=0]() + %9 : bool = prim::Constant[value=0]() + %10 : NoneType = prim::Constant() + %11 : int = prim::Constant[value=1]() + %12: Tensor = aten::reshape(%1, %4) + %13: Tensor = aten::reshape(%2, %3) + %14: Tensor = aten::reshape(%1, %3) + %15 : Tensor = aten::to(%12, %6, %9, %9, %10) + %16 : int = aten::size(%1, %8) + %17 : int[] = prim::ListConstruct(%16, %6, %5, %7) + %18 : Tensor = aten::index_add_(%14, %8, %15, %13, %11) + %20 : Tensor = aten::reshape(%18, %17) + return (%20))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.min_block_size = 3; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2, 3}, {4, 5, 6, 7}})); +} + +TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %b1 : Float(32), + %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %b2 : Float(16), + %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]), + %b3 : Float(8)): + %2 : int[] = prim::Constant[value=[1, 1]]() + %3 : int = prim::Constant[value=1]() + %10 : bool = prim::Constant[value=0]() + %11 : int[] = prim::Constant[value=[0, 0]]() + %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %13 : Tensor = aten::relu(%12) + %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + %15 : Tensor = aten::log_sigmoid(%14) + %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.forced_fallback_operators.push_back("aten::relu"); + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 3)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0}, {1}, {2}, {3}, {4}})); +} + +TEST(Partitioning, SegmentBranchModelCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %2 : Float(32), + %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %4 : Float(16)): + %5 : int[] = prim::Constant[value=[0, 0]]() + %6 : int[] = prim::Constant[value=[2, 2]]() + %7 : bool = prim::Constant[value=0]() + %8 : int[] = prim::Constant[value=[1, 1]]() + %9 : int = prim::Constant[value=1]() + %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %12: Tensor = aten::log_sigmoid(%10) + %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %14 : Tensor = aten::relu(%11) + %15 : Tensor = aten::add(%13, %14, %9) + %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2}, {3, 4, 5, 6}})); +} + +TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %2 : Float(32), + %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %4 : Float(16)): + %5 : int[] = prim::Constant[value=[0, 0]]() + %6 : int[] = prim::Constant[value=[2, 2]]() + %7 : bool = prim::Constant[value=0]() + %8 : int[] = prim::Constant[value=[1, 1]]() + %9 : int = prim::Constant[value=1]() + %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %12: Tensor = aten::log_sigmoid(%10) + %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %14 : Tensor = aten::relu(%11) + %15 : Tensor = aten::add(%13, %14, %9) + %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.min_block_size = 3; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4, 5, 6}})); +} + +TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]), + %2 : Float(32), + %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]), + %4 : Float(16)): + %5 : int[] = prim::Constant[value=[0, 0]]() + %6 : int[] = prim::Constant[value=[2, 2]]() + %7 : bool = prim::Constant[value=0]() + %8 : int[] = prim::Constant[value=[1, 1]]() + %9 : int = prim::Constant[value=1]() + %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + + %12: Tensor = aten::log_sigmoid(%10) + + %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7) + + %14 : Tensor = aten::relu(%11) + + %15 : Tensor = aten::add(%13, %14, %9) + %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7) + return (%16))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.forced_fallback_operators.push_back("aten::relu"); + PartitioningCtx ctx(g->block(), partitioning_info); + + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2, 4}, {3, 5, 6}})); +} + +TEST(Partitioning, SegmentModelWithDependencyAwareness) { + const auto graph = R"IR( + graph(%x : Tensor, + %y : Tensor): + %3 : int = prim::Constant[value=0]() + %20 : int = prim::Constant[value=1]() + %add : Tensor = aten::add(%x, %y, %20) + %x_lgamma : Tensor = aten::lgamma(%x) + %mul : Tensor = aten::mul(%x, %y) + %y_lgamma : Tensor = aten::lgamma(%y) + %div : Tensor = aten::div(%x, %y) + %div_lgamma : Tensor = aten::lgamma(%div) + %27 : Tensor[] = prim::ListConstruct(%x_lgamma, %y_lgamma, %div_lgamma, %add, %mul) + %12 : Tensor = aten::cat(%27, %3) + return (%12))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE( + checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 2, 4}, {1, 3, 5}, {6, 7}})); +} + +} // namespace tests +} // namespace partitioning +} // namespace core +} // namespace torch_tensorrt From 86d99248af6559c6e2ffcc28f696767c8dd49e7a Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Tue, 11 Oct 2022 13:23:50 -0700 Subject: [PATCH 3/5] Add documentation for contributors in partitioning.rst --- docsrc/contributors/partitioning.rst | 146 +++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/docsrc/contributors/partitioning.rst b/docsrc/contributors/partitioning.rst index 2d027c57f4..ae4c2edc73 100644 --- a/docsrc/contributors/partitioning.rst +++ b/docsrc/contributors/partitioning.rst @@ -91,3 +91,149 @@ To enable automatic fallback feature, you can set following attributes in Python cfg.torch_executed_ops.push_back("aten::relu"); auto trt_mod = torchtrt::ts::compile(mod, cfg); auto out = trt_mod.forward({in}); + +Dependency Aware Partitioning +==================== +During segmentation, Torch-TensorRT uses a dependency graph of the input TorchScript nodes to reduce the number of segments created. Consider this example from test Partitioning.SegmentModelWithDependencyAwareness in `tests/core/partitioning/test_segmentation.cpp `_ + +.. code-block:: none + + graph(%x : Tensor, %y : Tensor): + %3 : int = prim::Constant[value=0]() + %20 : int = prim::Constant[value=1]() + %add : Tensor = aten::add(%x, %y, %20) + %x_lgamma : Tensor = aten::lgamma(%x) + %mul : Tensor = aten::mul(%x, %y) + %y_lgamma : Tensor = aten::lgamma(%y) + %div : Tensor = aten::div(%x, %y) + %div_lgamma : Tensor = aten::lgamma(%div) + %27 : Tensor[] = prim::ListConstruct(%x_lgamma, %y_lgamma, %div_lgamma, %add, %mul) + %12 : Tensor = aten::cat(%27, %3) + return (%12) + +In this graph `aten::lgamma` is not supported by conversion and must be partitioned in a Torch fallback segment. If Torch-TensorRT uses a greedy segmentation strategy that traverses nodes in the input graph in order and gathers ops with the same target (TensorRT or Torch) into a segment until it encounters an op with a different target, the resulting partition includes 7 segments, many with just a single op. + +.. code-block:: none + + Segment Block @0: + Target: TensorRT + + Graph: graph(%x : Tensor, + %y : Tensor): + %3 : int = prim::Constant[value=1]() + %0 : Tensor = aten::add(%x, %y, %3) + return () + + Segment Block @1: + Target: Torch + + Graph: graph(%x : Tensor): + %0 : Tensor = aten::lgamma(%x) + return () + + Segment Block @2: + Target: TensorRT + + Graph: graph(%x : Tensor, + %y : Tensor): + %0 : Tensor = aten::mul(%x, %y) + return () + + Segment Block @3: + Target: Torch + + Graph: graph(%y : Tensor): + %0 : Tensor = aten::lgamma(%y) + return () + + Segment Block @4: + Target: TensorRT + + Graph: graph(%x : Tensor, + %y : Tensor): + %0 : Tensor = aten::div(%x, %y) + return () + + Segment Block @5: + Target: Torch + + Graph: graph(%1 : Tensor): + %0 : Tensor = aten::lgamma(%1) + return () + + Segment Block @6: + Target: TensorRT + + Graph: graph(%1 : Tensor, + %2 : Tensor, + %3 : Tensor, + %4 : Tensor, + %5 : Tensor): + %7 : int = prim::Constant[value=0]() + %0 : Tensor[] = prim::ListConstruct(%1, %2, %3, %4, %5) + %6 : Tensor = aten::cat(%0, %7) + return () + +This partition is valid, but the segmentation is suboptimal. These arithmetic ops and `aten::lgamma` ops are each split into their own segment as we alternate between Torch and TensorRT targets in the linear traversal of the graph. + +.. code-block:: none + + %add : Tensor = aten::add(%x, %y, %20) + %x_lgamma : Tensor = aten::lgamma(%x) + %mul : Tensor = aten::mul(%x, %y) + %y_lgamma : Tensor = aten::lgamma(%y) + %div : Tensor = aten::div(%x, %y) + %div_lgamma : Tensor = aten::lgamma(%div) + +Each of the arithmetic ops in this segment is only dependent on constants and the inputs `%x` and `%y`. The `aten::lgamma` ops are dependent on the inputs `%x`, `%y` and the output of the `aten::div`. This means that we could rewrite this portion of the input graph as below without changing the behavior of the graph. This reordered series of ops could be cleanly partitioned into just 2 segments using the greedy segmentation approach described above. + +.. code-block:: none + + %add : Tensor = aten::add(%x, %y, %20) + %mul : Tensor = aten::mul(%x, %y) + %div : Tensor = aten::div(%x, %y) + %x_lgamma : Tensor = aten::lgamma(%x) + %y_lgamma : Tensor = aten::lgamma(%y) + %div_lgamma : Tensor = aten::lgamma(%div) + +By adding awareness of the dependencies between ops to the basic greedy segmentation approach we can achieve the same partition without rewriting the graph. Now we will maintain both Torch and TensorRT targeted segments at the same time as we traverse the graph. We will only finalize a segment once we hit an op that is both dependent on an op in the segment and has a different target. This will allow the partition to create larger segments by reordering nodes across the segment boundary while guaranteeing that we will not modify the behavior of the graph by reordering nodes relative to their dependencies. +In this example we will collect the arithmetic ops in a TensorRT segment and the `aten::lgamma` ops in a Torch segment. When we encounter the `%div_lgamma : Tensor = aten::lgamma(%div)` op we can see it is dependent on `%div : Tensor = aten::div(%x, %y)` in the current TensorRT segment. This triggers finalization of the TensorRT segment containing the `aten::div` op to guarantee it will appear before its dependency in the final partition. The Torch segment containing the `aten::lgamma` op is finalized when we encounter the `prim::ListConstruct` op which targets TensorRT and is dependent on the results of the `aten::lgamma` ops. + +.. code-block:: none + + Segment Block @0: + Target: TensorRT + + Graph: graph(%x : Tensor, + %y : Tensor): + %3 : int = prim::Constant[value=1]() + %0 : Tensor = aten::add(%x, %y, %3) + %4 : Tensor = aten::mul(%x, %y) + %5 : Tensor = aten::div(%x, %y) + return () + + Segment Block @1: + Target: Torch + + Graph: graph(%x : Tensor, + %y : Tensor, + %5 : Tensor): + %0 : Tensor = aten::lgamma(%x) + %2 : Tensor = aten::lgamma(%y) + %4 : Tensor = aten::lgamma(%5) + return () + + Segment Block @2: + Target: TensorRT + + Graph: graph(%1 : Tensor, + %2 : Tensor, + %3 : Tensor, + %4 : Tensor, + %5 : Tensor): + %7 : int = prim::Constant[value=0]() + %0 : Tensor[] = prim::ListConstruct(%1, %2, %3, %4, %5) + %6 : Tensor = aten::cat(%0, %7) + return () + +In some cases this approach may create adjacent sements in the partition which have the same target. As a clean-up step we can consolidate these adjacent segments to further reduce the number of segments in the final partition. From 7c8a1af811fcc6522e7744873046aa3c99e36fc6 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Tue, 11 Oct 2022 13:28:05 -0700 Subject: [PATCH 4/5] fix typo --- docsrc/contributors/partitioning.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docsrc/contributors/partitioning.rst b/docsrc/contributors/partitioning.rst index ae4c2edc73..eb964fd7c4 100644 --- a/docsrc/contributors/partitioning.rst +++ b/docsrc/contributors/partitioning.rst @@ -236,4 +236,4 @@ In this example we will collect the arithmetic ops in a TensorRT segment and the %6 : Tensor = aten::cat(%0, %7) return () -In some cases this approach may create adjacent sements in the partition which have the same target. As a clean-up step we can consolidate these adjacent segments to further reduce the number of segments in the final partition. +In some cases this approach may create adjacent segments in the partition which have the same target. As a clean-up step we can consolidate these adjacent segments to further reduce the number of segments in the final partition. From 56ae9f67aadd5b346f6653cc23ac74806daa15fa Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Wed, 12 Oct 2022 14:42:54 -0700 Subject: [PATCH 5/5] Add description of merge segments to docs --- docsrc/contributors/partitioning.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docsrc/contributors/partitioning.rst b/docsrc/contributors/partitioning.rst index eb964fd7c4..7ac5a5231f 100644 --- a/docsrc/contributors/partitioning.rst +++ b/docsrc/contributors/partitioning.rst @@ -237,3 +237,5 @@ In this example we will collect the arithmetic ops in a TensorRT segment and the return () In some cases this approach may create adjacent segments in the partition which have the same target. As a clean-up step we can consolidate these adjacent segments to further reduce the number of segments in the final partition. +The merge segments step identifies a list of segments that are adjacent in the graph, have the same target, and are not marked as `do_not_merge`. The nodes from these segments will be combined into a single new segment that will replace the merged segments in the partition. +The `do_not_merge` marking is used to prevent merging of segments created for conditional nodes and loops that are handled as special cases in graph stitching and should not be merged with adjacent segments of the same type.