Skip to content

Commit dba4988

Browse files
committed
Merge branch 'main' into bose_fx2trt_converters_slice_select
2 parents ab89d2b + 6f7627f commit dba4988

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1261
-357
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ import torch_tensorrt
7373
...
7474
7575
trt_ts_module = torch_tensorrt.compile(torch_script_module,
76+
# If the inputs to the module are plain Tensors, specify them via the `inputs` argument:
7677
inputs = [example_tensor, # Provide example tensor for input shape or...
7778
torch_tensorrt.Input( # Specify input object with shape and dtype
7879
min_shape=[1, 3, 224, 224],
@@ -81,6 +82,12 @@ trt_ts_module = torch_tensorrt.compile(torch_script_module,
8182
# For static size shape=[1, 3, 224, 224]
8283
dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
8384
],
85+
86+
# For inputs containing tuples or lists of tensors, use the `input_signature` argument:
87+
# Below, we have an input consisting of a Tuple of two Tensors (Tuple[Tensor, Tensor])
88+
# input_signature = ( (torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half),
89+
# torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half)), ),
90+
8491
enabled_precisions = {torch.half}, # Run with FP16
8592
)
8693
@@ -114,7 +121,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
114121
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.
115122

116123
- Bazel 5.2.0
117-
- Libtorch 2.0.0.dev20230103 (built with CUDA 11.7)
124+
- Libtorch 2.1.0.dev20230314 (built with CUDA 11.7)
118125
- CUDA 11.7
119126
- cuDNN 8.5.0
120127
- TensorRT 8.5.1.7
@@ -124,7 +131,7 @@ These are the following dependencies used to verify the testcases. Torch-TensorR
124131
Releases: https://github.com/pytorch/TensorRT/releases
125132

126133
```
127-
pip install torch-tensorrt==1.2.0 --find-links https://github.com/pytorch/TensorRT/releases/expanded_assets/v1.2.0
134+
pip install torch-tensorrt
128135
```
129136

130137
## Compiling Torch-TensorRT

WORKSPACE

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ new_local_repository(
5656
http_archive(
5757
name = "libtorch",
5858
build_file = "@//third_party/libtorch:BUILD",
59-
sha256 = "8b3b48615169c83c1b643c0efade078ea080b1da598e15fcf01bc59421f3095e",
59+
sha256 = "7c4b8754830fef23ec19c5eaf414794cee9597b435df055f5c1d0471d3e81568",
6060
strip_prefix = "libtorch",
61-
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.0.dev20230219%2Bcu117.zip"],
61+
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
6262
)
6363

6464
http_archive(
6565
name = "libtorch_pre_cxx11_abi",
6666
build_file = "@//third_party/libtorch:BUILD",
67-
sha256 = "aa7fd06079d260ff83c344d043fb84fbd9cf831cf375ed8b5a1b62416817af31",
67+
sha256 = "f1e64a75dd12d0ba4c8c1f61947299e0a9c50684dff64f0cfbf355aa7a13e8cf",
6868
strip_prefix = "libtorch",
69-
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.0.0.dev20230219%2Bcu117.zip"],
69+
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
7070
)
7171

7272
# Download these tarballs manually from the NVIDIA website

core/compiler.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
352352
// Determine if the block is convertible/has collection output, and based on the result,
353353
// whether full compilation can be expected
354354
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
355+
auto inputIsCollection = conversion::InputIsCollection(g->block());
355356
auto outputIsCollection = conversion::OutputIsCollection(g->block());
356-
auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
357+
auto requires_collection_handling = (isBlockConvertible && (inputIsCollection || outputIsCollection));
357358

358359
// Determine whether user specifications necessitate partitioning
359360
auto isFallbackRequested = userRequestedFallback(cfg);

core/conversion/conversion.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
6868
return {};
6969
}
7070
}
71-
auto eval = evaluators::EvalNode(n, eval_args);
71+
auto eval = evaluators::EvalNode(ctx, n, eval_args);
7272
return eval;
7373
}
7474

@@ -556,10 +556,20 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
556556
return convertable_ops;
557557
}
558558

