Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s

if (unbind) {
axis = args[1].unwrapToInt();
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
axis = axis < 0 ? axis + maxDim : axis;
numOutputs = in->getDimensions().d[axis];
sizes.insert(sizes.end(), numOutputs, 1);
} else {
Expand Down
7 changes: 6 additions & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,12 @@ auto aten_registrations TORCHTRT_UNUSED =
.evaluator({c10::Symbol::fromQualString("aten::slice"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
int64_t start = args.at(n->input(1)).unwrapToInt();

int64_t start = 0;
auto startIVal = args.at(n->input(1)).IValue();
if(!startIVal->isNone()){
start = args.at(n->input(1)).unwrapToInt();
}
int64_t end = args.at(n->input(2)).unwrapToInt();
int64_t step = args.at(n->input(3)).unwrapToInt();

Expand Down
59 changes: 59 additions & 0 deletions tests/core/conversion/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,38 @@ TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenSliceListConvertsCorrectly) {
const auto graph = R"IR(
graph(%x : Tensor):
%1 : NoneType = prim::Constant()
%2 : int = prim::Constant[value=2]()
%3 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=3]()
%list : Tensor[] = aten::unbind(%x, %4)
%slice : Tensor[] = aten::slice(%list, %1, %2, %3)
%out.1 : Tensor, %out.2 : Tensor = prim::ListUnpack(%slice)
return (%out.1, %out.2))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in_x = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA});

auto jit_in_x = at::clone(in_x);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in_x});

auto trt_in_x = at::clone(in_x);
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in_x});

for (size_t i = 0; i < jit_results.size(); i++) {
auto trt = trt_results[i].reshape(jit_results[i].sizes());
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

TEST(Converters, ATenSliceDynamicBatchConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
Expand Down Expand Up @@ -796,3 +828,30 @@ TEST(Converters, ATenUnbindConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=-1]()
%3 : Tensor[] = aten::unbind(%x.1, %2)
%o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3)
return (%o1.1, %o2.1))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {5, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

for (size_t i = 0; i < jit_results.size(); i++) {
auto trt = trt_results[i].reshape(jit_results[i].sizes());
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}