@@ -833,6 +833,38 @@ TEST(Converters, ATenIndexTensorFullIndicesConvertsCorrectly) {
833
833
torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
834
834
}
835
835
836
+ TEST (Converters, ATenIndexTensorRepeatedFullIndicesConvertsCorrectly) {
837
+ const auto graph = R"IR(
838
+ graph(%x.1 : Tensor,
839
+ %index0 : Tensor,
840
+ %index1 : Tensor,
841
+ %index2 : Tensor):
842
+ %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2)
843
+ %19 : Tensor = aten::index(%x.1, %18)
844
+ %20 : Tensor = aten::index(%x.1, %18)
845
+ return (%19, %20))IR" ;
846
+
847
+ auto g = std::make_shared<torch::jit::Graph>();
848
+ torch::jit::parseIR (graph, g.get ());
849
+
850
+ auto in1 = at::randint (1 , 10 , {5 , 10 , 4 }, {at::kCUDA });
851
+ auto index0 = at::tensor ({0 , 1 , 2 , 3 }, {at::kCUDA }).to (torch::kLong );
852
+ auto index1 = at::tensor ({1 , 3 , 4 , 6 }, {at::kCUDA }).to (torch::kLong );
853
+ auto index2 = at::tensor ({3 , 2 , 1 , 0 }, {at::kCUDA }).to (torch::kLong );
854
+ auto index0_trt = index0.to (torch::kInt32 );
855
+ auto index1_trt = index1.to (torch::kInt32 );
856
+ auto index2_trt = index2.to (torch::kInt32 );
857
+
858
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
859
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, index0, index1, index2});
860
+
861
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
862
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, index0_trt, index1_trt, index2_trt});
863
+
864
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
865
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[1 ], trt_results[1 ], 2e-6 ));
866
+ }
867
+
836
868
TEST (Converters, ATenIndexTensorIdx0Idx1NoneConvertsCorrectly) {
837
869
const auto graph = R"IR(
838
870
graph(%x.1 : Tensor,
0 commit comments