559+
bool InputIsCollection(const torch::jit::Block* b) {
560+
for (auto in : b->inputs()) {
561+
if (in->type()->kind() == torch::jit::TypeKind::TupleType || in->type()->kind() == torch::jit::TypeKind::ListType) {
562+
return true;
563+
}
564+
}
565+
return false;
566+
}
567+
559568
bool OutputIsCollection(const torch::jit::Block* b) {
560569
for (auto out : b->outputs()) {
561570
if (out->type()->kind() == torch::jit::TypeKind::TupleType ||
562-
out->type()->kind() == torch::jit::TypeKind::ListType) {
571+
out->type()->kind() == torch::jit::TypeKind::ListType ||
572+
out->type()->kind() == torch::jit::TypeKind::DictType) {
563573
return true;
564574
}
565575
}

core/conversion/conversion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ std::string ConvertBlockToEngine(
2626

2727
bool OpSupported(const torch::jit::Node* n);
2828

29+
bool InputIsCollection(const torch::jit::Block* b);
30+
2931
bool OutputIsCollection(const torch::jit::Block* b);
3032

3133
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);

core/conversion/converters/impl/matrix_multiply.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,85 @@ auto mm_registrations TORCHTRT_UNUSED =
7272
mm_layer->setName(util::node_info(n).c_str());
7373
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
7474

75+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
76+
return true;
77+
}})
78+
.pattern(
79+
{"aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor",
80+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
81+
auto self = args[0].ITensorOrFreeze(ctx);
82+
auto bat1 = args[1].ITensorOrFreeze(ctx);
83+
auto bat2 = args[2].ITensorOrFreeze(ctx);
84+
nvinfer1::Dims batch1Dims = bat1->getDimensions();
85+
nvinfer1::Dims batch2Dims = bat2->getDimensions();
86+
87+
// check dimensions
88+
TORCHTRT_CHECK(
89+
batch1Dims.nbDims == 3,
90+
"Expected 3-dimensional tensor, but got "
91+
<< batch1Dims.nbDims
92+
<< "-dimensional tensor for argument 'batch1' (while checking arguments for baddbmm)");
93+
TORCHTRT_CHECK(
94+
batch2Dims.nbDims == 3,
95+
"Expected 3-dimensional tensor, but got "
96+
<< batch2Dims.nbDims
97+
<< "-dimensional tensor for argument 'batch2' (while checking arguments for baddbmm)");
98+
TORCHTRT_CHECK(
99+
batch1Dims.d[0] == batch2Dims.d[0],
100+
"Expected tensor to have size " << batch1Dims.d[0] << " at dimension 0, but got size "
101+
<< batch2Dims.d[0]
102+
<< " for argument 'batch2' (while checking arguments for baddbmm)");
103+
TORCHTRT_CHECK(
104+
batch1Dims.d[2] == batch2Dims.d[1],
105+
"Expected tensor to have size " << batch1Dims.d[2] << " at dimension 1, but got size "
106+
<< batch2Dims.d[1]
107+
<< " for argument 'batch2' (while checking arguments for baddbmm)");
108+
109+
auto mm_layer = ctx->net->addMatrixMultiply(
110+
*bat1, nvinfer1::MatrixOperation::kNONE, *bat2, nvinfer1::MatrixOperation::kNONE);
111+
TORCHTRT_CHECK(mm_layer, "Unable to create matrix multiplication for node: " << *n);
112+
mm_layer->setName((util::node_info(n) + "_matmul").c_str());
113+
114+
auto mm_out = mm_layer->getOutput(0);
115+
116+
auto alpha = args[4].unwrapToScalar();
117+
if (alpha.to<float>() != 1.) {
118+
auto alpha_tensor = scalar_to_tensor(ctx, alpha);
119+
auto alpha_layer = add_elementwise(
120+
ctx,
121+
nvinfer1::ElementWiseOperation::kPROD,
122+
mm_out,
123+
alpha_tensor,
124+
util::node_info(n) + std::string("_alpha_mul"));
125+
TORCHTRT_CHECK(alpha_layer, "Unable to create alpha_mul layer from node: " << *n);
126+
mm_out = alpha_layer->getOutput(0);
127+
}
128+
129+
auto beta = args[3].unwrapToScalar();
130+
// If beta is 0, then input will be ignored, and nan and inf in it will not be propagated.
131+
if (beta.to<float>() != 0.) {
132+
if (beta.to<float>() != 1.) {
133+
auto beta_tensor = scalar_to_tensor(ctx, beta);
134+
auto beta_layer = add_elementwise(
135+
ctx,
136+
nvinfer1::ElementWiseOperation::kPROD,
137+
self,
138+
beta_tensor,
139+
util::node_info(n) + std::string("_beta_mul"));
140+
TORCHTRT_CHECK(beta_layer, "Unable to create beta_mul layer from node: " << *n);
141+
self = beta_layer->getOutput(0);
142+
}
143+
auto self_add_layer = add_elementwise(
144+
ctx,
145+
nvinfer1::ElementWiseOperation::kSUM,
146+
self,
147+
mm_out,
148+
util::node_info(n) + std::string("_self_add"));
149+
TORCHTRT_CHECK(self_add_layer, "Unable to create self_add layer from node: " << *n);
150+
mm_out = self_add_layer->getOutput(0);
151+
}
152+
153+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_out);
75154
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
76155
return true;
77156
}});

