Skip to content

Commit ef6394b

Browse files
authored
chore: cherry-pick of DS feature (#2857)
1 parent 3422c41 commit ef6394b

29 files changed

+416
-158
lines changed

core/runtime/execute_engine.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
124124
}
125125
}
126126

127+
// this is a buffer to store shape tensor input addresses throughout the runtime scope
128+
std::list<std::vector<int32_t>> inputShapeTensorValues;
127129
{
128130
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
129131
if (compiled_engine->profile_execution) {
@@ -142,12 +144,30 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
142144
auto dims = core::util::toDims(inputs[i].sizes());
143145
auto shape = core::util::toVec(dims);
144146
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
145-
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
146-
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr());
147+
if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
148+
// Shape tensor inputs are casted to int32 explicitly.
149+
// Refer to
150+
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
151+
auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt32);
152+
std::vector<int32_t> inputs_cpu_vec(
153+
input_cpu.data_ptr<int32_t>(), input_cpu.data_ptr<int32_t>() + input_cpu.numel());
154+
inputShapeTensorValues.emplace_back(inputs_cpu_vec);
155+
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data());
156+
} else {
157+
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
158+
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr());
159+
}
147160
}
148161

162+
// Check if input shapes can be inferred.
163+
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
164+
std::vector<char const*> names(io_size);
165+
int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data());
149166
TORCHTRT_CHECK(
150-
compiled_engine->exec_ctx->allInputShapesSpecified(), "Not enough inputs provided (runtime.RunCudaEngine)");
167+
nbNames == 0,
168+
"The shapes of the inputs: "
169+
<< names
170+
<< " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly");
151171
}
152172

153173
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);

py/torch_tensorrt/_Input.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class _ShapeMode(Enum):
4747
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
4848
torch_tensor: torch.Tensor = None
4949
name: str = ""
50+
is_shape_tensor: bool = False
5051

