From 285913da3c48edeb142d27f567444e3f2b13669a Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Mon, 3 Apr 2023 17:06:43 -0700 Subject: [PATCH] fix aten::stack with dynamic inputs --- core/conversion/converters/impl/stack.cpp | 3 +- .../core/conversion/converters/test_stack.cpp | 35 +++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/core/conversion/converters/impl/stack.cpp b/core/conversion/converters/impl/stack.cpp index 0f5b9da273..2a4241ecf3 100644 --- a/core/conversion/converters/impl/stack.cpp +++ b/core/conversion/converters/impl/stack.cpp @@ -43,10 +43,9 @@ auto stack_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patt auto cont = t.toCustomClass(); itensor = cont->tensor(); } - auto shuffle_layer = ctx->net->addShuffle(*itensor); TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); - shuffle_layer->setReshapeDimensions(util::unsqueezeDims(itensor->getDimensions(), dim)); + shuffle_layer->setReshapeDimensions(util::unsqueezeDims(itensor->getDimensions(), dim, 1, false)); tensors.push_back(shuffle_layer->getOutput(0)); } diff --git a/tests/core/conversion/converters/test_stack.cpp b/tests/core/conversion/converters/test_stack.cpp index 72bc8832bf..6d0fdbec44 100644 --- a/tests/core/conversion/converters/test_stack.cpp +++ b/tests/core/conversion/converters/test_stack.cpp @@ -40,6 +40,41 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) { TestATenStackPureTensorConvertsCorrectly(graph2); } +TEST(Converters, ATenStackPureTensorDynamicConvertsCorrectly) { + auto TestATenStackPureTensorConvertsCorrectly = [](const std::string& graph) { + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); + auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], THRESHOLD_E5)); + }; + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Tensor): + %2 : Tensor[] = prim::ListConstruct(%0, %1) + %3 : int = prim::Constant[value=1]() + %4 : Tensor = aten::stack(%2, %3) + return (%4))IR"; + const auto graph2 = R"IR( + graph(%0 : Tensor, + %1 : Tensor): + %2 : Tensor[] = prim::ListConstruct(%0, %1) + %3 : int = prim::Constant[value=-1]() + %4 : Tensor = aten::stack(%2, %3) + return (%4))IR"; + + TestATenStackPureTensorConvertsCorrectly(graph); + TestATenStackPureTensorConvertsCorrectly(graph2); +} + TEST(Converters, ATenStackDiffTensorConvertsCorrectly) { auto TestATenStackDiffTensorConvertsCorrectly = [](const std::string& graph) { auto g = std::make_shared();