Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,11 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
auto graph_and_mapping =
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes);
new_g = graph_and_mapping.first;
LOG_INFO("Segmented Graph: " << *new_g);
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
for (size_t i = 0; i < new_g->inputs().size(); ++i) {
new_g->inputs()[i]->setDebugName(std::string("input_") + std::to_string(i));
}
LOG_INFO(*new_g << "(GraphAfterFallback)");

// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
// module
Expand Down
18 changes: 17 additions & 1 deletion tests/core/partitioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ partitioning_test(
name = "test_resolve_nontensor_inputs",
)

cc_test(
name = "test_loading_model",
srcs = ["test_loading_model.cpp"],
deps = [
"//tests/util",
"@googletest//:gtest_main",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
}),
data = [
":jit_models"
]
)

cc_test(
name = "test_fallback_graph_output",
srcs = ["test_fallback_graph_output.cpp"],
Expand Down Expand Up @@ -92,6 +107,7 @@ test_suite(
":test_fallback_graph_output",
":test_loop_fallback",
":test_conditionals",
":test_resolve_nontensor_inputs"
":test_resolve_nontensor_inputs",
":test_loading_model"
]
)
39 changes: 39 additions & 0 deletions tests/core/partitioning/test_loading_model.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <string>
#include <unordered_set>
#include "core/compiler.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/script.h"

#ifndef DISABLE_TEST_IN_CI

TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) {
torch::jit::script::Module mod;
try {
mod = torch::jit::load("tests/modules/conditional_scripted.jit.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}

const std::vector<std::vector<int64_t>> input_shapes = {{1, 3, 224, 224}};
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randint(5, in_shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}

std::vector<torch_tensorrt::core::ir::Input> input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})};

torch_tensorrt::core::CompileSpec cfg(input_ranges);
cfg.partition_info.enabled = true;

auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
trt_mod.save("loading_model.ts");
auto loaded_model = torch::jit::load("loading_model.ts");
}

#endif