5152
def __init__(self, *args: Any, **kwargs: Any) -> None:
5253
"""__init__ Method for torch_tensorrt.Input
@@ -161,6 +162,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
161162
else:
162163
self._explicit_set_dtype = False
163164

165+
if "is_shape_tensor" in kwargs:
166+
self.is_shape_tensor = kwargs["is_shape_tensor"]
167+
164168
if "format" in kwargs:
165169
self.format = memory_format._from(kwargs["format"])
166170

@@ -174,7 +178,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
174178
if "torch_tensor" in kwargs:
175179
self.torch_tensor = kwargs["torch_tensor"]
176180
else:
177-
if self.shape_mode == Input._ShapeMode.DYNAMIC:
181+
if self.is_shape_tensor:
182+
self.torch_tensor = torch.tensor(
183+
kwargs["opt_shape"], dtype=kwargs["dtype"]
184+
)
185+
elif self.shape_mode == Input._ShapeMode.DYNAMIC:
178186
self.torch_tensor = self.example_tensor("opt_shape")
179187
else:
180188
self.torch_tensor = self.example_tensor()

py/torch_tensorrt/dynamo/_tracer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,9 @@ def trace(
5858

5959
device = to_torch_device(kwargs.get("device", default_device()))
6060
torch_inputs = get_torch_inputs(inputs, device)
61-
dynamic_shapes = {}
61+
dynamic_shapes = []
6262
for input in inputs:
6363
if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC:
64-
if not input.name:
65-
raise AssertionError(
66-
f"Expected a name for a dynamic input with shape {input.shape} but found none"
67-
)
6864
min_shape = input.shape["min_shape"]
6965
opt_shape = input.shape["opt_shape"]
7066
max_shape = input.shape["max_shape"]
@@ -80,8 +76,8 @@ def trace(
8076
max=max_shape[dim],
8177
)
8278

83-
dynamic_shapes[input.name] = dynamic_dims
79+
dynamic_shapes.append(dynamic_dims)
8480

85-
exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes)
81+
exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes))
8682

8783
return exp_program

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def _pretraced_backend(
9696

9797
gm = apply_lowering_passes(gm, torch_inputs)
9898

99+
logger.debug("Lowered Input graph:\n " + str(gm.graph))
100+
99101
torchtrt_inputs = prepare_inputs(
100102
torch_inputs, disable_memory_format_check=True
101103
)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_node_name,
2424
get_trt_tensor,
2525
)
26+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
2627
from torch_tensorrt.fx.observer import Observer
2728
from torch_tensorrt.logging import TRT_LOGGER
2829

@@ -370,18 +371,29 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
370371
max_shape = current_input.shape["max_shape"]
371372
# TODO: Does not support disjoint optimization profiles?
372373
assert self.optimization_profiles is not None
373-
self.optimization_profiles[0].set_shape(
374-
target, min_shape, opt_shape, max_shape
375-
)
376-
377374
assert len(min_shape) == len(opt_shape) == len(max_shape)
378-
for i in range(len(min_shape)):
379-
if min_shape[i] == opt_shape[i] == max_shape[i]:
380-
shape.append(min_shape[i])
381-
else:
382-
# -1 to represent the dynamic dimension
383-
shape.append(-1)
384-
elif current_input.shape_mode == Input._ShapeMode.STATIC:
375+
if current_input.is_shape_tensor:
376+
# For shape_tensors, min/opt/max_shapes correspond to actual values
377+
# of the shapes provided during runtime
378+
self.optimization_profiles[0].set_shape_input(
379+
target, min_shape, opt_shape, max_shape
380+
)
381+
shape.append(len(opt_shape))
382+
else:
383+
self.optimization_profiles[0].set_shape(
384+
target, min_shape, opt_shape, max_shape
385+
)
386+
387+
for i in range(len(min_shape)):
388+
if min_shape[i] == opt_shape[i] == max_shape[i]:
389+
shape.append(min_shape[i])
390+
else:
391+
# -1 to represent the dynamic dimension
392+
shape.append(DYNAMIC_DIM)
393+
elif (
394+
not current_input.is_shape_tensor
395+
and current_input.shape_mode == Input._ShapeMode.STATIC
396+
):
385397
assert isinstance(current_input.shape, tuple)
386398
shape = list(current_input.shape)
387399
else:
@@ -393,6 +405,7 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
393405
_LOGGER.debug(
394406
f"Adding input to in-progress INetwork: {target} [shape={shape}, dtype={trt_input_dtype}]"
395407
)
408+
396409
return self.ctx.net.add_input(
397410
name=target,
398411
shape=tuple(shape),

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import logging
55
from typing import List, Sequence
66

7+
import tensorrt as trt
78
import torch
9+
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
810
from torch_tensorrt._Device import Device
911
from torch_tensorrt._enums import dtype
1012
from torch_tensorrt._features import ENABLED_FEATURES
@@ -17,8 +19,6 @@
1719
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1820
from torch_tensorrt.dynamo.utils import get_torch_inputs
1921

20-
import tensorrt as trt
21-
2222
logger = logging.getLogger(__name__)
2323

2424

@@ -28,12 +28,12 @@ def infer_module_output_dtypes(
2828
device: Device,
2929
truncate_double: bool = False,
3030
) -> List[dtype]:
31-
torch_inputs = get_torch_inputs(inputs, device)
32-
module = module.to(device.to(torch.device))
33-
module_outputs = module(*torch_inputs)
34-
35-
if not isinstance(module_outputs, (list, tuple)):
36-
module_outputs = [module_outputs]
31+
with maybe_disable_fake_tensor_mode():
32+
torch_inputs = get_torch_inputs(inputs, device)
33+
module = module.to(device.to(torch.device))
34+
module_outputs = module(*torch_inputs)
35+
if not isinstance(module_outputs, (list, tuple)):
36+
module_outputs = [module_outputs]
3737

3838
# Int64 outputs can sometimes be generated from within other operators
3939
# such as aten.sum - such outputs can be truncated

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,14 @@ def aten_ops_batch_norm_legit_no_training(
129129

130130

131131
@dynamo_tensorrt_converter(
132-
torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator
132+
torch.ops.aten.native_layer_norm.default,
133+
capability_validator=one_user_validator,
134+
supports_dynamic_shapes=True,
135+
)
136+
@dynamo_tensorrt_converter(
137+
torch.ops.aten.layer_norm.default, supports_dynamic_shapes=True
133138
)
134-
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default)
135-
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm)
139+
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm, supports_dynamic_shapes=True)
136140
@enforce_tensor_types(
137141
{
138142
0: (TRTTensor,),
@@ -237,7 +241,10 @@ def aten_ops_cat(
237241
)
238242

239243

240-
@dynamo_tensorrt_converter(torch.ops.aten.embedding.default)
244+
@dynamo_tensorrt_converter(
245+
torch.ops.aten.embedding.default,
246+
supports_dynamic_shapes=True,
247+
)
241248
def aten_ops_embedding(
242249
ctx: ConversionContext,
243250
target: Target,
@@ -427,7 +434,7 @@ def aten_ops_index(
427434
)
428435

429436

430-
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default)
437+
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default, supports_dynamic_shapes=True)
431438
def aten_ops_tanh(
432439
ctx: ConversionContext,
433440
target: Target,
@@ -518,10 +525,10 @@ def aten_ops_hard_sigmoid(
518525
)
519526

520527

521-
@dynamo_tensorrt_converter(torch.ops.aten.matmul)
522-
@dynamo_tensorrt_converter(torch.ops.aten.mm.default)
523-
@dynamo_tensorrt_converter(torch.ops.aten.mv.default)
524-
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default)
528+
@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
529+
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)
530+
@dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True)
531+
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default, supports_dynamic_shapes=True)
525532
def aten_ops_matmul(
526533
ctx: ConversionContext,
527534
target: Target,
@@ -602,7 +609,9 @@ def aten_ops_erf(
602609
)
603610

604611

605-
@dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default)
612+
@dynamo_tensorrt_converter(
613+
torch.ops.aten.unsqueeze.default, supports_dynamic_shapes=True
614+
)
606615
def aten_ops_unsqueeze(
607616
ctx: ConversionContext,
608617
target: Target,
@@ -615,7 +624,9 @@ def aten_ops_unsqueeze(
615624
)
616625

617626

618-
@dynamo_tensorrt_converter(torch.ops.aten._softmax.default)
627+
@dynamo_tensorrt_converter(
628+
torch.ops.aten._softmax.default, supports_dynamic_shapes=True
629+
)
619630
def aten_ops_softmax(
620631
ctx: ConversionContext,
621632
target: Target,
@@ -730,7 +741,7 @@ def aten_ops_select(
730741
)
731742

732743

733-
@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor)
744+
@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor, supports_dynamic_shapes=True)
734745
@enforce_tensor_types(
735746
{
736747
0: (TRTTensor,),
@@ -860,7 +871,7 @@ def aten_ops_as_strided(
860871
)
861872

862873

863-
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
874+
@dynamo_tensorrt_converter(torch.ops.aten.permute.default, supports_dynamic_shapes=True)
864875
@enforce_tensor_types(
865876
{
866877
0: (TRTTensor,),
@@ -931,10 +942,12 @@ def validator(to_copy_node: Node) -> bool:
931942
@dynamo_tensorrt_converter(
932943
torch.ops.aten.clone.default,
933944
capability_validator=lambda node: not is_only_operator_on_placeholder(node),
945+
supports_dynamic_shapes=True,
934946
)
935947
@dynamo_tensorrt_converter(
936948
torch.ops.aten._to_copy.default,
937949
capability_validator=to_copy_dtype_validator(placeholder_only=False),
950+
supports_dynamic_shapes=True,
938951
)
939952
def aten_ops_clone_copy_dtype(
940953
ctx: ConversionContext,
@@ -983,7 +996,7 @@ def aten_ops_clone_copy_placeholder(
983996
)
984997

985998

986-
@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
999+
@dynamo_tensorrt_converter(torch.ops.aten.expand.default, supports_dynamic_shapes=True)
9871000
@enforce_tensor_types(
9881001
{
9891002
0: (TRTTensor,),
@@ -1673,6 +1686,7 @@ def aten_ops_isnan(
16731686
)
16741687

16751688

1689+
@dynamo_tensorrt_converter(operator.add, supports_dynamic_shapes=True)
16761690
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor, supports_dynamic_shapes=True)
16771691
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar, supports_dynamic_shapes=True)
16781692
def aten_ops_add(
@@ -1705,8 +1719,8 @@ def aten_ops_add(
17051719
)
17061720

17071721

1708-
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
1709-
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar)
1722+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor, supports_dynamic_shapes=True)
1723+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar, supports_dynamic_shapes=True)
17101724
def aten_ops_mul(
17111725
ctx: ConversionContext,
17121726
target: Target,
@@ -1792,11 +1806,11 @@ def aten_ops_sub(
17921806
)
17931807

17941808

1795-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
1796-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode)
1797-
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar)
1798-
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode)
1799-
@dynamo_tensorrt_converter(torch.ops.prims.div.default)
1809+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor, supports_dynamic_shapes=True)
1810+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode, supports_dynamic_shapes=True)
1811+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar, supports_dynamic_shapes=True)
1812+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode, supports_dynamic_shapes=True)
1813+
@dynamo_tensorrt_converter(torch.ops.prims.div.default, supports_dynamic_shapes=True)
18001814
def aten_ops_div(
18011815
ctx: ConversionContext,
18021816
target: Target,
@@ -1839,9 +1853,13 @@ def aten_ops_div(
18391853
)
18401854

18411855

1842-
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
1843-
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar)
1844-
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
1856+
@dynamo_tensorrt_converter(
1857+
torch.ops.aten.pow.Tensor_Tensor, supports_dynamic_shapes=True
1858+
)
1859+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar, supports_dynamic_shapes=True)
1860+
@dynamo_tensorrt_converter(
1861+
torch.ops.aten.pow.Tensor_Scalar, supports_dynamic_shapes=True
1862+
)
18451863
def aten_ops_pow(
18461864
ctx: ConversionContext,
18471865
target: Target,
@@ -3046,12 +3064,16 @@ def zero_diag_size_validator(node: Node) -> bool:
30463064
)
30473065
return False
30483066

3049-
offset, dim1, dim2 = (
3050-
node.args[1],
3051-
node.args[2],
3052-
node.args[3],
3053-
)
3054-
3067+
if len(node.args) == 1:
3068+
offset, dim1, dim2 = 0, 0, 1
3069+
elif len(node.args) == 2:
3070+
offset, dim1, dim2 = node.args[1], 0, 1
3071+
else:
3072+
offset, dim1, dim2 = (
3073+
node.args[1],
3074+
node.args[2],
3075+
node.args[3],
3076+
)
30553077
num_dims = len(input_shape)
30563078

30573079
# Adjust dimensions to be positive and canonicalize

0 commit comments

Comments
 (0)