core/conversion/converters/impl/shuffle.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,25 +70,37 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
7070
auto in = args[0].ITensorOrFreeze(ctx);
7171
auto in_shape = util::toVec(in->getDimensions());
7272
std::vector<int64_t> new_shape;
73+
nvinfer1::ITensor* shape_tensor;
7374
if (ctx->input_is_dynamic) {
74-
new_shape = util::toVec(args[1].unwrapToIntList().vec());
75-
int nbDynamicDims = 0;
76-
for (size_t i = 0; i < new_shape.size(); i++) {
77-
if (in_shape[i] == -1)
78-
nbDynamicDims++;
79-
}
80-
if (nbDynamicDims > 1) {
81-
TORCHTRT_THROW_ERROR(
82-
"Resize is currently not supported when target shape contains more than one dynamic dimension");
75+
LOG_DEBUG("Using dynamic version of reshape layer");
76+
if (args[1].isITensorList()) {
77+
LOG_DEBUG("Shape tensor is an ITensorList");
78+
auto new_shape = args[1].unwrapToITensorList();
79+
auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size());
80+
TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n);
81+
concat_layer->setAxis(static_cast<int32_t>(0));
82+
shape_tensor = concat_layer->getOutput(0);
83+
} else if (args[1].isIntList()) {
84+
LOG_DEBUG("Shape tensor is an IntList");
85+
auto shape_vec = args[1].unwrapToIntList().vec();
86+
shape_tensor = tensor_to_const(ctx, torch::tensor(shape_vec).to(torch::kI32));
87+
} else {
88+
LOG_ERROR(
89+
"Invalid IValue type of " << args[1].IValue()->type()
90+
<< " detected for shape tensor from node: " << *n);
8391
}
8492
} else {
8593
new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
8694
}
87-
8895
auto shuffle = ctx->net->addShuffle(*in);
89-
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
90-
shuffle->setReshapeDimensions(util::toDims(new_shape));
9196
shuffle->setName(util::node_info(n).c_str());
97+
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
98+
99+
if (ctx->input_is_dynamic) {
100+
shuffle->setInput(1, *shape_tensor);
101+
} else {
102+
shuffle->setReshapeDimensions(util::toDims(new_shape));
103+
}
92104

93105
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
94106
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

core/conversion/converters/impl/stack.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,9 @@ auto stack_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patt
4343
auto cont = t.toCustomClass<TensorContainer>();
4444
itensor = cont->tensor();
4545
}
46-
4746
auto shuffle_layer = ctx->net->addShuffle(*itensor);
4847
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
49-
shuffle_layer->setReshapeDimensions(util::unsqueezeDims(itensor->getDimensions(), dim));
48+
shuffle_layer->setReshapeDimensions(util::unsqueezeDims(itensor->getDimensions(), dim, 1, false));
5049

5150
tensors.push_back(shuffle_layer->getOutput(0));
5251
}

core/conversion/evaluators/NodeEvaluatorRegistry.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ std::vector<std::string> getEvaluatorList() {
114114
return get_evaluator_registry().GetRegisteredEvaluatorList();
115115
}
116116

117-
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
117+
c10::optional<torch::jit::IValue> EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
118118
auto evaluator = get_evaluator_registry().GetEvaluator(n);
119-
return evaluator(n, args);
119+
return evaluator(ctx, n, args);
120120
}
121121

122122
void register_node_evaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {

0 commit comments

Comments
 (0)