@@ -40,6 +40,41 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
4040 TestATenStackPureTensorConvertsCorrectly (graph2);
4141}
4242
43+ TEST (Converters, ATenStackPureTensorDynamicConvertsCorrectly) {
44+ auto TestATenStackPureTensorConvertsCorrectly = [](const std::string& graph) {
45+ auto g = std::make_shared<torch::jit::Graph>();
46+ torch::jit::parseIR (graph, g.get ());
47+
48+ auto in1 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
49+ auto in2 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
50+
51+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
52+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
53+
54+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
55+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in1, in2});
56+
57+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], THRESHOLD_E5));
58+ };
59+ const auto graph = R"IR(
60+ graph(%0 : Tensor,
61+ %1 : Tensor):
62+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
63+ %3 : int = prim::Constant[value=1]()
64+ %4 : Tensor = aten::stack(%2, %3)
65+ return (%4))IR" ;
66+ const auto graph2 = R"IR(
67+ graph(%0 : Tensor,
68+ %1 : Tensor):
69+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
70+ %3 : int = prim::Constant[value=-1]()
71+ %4 : Tensor = aten::stack(%2, %3)
72+ return (%4))IR" ;
73+
74+ TestATenStackPureTensorConvertsCorrectly (graph);
75+ TestATenStackPureTensorConvertsCorrectly (graph2);
76+ }
77+
4378TEST (Converters, ATenStackDiffTensorConvertsCorrectly) {
4479 auto TestATenStackDiffTensorConvertsCorrectly = [](const std::string& graph) {
4580 auto g = std::make_shared<torch::jit::Graph>();
0 commit comments