From 150643d4e3c72a19300fc868ca1fdd48b76cd406 Mon Sep 17 00:00:00 2001 From: inocsin Date: Thu, 1 Apr 2021 15:32:09 +0800 Subject: [PATCH 1/2] fix: support c10::List type of output in MarkOutputs Signed-off-by: inocsin --- core/conversion/conversion.cpp | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index fb42a41dbf..6655b3893d 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -173,6 +173,20 @@ void AddInputs(ConversionCtx* ctx, at::ArrayRef inputs } #endif } +void MarkOutputsOfIvalue(ConversionCtx* ctx, c10::IValue out_ivalue, const torch::jit::Value* out) { + if (out_ivalue.isCustomClass()) { + std::string name = std::string("output_") + std::to_string(ctx->num_outputs); + auto output_container = out_ivalue.toCustomClass(); + nvinfer1::ITensor* out_tensor = output_container.get()->tensor(); + out_tensor->setName(name.c_str()); + ctx->net->markOutput(*out_tensor); + LOG_INFO( + ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)"); + ctx->num_outputs += 1; + } else { + TRTORCH_THROW_ERROR("Unknown output type. Only a single tensor or a TensorList type is supported."); + } +} void MarkOutputs(ConversionCtx* ctx, at::ArrayRef outputs) { for (auto out : outputs) { @@ -180,17 +194,13 @@ void MarkOutputs(ConversionCtx* ctx, at::ArrayRef outp if (it == ctx->value_tensor_map.end()) { if (ctx->evaluated_value_map.find(out) != ctx->evaluated_value_map.end()) { auto out_ivalue = ctx->evaluated_value_map[out]; - if (out_ivalue.isCustomClass()) { - std::string name = std::string("output_") + std::to_string(ctx->num_outputs); - auto output_container = out_ivalue.toCustomClass(); - nvinfer1::ITensor* out_tensor = output_container.get()->tensor(); - out_tensor->setName(name.c_str()); - ctx->net->markOutput(*out_tensor); - LOG_INFO( - ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)"); - ctx->num_outputs += 1; + if (out_ivalue.isList()) { + c10::List value_list = out_ivalue.toList(); + for(auto it = value_list.begin(); it != value_list.end(); it++) { + MarkOutputsOfIvalue(ctx, *it, out); + } } else { - TRTORCH_THROW_ERROR("Unknown output type. Only a single tensor or a TensorList type is supported."); + MarkOutputsOfIvalue(ctx, out_ivalue, out); } } } else { From 053567ec5b3d383f3e9110894b71c43c6f0cc223 Mon Sep 17 00:00:00 2001 From: inocsin Date: Fri, 2 Apr 2021 11:15:46 +0800 Subject: [PATCH 2/2] chore: update error message and function name of MarkOutputs Signed-off-by: inocsin --- core/conversion/conversion.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 6655b3893d..9ad77aa056 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -173,7 +173,8 @@ void AddInputs(ConversionCtx* ctx, at::ArrayRef inputs } #endif } -void MarkOutputsOfIvalue(ConversionCtx* ctx, c10::IValue out_ivalue, const torch::jit::Value* out) { + +void MarkIValueOutputs(ConversionCtx* ctx, c10::IValue out_ivalue, const torch::jit::Value* out) { if (out_ivalue.isCustomClass()) { std::string name = std::string("output_") + std::to_string(ctx->num_outputs); auto output_container = out_ivalue.toCustomClass(); @@ -184,7 +185,7 @@ void MarkOutputsOfIvalue(ConversionCtx* ctx, c10::IValue out_ivalue, const torch ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)"); ctx->num_outputs += 1; } else { - TRTORCH_THROW_ERROR("Unknown output type. Only a single tensor or a TensorList type is supported."); + TRTORCH_THROW_ERROR("Unsupported output type, only Tensors or unwrapped collections of Tensors can be marked as engine outputs but found type: " << out_ivalue.tagKind()); } } @@ -197,10 +198,10 @@ void MarkOutputs(ConversionCtx* ctx, at::ArrayRef outp if (out_ivalue.isList()) { c10::List value_list = out_ivalue.toList(); for(auto it = value_list.begin(); it != value_list.end(); it++) { - MarkOutputsOfIvalue(ctx, *it, out); + MarkIValueOutputs(ctx, *it, out); } } else { - MarkOutputsOfIvalue(ctx, out_ivalue, out); + MarkIValueOutputs(ctx, out_ivalue, out); } } } else {