From 7c478f04dfbacf79384d5b69615fc0c50fedd4b7 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Wed, 10 May 2023 17:26:18 -0700 Subject: [PATCH] support type promotion in aten::cat converter --- core/conversion/converters/converter_util.h | 2 + core/conversion/converters/impl/concat.cpp | 12 +++ .../conversion/converters/test_concat.cpp | 79 +++++++++++++++++++ 3 files changed, 93 insertions(+) diff --git a/core/conversion/converters/converter_util.h b/core/conversion/converters/converter_util.h index c5e4e4eebc..3342302431 100644 --- a/core/conversion/converters/converter_util.h +++ b/core/conversion/converters/converter_util.h @@ -96,6 +96,8 @@ nvinfer1::ITensor* get_slice_size( nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s); +nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b); + } // namespace converters } // namespace conversion } // namespace core diff --git a/core/conversion/converters/impl/concat.cpp b/core/conversion/converters/impl/concat.cpp index 2f82fccc80..a2e91b55f6 100644 --- a/core/conversion/converters/impl/concat.cpp +++ b/core/conversion/converters/impl/concat.cpp @@ -1,3 +1,4 @@ +#include "core/conversion/converters/converter_util.h" #include "core/conversion/converters/converters.h" #include "core/conversion/tensorcontainer/TensorContainer.h" #include "core/util/prelude.h" @@ -27,6 +28,17 @@ auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() } } + auto promo_dtype = tensors[0]->getType(); + for(size_t idx = 1UL; idx < tensors.size(); ++idx){ + promo_dtype = promote_types(promo_dtype, tensors[idx]->getType()); + } + + for(size_t idx = 0UL; idx < tensors.size(); ++idx){ + if(tensors[idx]->getType() != promo_dtype){ + tensors[idx] = castITensor(ctx, tensors[idx], promo_dtype, util::node_info(n) + "_cast_" + std::to_string(idx)); + } + } + if (dim < 0) { dim = tensors[0]->getDimensions().nbDims + dim; } diff --git a/tests/core/conversion/converters/test_concat.cpp b/tests/core/conversion/converters/test_concat.cpp index 7c7e2d7a93..a311df1c7a 100644 --- a/tests/core/conversion/converters/test_concat.cpp +++ b/tests/core/conversion/converters/test_concat.cpp @@ -29,6 +29,85 @@ TEST(Converters, ATenCatPureTensorConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Converters, ATenCatFloatIntConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Tensor): + %2 : Tensor[] = prim::ListConstruct(%0, %1) + %3 : int = prim::Constant[value=0]() + %4 : Tensor = aten::cat(%2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kFloat); + auto in2 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenCatIntHalfIntHalfConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Tensor, + %2 : Tensor, + %3 : Tensor): + %2 : Tensor[] = prim::ListConstruct(%0, %1, %2, %3) + %3 : int = prim::Constant[value=0]() + %4 : Tensor = aten::cat(%2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt); + auto in2 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kHalf); + auto in3 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt); + auto in4 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kHalf); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2, in3, in4}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = + torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2, in3, in4}, nvinfer1::DataType::kHALF); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenCatHalfIntFloatConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Tensor, + %2 : Tensor): + %2 : Tensor[] = prim::ListConstruct(%0, %1, %2) + %3 : int = prim::Constant[value=0]() + %4 : Tensor = aten::cat(%2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt); + auto in2 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kHalf); + auto in3 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kFloat); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2, in3}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2, in3}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + TEST(Converters, ATenCatDiffTensorConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor,