From e7267a128bbc45585c2d7ac7110b2ad33805e246 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 28 Apr 2023 15:14:02 -0700 Subject: [PATCH] Combination: 16 commits with aten improvements refactor: Moving elementwise and unary core to impl Signed-off-by: Naren Dasan new file: ../converters/impl/unary/base.py Moving elementwise core to impl - rsqrt (FX Converter Refactor [9/N]) (#1905) Converter reorg fmod Converter reorg and rsub Rsub error fixes and linting error fixed Rsub test case to include different inputs Converter reorg batch norm batch norm error fix and linting issue error fix layer_norm converter Layer norm linting correction ops file correction fixing lint Acc_ops layer_norm correction Converter reorg and softmax operation softmax linting error fix Converter reorg and gelu Linting error Converter reorg and squeeze operator Correcting squeeze operator implementation, linting error and acc squeeze test Adding the condition to convert dim to int and removing the comment Converter reorg and select operation select operation correction and linting changes converter reorg and slice converter reorg slice op Correcting linting error and slice changes Correcting the slice operation converter reorg and matmul Matmul issue fixes and lint error check moving matmul to individual file Converter reorg and where operator adding where aten op aten::where correction and linting error changes aten::unsqueeze impl refactor Signed-off-by: Boris Fomitchev Moved clamp to impl Signed-off-by: Boris Fomitchev fixed method name Signed-off-by: Boris Fomitchev fix: Add automatic type promotion for FX ops - Implement functionality to cast tensors to alternative types - Add functionality to elementwise ops to promote types and perform necessary casts - Address issues in FX ops where mixed-precision computations can cause errors - Add test cases to validate fix Revert all changes to py/torch_tensorrt/fx Revert "fix: Add automatic type promotion for FX ops" This reverts commit f1f371663b222e58ec15c335080365f4b2a44a40. Revert "Moved clamp to impl" This reverts commit df401dd95ba8f6bbe5777e7c36086b30d1eeea26. Revert "aten::unsqueeze impl refactor" This reverts commit b4247358540e353ab1d81252a0000a2131d68528. Revert "Converter reorg and where operator" This reverts commit b4da15e9af5e4b92fe2ee76ed6fd71696452bd3b. Revert "converter reorg and matmul" This reverts commit 7551eeecfd3c7679cbed048f910bed2e792910c8. Revert "converter reorg and slice" This reverts commit 9bbdc9ecf6c8fd04b915dc40567bc67d13c86d38. Revert "Converter reorg and select operation" This reverts commit fb70253c1372020c55812f4b22df5f64eb59018a. Revert "Converter reorg and squeeze operator" This reverts commit 294545c6361b79656f74f742a84abd6b9b1f1d24. Revert "Converter reorg and gelu" This reverts commit 37d11682092d64f1f250581944faddd7c36de176. Revert "Converter reorg and softmax operation" This reverts commit 1ba6d139d100b7cc1ba7c95400047ab993e97ce2. Revert "layer_norm converter" This reverts commit e0b34b1c35be30e145c2bc484da661c43c23cf9b. Revert "Converter reorg batch norm" This reverts commit 59354e5416ee14ca3145b28b389be2ffae2c1821. Revert "Converter reorg and rsub" This reverts commit db15d2704f7b12ff4fcee4a267b42bc7c5d37bf6. Revert "Converter reorg fmod" This reverts commit ce3fa67655cc0924dc2b046aad8b583c90e283ae. Revert "Moving elementwise core to impl - rsqrt (FX Converter Refactor [9/N]) (#1905)" This reverts commit 7158ca54bd1827e5e0de68107cf19a20238c36f6. Revert "refactor: Moving elementwise and unary core to impl" This reverts commit 45e43ca7ccd6ae8fa3d67b124a2a453989d2de47. fix: Replay all FX changes in Dynamo - Add multiple fixes to make FX changes appear in Dynamo directory, using Dynamo registry - All converters with open PRs are linked and shown - Update references, imports, code, merges, rebases accordingly - Add new test cases to Dynamo for converters Temporarily removing rsub pending fix Fixing clamp to not use Torch Signed-off-by: Boris Fomitchev Fixing select to not use torch fix: Reorganize folders in latest implementation - Update test references and imports accordingly Embedding operator in dynamo reciprocal lowering pass fix: Fix for Dynamic Shape Tests + Input class feat: Add permute operation implementation chore: Move converter registry, update imports --- .circleci/config.yml | 16 + py/torch_tensorrt/_Input.py | 11 + py/torch_tensorrt/dynamo/__init__.py | 3 +- .../dynamo/conversion/SourceIR.py | 24 ++ .../dynamo/conversion/__init__.py | 2 + .../dynamo/conversion/aten_ops_converters.py | 378 ++++++++++++++++++ .../{ => conversion}/converter_registry.py | 0 .../dynamo/conversion/converter_utils.py | 97 +++++ .../dynamo/conversion/impl/__init__.py | 14 + .../dynamo/conversion/impl/activation.py | 65 +++ .../conversion/impl/condition/__init__.py | 1 + .../dynamo/conversion/impl/condition/ops.py | 108 +++++ .../conversion/impl/elementwise/__init__.py | 2 + .../conversion/impl/elementwise/base.py | 162 ++++++++ .../conversion/impl/elementwise/clamp.py | 78 ++++ .../dynamo/conversion/impl/elementwise/ops.py | 177 ++++++++ .../dynamo/conversion/impl/embedding.py | 73 ++++ .../dynamo/conversion/impl/matmul.py | 54 +++ .../conversion/impl/normalization/__init__.py | 1 + .../conversion/impl/normalization/ops.py | 313 +++++++++++++++ .../dynamo/conversion/impl/permutation.py | 34 ++ .../dynamo/conversion/impl/select.py | 64 +++ .../dynamo/conversion/impl/shape.py | 77 ++++ .../dynamo/conversion/impl/slice/__init__.py | 1 + .../dynamo/conversion/impl/slice/base.py | 39 ++ .../dynamo/conversion/impl/slice/ops.py | 96 +++++ .../dynamo/conversion/impl/squeeze.py | 63 +++ .../dynamo/conversion/impl/unary/__init__.py | 1 + .../dynamo/conversion/impl/unary/base.py | 44 ++ .../dynamo/conversion/impl/unary/ops.py | 98 +++++ .../dynamo/conversion/impl/unsqueeze.py | 52 +++ .../dynamo/conversion/trt_interpreter.py | 14 +- .../dynamo/lowering/_decompositions.py | 17 + py/torch_tensorrt/dynamo/test_utils.py | 310 ++++++++++++++ .../dynamo/backend/test_backend_compiler.py | 2 +- .../py/dynamo/backend/test_decompositions.py | 76 +++- .../converters/test_adaptive_avgpool_aten.py | 127 ++++++ .../dynamo/converters/test_batchnorm_aten.py | 66 +++ .../dynamo/converters/test_binary_ops_aten.py | 263 ++++++++++++ tests/py/dynamo/converters/test_cat_aten.py | 94 +++++ tests/py/dynamo/converters/test_clamp_aten.py | 71 ++++ .../converters/test_convolution_aten.py | 203 ++++++++++ tests/py/dynamo/converters/test_elu_aten.py | 52 +++ .../dynamo/converters/test_embedding_aten.py | 99 +++++ .../py/dynamo/converters/test_expand_aten.py | 31 ++ tests/py/dynamo/converters/test_gelu_aten.py | 52 +++ .../dynamo/converters/test_hardtanh_aten.py | 54 +++ .../dynamo/converters/test_layer_norm_aten.py | 45 +++ .../dynamo/converters/test_leaky_relu_aten.py | 54 +++ .../py/dynamo/converters/test_linear_aten.py | 71 ++++ .../py/dynamo/converters/test_matmul_aten.py | 97 +++++ tests/py/dynamo/converters/test_mean_aten.py | 85 ++++ .../converters/test_permutation_aten.py | 73 ++++ tests/py/dynamo/converters/test_relu_aten.py | 52 +++ .../py/dynamo/converters/test_reshape_aten.py | 103 +++++ tests/py/dynamo/converters/test_rsqrt_aten.py | 30 ++ .../py/dynamo/converters/test_select_aten.py | 79 ++++ tests/py/dynamo/converters/test_selu_aten.py | 52 +++ .../py/dynamo/converters/test_sigmoid_aten.py | 68 ++++ tests/py/dynamo/converters/test_slice_aten.py | 86 ++++ .../py/dynamo/converters/test_softmax_aten.py | 45 +++ .../py/dynamo/converters/test_squeeze_aten.py | 68 ++++ tests/py/dynamo/converters/test_tanh_aten.py | 52 +++ .../dynamo/converters/test_unsqueeze_aten.py | 62 +++ tests/py/dynamo/converters/test_where_aten.py | 33 ++ tests/py/dynamo/models/test_models.py | 10 +- 66 files changed, 4827 insertions(+), 17 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/SourceIR.py create mode 100644 py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py rename py/torch_tensorrt/dynamo/{ => conversion}/converter_registry.py (100%) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/__init__.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/activation.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/condition/__init__.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/elementwise/__init__.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/elementwise/clamp.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/embedding.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/matmul.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/normalization/__init__.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/permutation.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/select.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/shape.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/slice/__init__.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/slice/base.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/squeeze.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/unary/__init__.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/unary/base.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py create mode 100644 py/torch_tensorrt/dynamo/test_utils.py create mode 100644 tests/py/dynamo/converters/test_adaptive_avgpool_aten.py create mode 100644 tests/py/dynamo/converters/test_batchnorm_aten.py create mode 100644 tests/py/dynamo/converters/test_binary_ops_aten.py create mode 100644 tests/py/dynamo/converters/test_cat_aten.py create mode 100644 tests/py/dynamo/converters/test_clamp_aten.py create mode 100644 tests/py/dynamo/converters/test_convolution_aten.py create mode 100644 tests/py/dynamo/converters/test_elu_aten.py create mode 100644 tests/py/dynamo/converters/test_embedding_aten.py create mode 100644 tests/py/dynamo/converters/test_expand_aten.py create mode 100644 tests/py/dynamo/converters/test_gelu_aten.py create mode 100644 tests/py/dynamo/converters/test_hardtanh_aten.py create mode 100644 tests/py/dynamo/converters/test_layer_norm_aten.py create mode 100644 tests/py/dynamo/converters/test_leaky_relu_aten.py create mode 100644 tests/py/dynamo/converters/test_linear_aten.py create mode 100644 tests/py/dynamo/converters/test_matmul_aten.py create mode 100644 tests/py/dynamo/converters/test_mean_aten.py create mode 100644 tests/py/dynamo/converters/test_permutation_aten.py create mode 100644 tests/py/dynamo/converters/test_relu_aten.py create mode 100644 tests/py/dynamo/converters/test_reshape_aten.py create mode 100644 tests/py/dynamo/converters/test_rsqrt_aten.py create mode 100644 tests/py/dynamo/converters/test_select_aten.py create mode 100644 tests/py/dynamo/converters/test_selu_aten.py create mode 100644 tests/py/dynamo/converters/test_sigmoid_aten.py create mode 100644 tests/py/dynamo/converters/test_slice_aten.py create mode 100644 tests/py/dynamo/converters/test_softmax_aten.py create mode 100644 tests/py/dynamo/converters/test_squeeze_aten.py create mode 100644 tests/py/dynamo/converters/test_tanh_aten.py create mode 100644 tests/py/dynamo/converters/test_unsqueeze_aten.py create mode 100644 tests/py/dynamo/converters/test_where_aten.py diff --git a/.circleci/config.yml b/.circleci/config.yml index ae4261ac43..d46c695678 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -780,6 +780,21 @@ commands: - store_artifacts: path: /tmp/testlogs + test-dynamo-converters: + description: "Test the Dynamo aten converters" + steps: + - run: + name: Run Dynamo converter tests + command: | + cd tests/py/dynamo/converters + TESTS_TO_RUN=$(circleci tests glob "test_*.py" | circleci tests split --split-by=timings) + pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/test_results.xml $TESTS_TO_RUN + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + # =================== Dynamo tests end ======================== # # Define a job to be invoked later in a workflow. @@ -1036,6 +1051,7 @@ jobs: command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl # We install torch after torch-trt because pip automatically enforces the version constraint otherwise - dump-test-env + - test-dynamo-converters - test-dynamo-torch_compile - test-dynamo-models_torch_compile - test-dynamo-models_torch_export diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 1ea87c5a4e..8d3a842a47 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -68,6 +68,17 @@ def __init__(self, *args, **kwargs): - Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW) - Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW """ + # Compatibility code for switching over from InputTensorSpec + if "shape" in kwargs and "shape_ranges" in kwargs: + assert ( + len(kwargs["shape_ranges"]) == 1 and len(kwargs["shape_ranges"][0]) == 3 + ) + del kwargs["shape"] + + kwargs["min_shape"] = kwargs["shape_ranges"][0][0] + kwargs["opt_shape"] = kwargs["shape_ranges"][0][1] + kwargs["max_shape"] = kwargs["shape_ranges"][0][2] + if len(args) == 1: if not Input._supported_input_size_type(args[0]): raise TypeError( diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 63a3308fe2..5918bad806 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -3,8 +3,9 @@ if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): from ._settings import * + from .conversion import * from .aten_tracer import trace - from .converter_registry import ( + from .conversion.converter_registry import ( DYNAMO_CONVERTERS, dynamo_tensorrt_converter, ) diff --git a/py/torch_tensorrt/dynamo/conversion/SourceIR.py b/py/torch_tensorrt/dynamo/conversion/SourceIR.py new file mode 100644 index 0000000000..c0547986c4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/SourceIR.py @@ -0,0 +1,24 @@ +from enum import Enum, auto + + +class SourceIR(Enum): + NN = auto() + ACC = auto() + ATEN = auto() + PRIM = auto() + TORCHTRT_LOWERED = auto() + UNKNOWN = auto() + + def __str__(self): + if self == SourceIR.NN: + return "nn" + elif self == SourceIR.ACC: + return "acc" + elif self == SourceIR.ATEN: + return "aten" + elif self == SourceIR.PRIM: + return "prim" + elif self == SourceIR.TORCHTRT_LOWERED: + return "torchtrt_lowered" + else: + return "unknown_ir" diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index f50b22f27d..d201665a5b 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,3 +1,5 @@ +from .SourceIR import SourceIR +from .aten_ops_converters import * from .trt_interpreter import * from .conversion import * from .truncate_long_and_double import repair_long_or_double_inputs diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py new file mode 100644 index 0000000000..38f8692852 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -0,0 +1,378 @@ +import logging +from typing import Dict, Sequence, Tuple, Union +import torch +import tensorrt as trt +from torch_tensorrt.fx.converters import acc_ops_converters +from .converter_registry import dynamo_tensorrt_converter +from torch.fx.node import Argument, Target, Node + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR, impl +from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import cast_int_int_div_trt_tensor + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def args_bounds_check(args, i, replacement=None): + return args[i] if len(args) > i else replacement + + +@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) +def aten_ops_batch_norm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.batch_norm( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6], + args[7], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.div.default) +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) +def aten_ops_div( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + # If both are TRTTensor, both are cast to float32 + if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor): + kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor( + network, + kwargs_new["input"], + kwargs_new["other"], + name, + ) + # If one is TRTTensor, it is cast to float32 + elif isinstance(args[0], TRTTensor) and ( + kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32 + ): + kwargs_new["input"] = cast_trt_tensor( + network, kwargs_new["input"], trt.float32, name + ) + elif isinstance(args[1], TRTTensor) and ( + kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32 + ): + kwargs_new["other"] = cast_trt_tensor( + network, kwargs_new["other"], trt.float32, name + ) + rounding_mode = kwargs.get("rounding_mode") + if rounding_mode is None: + return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name) + elif rounding_mode == "floor": + return acc_ops_converters.acc_ops_floor_div( + network, target, None, kwargs_new, name + ) + elif rounding_mode == "trunc": + return impl.elementwise.trunc_div( + network, target, SourceIR.ATEN, name, args[0], args[1] + ) + else: + raise RuntimeError( + f"Target {target} does not support rounding mode {rounding_mode}" + ) + + +def embedding_param_validator(embedding_node: Node): + + max_norm = args_bounds_check(embedding_node.args, 2) + norm_type = args_bounds_check(embedding_node.args, 3) + scale_grad_by_freq = args_bounds_check(embedding_node.args, 4) + sparse = args_bounds_check(embedding_node.args, 5) + + if max_norm is not None: + _LOGGER.debug( + f"Currently we don't support specifying max_norm, got {max_norm}." + ) + return False + + if norm_type is not None and norm_type != 2.0: + _LOGGER.debug( + f"Currently we don't support specifying norm_type, got {norm_type}." + ) + return False + + if scale_grad_by_freq is not None: + _LOGGER.debug( + f"Currently we don't support specifying scale gradient by word frequency, got {scale_grad_by_freq}." + ) + return False + + if sparse is not None: + _LOGGER.debug(f"Currently we don't support sparse gradient, got {sparse}.") + return False + + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.embedding.default, capability_validator=embedding_param_validator +) +def aten_ops_embedding( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.embedding.embedding( + network, + target, + SourceIR.ATEN, + name, + input=args[1], + weight=args[0], + max_norm=args_bounds_check(args, 2), + norm_type=args_bounds_check(args, 3), + scale_grad_by_freq=args_bounds_check(args, 4), + sparse=args_bounds_check(args, 5), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor) +def aten_ops_fmod( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1]) + + +@dynamo_tensorrt_converter(torch.ops.aten.gelu.default) +def aten_ops_gelu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.activation.gelu( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.matmul) +@dynamo_tensorrt_converter(torch.ops.aten.mm.default) +def aten_ops_matmul( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.matmul.matrix_multiply( + network, target, SourceIR.ATEN, name, args[0], args[1] + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) +def aten_ops_layernorm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.layer_norm( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args[3], + args[4], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.relu.default) +def aten_ops_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return impl.activation.relu( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default) +def aten_ops_rsqrt( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return impl.elementwise.rsqrt( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) +@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) +def aten_ops_squeeze( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.squeeze.squeeze(network, target, SourceIR.ATEN, name, args[0], args[1]) + + +@dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default) +def aten_ops_unsqueeze( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unsqueeze.unsqueeze( + network, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1] + ) + + +@dynamo_tensorrt_converter(torch.ops.aten._softmax.default) +def aten_ops_softmax( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.softmax( + network, target, SourceIR.ATEN, name, args[0], args[1] + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.where.self) +def aten_ops_where( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.condition.where( + network, + target, + SourceIR.ATEN, + name, + args[1], + args[2], + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.clamp.default) +def aten_ops_clamp( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.clamp( + network, + target, + SourceIR.ATEN, + name, + input_val=args[0], + min_val=args_bounds_check(args, 1), + max_val=args_bounds_check(args, 2), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.select.int) +def aten_ops_select( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.select.select( + network, target, SourceIR.ATEN, name, args[0], args[1], args[2] + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor) +def aten_ops_slice( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.slice_op( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args[3], + args_bounds_check(args, 4, replacement=1), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.permute.default) +def aten_ops_permute( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.permutation.permute( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) diff --git a/py/torch_tensorrt/dynamo/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/converter_registry.py similarity index 100% rename from py/torch_tensorrt/dynamo/converter_registry.py rename to py/torch_tensorrt/dynamo/conversion/converter_registry.py diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 4471931e4c..584e15b263 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,5 +1,19 @@ import torch +from torch_tensorrt.fx.types import ( + TRTDataType, + TRTNetwork, + TRTTensor, +) + +from torch_tensorrt.fx.converters.converter_utils import ( + unified_dtype_converter, + Frameworks, +) + +import tensorrt as trt +from typing import List + def dynamic_unsupported(node: torch.fx.Node) -> bool: # Validate that none of the inputs to the node have Dynamic shapes @@ -28,3 +42,86 @@ def dynamic_unsupported(node: torch.fx.Node) -> bool: return False return True + + +def cast_trt_tensor( + network: TRTNetwork, + input_val: TRTTensor, + dtype: TRTDataType, + name: str, +) -> TRTTensor: + """ + Given a TRT Tensor, convert that Tensor to the specified dtype + Adds an Identity layer to the network which performs the conversion + Args: + network (TRTNetwork): A TensorRT network + input_val (TRTTensor): A TRT Tensor to cast to a new data type + dtype (TRTDataType): The TRTDataType to cast the input Tensor to + name (str): Name of the calling layer + Returns: + A TensorRT ITensor which has been casted to the specified dtype + """ + trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT) + + if input_val.dtype != trt_dtype: + identity_layer = network.add_identity(input_val) + identity_layer.set_output_type(0, trt_dtype) + identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}" + return identity_layer.get_output(0) + else: + return input_val + + +def cast_int_int_div_trt_tensor( + network: TRTNetwork, + lhs_val: TRTTensor, + rhs_val: TRTTensor, + name: str, +) -> List[TRTTensor]: + """ + Given two `int` data type TRT Tensor to div operation, cast the TRT Tensor to float type + Args: + network (TRTNetwork): A TensorRT network + lhs_val (TRTTensor): A TRT Tensor numerator + rhs_val (TRTTensor): A TRT Tensor numerator + name (str): Name of calling layer + Returns: + A list of lhs_val and rhs_val casted to the approriate datatype + """ + if (lhs_val.dtype == trt.int8 or lhs_val.dtype == trt.int32) and ( + rhs_val.dtype == trt.int8 or rhs_val.dtype == trt.int32 + ): + lhs_val = cast_trt_tensor(network, lhs_val, trt.float32, name) + rhs_val = cast_trt_tensor(network, rhs_val, trt.float32, name) + return list((lhs_val, rhs_val)) + + +def broadcastable( + a: TRTTensor, + b: TRTTensor, +) -> bool: + "Check if two tensors are broadcastable according to torch rules" + a_shape = tuple(a.shape) + b_shape = tuple(b.shape) + # check from the trailing + diff = len(a_shape) - len(b_shape) + if diff == 0: + return True + if diff > 0: + max = len(a_shape) + min = len(b_shape) + greater_tensor = a_shape + lesser_tensor = b_shape + elif diff < 0: + max = len(b_shape) + min = len(a_shape) + greater_tensor = b_shape + lesser_tensor = a_shape + j = min - 1 + for i in range(max - 1, diff - 1, -1): + if not ( + greater_tensor[i] != lesser_tensor[j] + and (greater_tensor[i] == 1 or lesser_tensor[i] == 1) + ): + return False + return True diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py new file mode 100644 index 0000000000..db6e405978 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -0,0 +1,14 @@ +from torch_tensorrt.fx.converters.impl import convolution +from . import condition +from . import elementwise +from . import embedding +from . import normalization +from . import slice +from . import unary +from . import activation +from . import matmul +from . import select +from . import shape +from . import squeeze +from . import unsqueeze +from . import permutation diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation.py b/py/torch_tensorrt/dynamo/conversion/impl/activation.py new file mode 100644 index 0000000000..ec3e078820 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation.py @@ -0,0 +1,65 @@ +import numpy as np +from typing import Any, Optional +import math + +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.converters.impl.activation import * +from torch_tensorrt.fx.converters.converter_utils import ( + mark_as_int8_layer, + set_layer_name, + get_trt_plugin, +) +from torch_tensorrt.dynamo.conversion import SourceIR + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, + TRTPluginFieldCollection, +) + + +def gelu( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + alpha: Optional[Any] = None, +): + approximate = alpha + if approximate is not None: + raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"GELU received input {input_val} that is not part " + "of the TensorRT region!" + ) + if network.has_implicit_batch_dimension: + raise RuntimeError( + "GeLU converter currently doesn't support implicit batch dimension" + ) + plugin_name = "CustomGeluPluginDynamic" + # type_id 0 for float32, 1 for float16 + type_id = trt.PluginField( + "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32 + ) + field_collection = TRTPluginFieldCollection([type_id]) + plugin_version = "1" + + plugin = get_trt_plugin(plugin_name, field_collection, plugin_version) + + layer = network.add_plugin_v2([input_val], plugin) + + def gelu_dyn_range_fn(dyn_range): + return ( + dyn_range[0] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0))) + ), (dyn_range[1] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0)))) + + if input_val.dynamic_range is not None: + dyn_range = gelu_dyn_range_fn(input_val.dynamic_range) + mark_as_int8_layer(layer, dyn_range) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/condition/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/condition/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/condition/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py new file mode 100644 index 0000000000..79472fa2e7 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py @@ -0,0 +1,108 @@ +from typing import Optional + + +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import broadcastable +from torch_tensorrt.fx.converters.converter_utils import ( + broadcast, + get_trt_tensor, + set_layer_name, +) +from torch_tensorrt.dynamo.conversion.impl.slice import expand + + +def where( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, + condition: TRTTensor, +) -> TRTTensor: + input_dim = len(tuple(input.shape)) + other_dim = len(tuple(other.shape)) + condition_dim = len(tuple(condition.shape)) + + if type(input) != TRTTensor: + assert type(input) is torch.Tensor, f"value {input} is not torch.Tensor!" + + if type(other) != TRTTensor: + assert type(other) is torch.Tensor, f"value {other} is not torch.Tensor!" + + if not (broadcastable(input, other)): + assert f"The two torch tensors should be broadcastable" + + # get output shape + # purpose of this is to bring input and other rank same as + # output_shape to input it to the add_expand operation + # condition will have dimension of either input or other + input, other = broadcast(network, input, other, f"{name}_x", f"{name}_y") + if len(tuple(condition.shape)) != len(tuple(input.shape)): + condition, input = broadcast( + network, condition, input, f"{name}_condition", f"{name}_x" + ) + + x_shape = list(input.shape) + y_shape = list(other.shape) + condition_shape = list(condition.shape) + output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) + + # expand shape + if type(condition) != TRTTensor: + assert condition.dtype == torch.bool, "condition dtype is not bool" + if condition_shape != output_shape: + condition.expand(output_shape) + condition = condition.to(torch.int32) + condition_const = get_trt_tensor(network, condition, f"{name}_condition") + condition_layer = network.add_identity(condition_const) + condition_layer.set_output_type(0, trt.bool) + set_layer_name(condition_layer, target, f"{name}_condition") + condition_val = condition_layer.get_output(0) + else: + assert condition.dtype == trt.bool, "mask dtype is not bool!" + if condition_shape != condition_dim: + condition_val = expand( + network, target, source_ir, f"{name}_expand", condition, output_shape + ) + else: + condition_val = condition + + if type(input) != TRTTensor: + if x_shape != input_dim: + # special case where 1 element in input + if len(input.shape) == 0: + input = input.unsqueeze(0) + input = input.expand(output_shape) + x_val = get_trt_tensor(network, input, f"{name}_x") + else: + x_val = input + if x_shape != output_shape: + x_val = expand( + network, target, source_ir, f"{name}_x_expand", input, output_shape + ) + + if type(other) != TRTTensor: + if y_shape != output_shape: + # special case where 1 element in other + if len(other.shape) == 0: + other = other.unsqueeze(0) + other = other.expand(output_shape) + y_val = get_trt_tensor(network, other, f"{name}_y") + else: + y_val = other + if y_shape != other_dim: + y_val = expand( + network, target, source_ir, f"{name}_y_expand", y_val, output_shape + ) + + select_layer = network.add_select(condition_val, x_val, y_val) + + set_layer_name(select_layer, target, f"{name}_select") + + return select_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/__init__.py new file mode 100644 index 0000000000..25d71e3702 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/__init__.py @@ -0,0 +1,2 @@ +from .ops import * +from .clamp import clamp diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py new file mode 100644 index 0000000000..9b15ebd4c4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -0,0 +1,162 @@ +import operator +import warnings +from typing import Union, Callable, Any, Optional + +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp +from torch_tensorrt.fx.utils import ( + unified_dtype_converter, + Frameworks, +) +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, +) +from torch_tensorrt.fx.converters.converter_utils import ( + set_layer_name, + broadcast, + squeeze_left, + get_trt_tensor, +) + + +def get_python_op_from_trt_elementwise_op( + trt_op: TRTElementWiseOp, +) -> Callable[[Any, Any], Any]: + if trt_op == trt.ElementWiseOperation.SUM: + return operator.add + elif trt_op == trt.ElementWiseOperation.PROD: + return operator.mul + elif trt_op == trt.ElementWiseOperation.SUB: + return operator.sub + elif trt_op == trt.ElementWiseOperation.DIV: + return operator.truediv + elif trt_op == trt.ElementWiseOperation.FLOOR_DIV: + return operator.floordiv + else: + raise RuntimeError(f"{trt_op} is not supported yet!") + + +def convert_binary_elementwise( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + op_type: trt.ElementWiseOperation, + lhs_val: Union[int, float, TRTTensor, torch.Tensor], + rhs_val: Union[int, float, TRTTensor, torch.Tensor], +) -> TRTTensor: + """ + This function adds a TensorRT elementwise layer. We allow both operands to be + constant (not a trt tensor) because in implicit batch dimension mode, we could + introduce constant via .size() op. Other scenario should be const folded first. + If any operand is not a trt tensor, we make it a trt constant layer while preserve + its dtype. Then we broadcast these two inputs to have the same number of dimensions. + We also promote the types of the two tensors to avoid dtype errors in TRT. + + Limitation: + If we are using implicit batch dim mode, the operand that is not a trt + tensor are not allowed to have larger ranks than the trt tensor operand. + + Args: + network (TRTNetwork): TensorRT network object. + target (Target): Target of fx node. + source_ir (SourceIR): The IR that is calling the function. + name (str): The name we want to assign to the created TensorRT layer. + lhs_val (TRTTensor): Left operand of the binary operation. Could + be a TensorRT tensor, a PyTorch tensor or a simple value. + rhs_val (TRTTensor): Right operand of the binary operation. Similar + to lhs_val. + op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation. + + Returns: + The output of TensorRT Elementwise layer. + """ + lhs_dtype = None + rhs_dtype = None + is_lhs_trt_tensor = False + is_rhs_trt_tensor = False + + if isinstance(lhs_val, TRTTensor): + lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH) + is_lhs_trt_tensor = True + if isinstance(rhs_val, TRTTensor): + rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH) + is_rhs_trt_tensor = True + + if not is_lhs_trt_tensor and not is_rhs_trt_tensor: + warnings.warn( + f"Both operands of the binary elementwise op {name} " + "are constant. In this case, please consider constant fold the model first." + ) + return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val) + + # If the following conditions are true: + # 1. the network has implicit batch dimension, + # 2. one operand has shape [] (real shape is [batch_size]), + # 3. another operand is a scalar, + # then the result should also have shape [] (real shape is [batch_size]). + # + # In such case, we need to convert the scalar operand to tensor, because + # this way the shape will become [1], and then will be properly squeezed + # into [], meaning that the result will have shape [], which is what we + # expect. + # + # Note that the dtype here is supposed to be the same as the scalar + # dtype but we don't have a way to detect whether it makes sense for the + # scalar to be float or half. Hence we go with the lhs dtype. + if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): + rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) + if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): + lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) + + # When lhs is scalar, and rhs has shape [1,], then currently the assert + # will fail because lhs shape has fewer dimensions than rhs shape. This + # happens when using implicit batch dimension, when we removed the 1st + # dimension from input tensor, causing it to have shape [] - a scalar. We + # fix it by reducing the rhs constant with a squeeze_left, so it becomes a + # scalar too. More generally, we squeeze_left on input if it's a constant + # tensor. This is safe because broadcast will pad dimensions on the left + # (prepend) to make lhs and rhs shape compatible. + if network.has_implicit_batch_dimension: + if isinstance(lhs_val, torch.Tensor): + lhs_val = squeeze_left(lhs_val) + if isinstance(rhs_val, torch.Tensor): + rhs_val = squeeze_left(rhs_val) + + lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) + rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) + + promoted_type = torch.promote_types( + unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH), + unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH), + ) + trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT) + + if trt_promoted_type != lhs_val.dtype: + lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name) + if trt_promoted_type != rhs_val.dtype: + rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name) + + # Check the limitation in the doc string. + if network.has_implicit_batch_dimension: + if is_lhs_trt_tensor and not is_rhs_trt_tensor: + assert len(lhs_val.shape) >= len( + rhs_val.shape + ), f"{lhs_val.shape} >= {rhs_val.shape}" + elif not is_lhs_trt_tensor and is_rhs_trt_tensor: + assert len(rhs_val.shape) >= len( + lhs_val.shape + ), f"{rhs_val.shape} >= {lhs_val.shape}" + + lhs_val, rhs_val = broadcast( + network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ) + layer = network.add_elementwise(lhs_val, rhs_val, op_type) + set_layer_name(layer, target, name, source_ir) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return output diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/clamp.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/clamp.py new file mode 100644 index 0000000000..59e1b0f723 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/clamp.py @@ -0,0 +1,78 @@ +import numpy as np +from typing import Optional +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.utils import ( + unified_dtype_converter, + Frameworks, +) + +from torch_tensorrt.fx.converters.converter_utils import ( + set_layer_name, + squeeze_left, + get_trt_tensor, +) + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + + +def add_clamp(network, input, val, op, name): + if not len(input.shape): + # clamping scalar + acc_ops_clamp_trt = get_trt_tensor( + network, + squeeze_left( + np.array( + [val], dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY) + ) + ), + f"{name}_clamp_{val}", + ) + else: + acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions + acc_ops_clamp_tensor = np.full( + acc_ops_clamp_shape, + val, + dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), + ) + acc_ops_clamp_trt = network.add_constant( + acc_ops_clamp_shape, acc_ops_clamp_tensor + ).get_output(0) + layer = network.add_elementwise(input, acc_ops_clamp_trt, op) + return layer + + +def clamp( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val, + min_val=None, + max_val=None, +) -> TRTTensor: + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"Clamp received input {input_val} that is not part " + "of the TensorRT region!" + ) + + if min_val is not None: + clamp_min_layer = add_clamp( + network, input_val, min_val, trt.ElementWiseOperation.MAX, name + ) + set_layer_name(clamp_min_layer, target, f"{name}_clamp_min") + input_val = clamp_min_layer.get_output(0) + if max_val is not None: + clamp_max_layer = add_clamp( + network, input_val, max_val, trt.ElementWiseOperation.MIN, name + ) + set_layer_name(clamp_max_layer, target, f"{name}_clamp_max") + input_val = clamp_max_layer.get_output(0) + + return input_val diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py new file mode 100644 index 0000000000..089fcf223c --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -0,0 +1,177 @@ +from typing import Any, Optional + +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.utils import ( + unified_dtype_converter, + Frameworks, +) +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor + +from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( + convert_binary_elementwise, +) +from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary +from torch_tensorrt.dynamo.conversion.impl.unary import sign + + +def trunc_div( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + """ + Perform trunc divide on Tensor, result of divide will be round toward zero. + This means for positive number, it will be floor round; for negative number, + it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3]. + + Args: + network: INetworkDefinition. + target: node target + source_ir (SourceIR): Source IR calling the function. + name: namespace for the op + input: divisor. + other: dividend. + + Returns: + A TensorRT tensor represent the result of trunc divide. + """ + prod_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_prod", + trt.ElementWiseOperation.PROD, + input, + other, + ) + + sign_output = sign( + network, + target, + source_ir, + name, + prod_output, + ) + + # Convert constant input into ITensor for UnaryOperation + if not isinstance(input, trt.tensorrt.ITensor): + input = get_trt_tensor(network, input, f"{name}_input") + if not isinstance(other, trt.tensorrt.ITensor): + other = get_trt_tensor( + network, + other, + f"{name}_other", + dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), + ) + + abs_input_output = convert_unary( + network, + target, + source_ir, + f"{name}_abs_input", + trt.UnaryOperation.ABS, + input, + ) + abs_other_output = convert_unary( + network, + target, + source_ir, + f"{name}_abs_other", + trt.UnaryOperation.ABS, + other, + ) + abs_floor_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_floor_div", + trt.ElementWiseOperation.FLOOR_DIV, + abs_input_output, + abs_other_output, + ) + output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_output", + trt.ElementWiseOperation.PROD, + abs_floor_output, + sign_output, + ) + + return output + + +def rsqrt( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + + sqrt_trt_output = convert_unary( + network, + target, + source_ir, + f"{name}_sqrt", + trt.UnaryOperation.SQRT, + input, + ) + + output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_output", + trt.ElementWiseOperation.DIV, + 1, + sqrt_trt_output, + ) + + return output + + +def fmod( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it + trunc_div_value = trunc_div( + network, + target, + source_ir, + name + "_trunc_div", + input, + other, + ) + prod_value = convert_binary_elementwise( + network, + target, + source_ir, + name + "_prod", + trt.ElementWiseOperation.PROD, + trunc_div_value, + other, + ) + sub_value = convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name + "_sub", + trt.ElementWiseOperation.SUB, + input, + prod_value, + ) + return sub_value diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py new file mode 100644 index 0000000000..a68d2455ee --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -0,0 +1,73 @@ +import operator +import warnings +from typing import Optional, cast, Any + +import numpy as np + +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, + set_layer_name, +) + +from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor + + +def embedding( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + weight: TRTTensor, + max_norm: None, + norm_type: None, + scale_grad_by_freq: bool, + sparse: bool, +) -> TRTTensor: + + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `embedding` function should be called with explicit batch dimension." + ) + + indices_tensor = input + embedding_tensor = weight + if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64: + raise RuntimeError( + "The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT." + ) + indices_tensor = get_trt_tensor(network, indices_tensor, f"{name}_indices_tensor") + embedding_tensor = get_trt_tensor( + network, embedding_tensor, f"{name}_embedding_tensor" + ) + # unsupported parameters + # ignore padding_idx since it is meaningful for training only + + if max_norm is not None: + raise RuntimeError( + f"Currently we don't support specifying max_norm, got {max_norm}." + ) + + if norm_type is not None and norm_type != 2.0: + raise RuntimeError( + f"Currently we don't support specifying max_norm, got {norm_type} for norm_type." + ) + + if scale_grad_by_freq: + raise RuntimeError( + "Currently we don't support scale gradient by word frequency." + ) + + if sparse: + raise RuntimeError("Currently we don't support sparse gradient.") + + # Implement embedding lookup with gather layer + gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0) + set_layer_name(gather_layer, target, name + "_gather") + return gather_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py new file mode 100644 index 0000000000..846f4ab2ee --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -0,0 +1,54 @@ +from typing import Optional + + +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.utils import ( + unified_dtype_converter, + Frameworks, +) +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + get_trt_tensor, + broadcast, + set_layer_name, +) + + +def matrix_multiply( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + if not isinstance(input, trt.tensorrt.ITensor): + input = get_trt_tensor(network, input, f"{name}_input") + if not isinstance(other, trt.tensorrt.ITensor): + other = get_trt_tensor( + network, + other, + f"{name}_other", + dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), + ) + + input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE + preset_diff = 0 + + if len(input.shape) == 1: + preset_diff -= 1 + input_matrix_op = trt.MatrixOperation.VECTOR + + if len(other.shape) == 1: + preset_diff += 1 + other_matrix_op = trt.MatrixOperation.VECTOR + + input, other = broadcast( + network, input, other, f"{name}_input", f"{name}_other", preset_diff + ) + layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py new file mode 100644 index 0000000000..9d193fdf92 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -0,0 +1,313 @@ +from typing import cast, Union, Any, Optional, Sequence + +import numpy as np + +import tensorrt as trt +import torch +from torch.fx.node import Target + +import logging + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.utils import get_dynamic_dims +from torch_tensorrt.dynamo.conversion import SourceIR + +from torch_tensorrt.fx.converters.converter_utils import ( + get_trt_plugin, + set_layer_name, + to_numpy, + has_dynamic_shape, + get_positive_dim, +) + +from torch_tensorrt.dynamo.conversion.impl.unary.base import ( + convert_unary, +) + +from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( + convert_binary_elementwise, +) + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def batch_norm( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + weight: torch.Tensor, + bias: torch.Tensor, + running_mean: torch.Tensor, + running_var: torch.Tensor, + training: torch.Tensor, + momentum: torch.Tensor, + eps: list, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"BatchNorm2d received input {input} that is not part " + "of the TensorRT region!" + ) + + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm." + + scale = cast(torch.Tensor, to_numpy(cast(torch.Tensor, weight))) / np.sqrt( + cast(torch.Tensor, to_numpy(cast(torch.Tensor, running_var))) + cast(float, eps) + ) + + bias = ( + to_numpy(cast(torch.Tensor, bias)) + - to_numpy(cast(torch.Tensor, running_mean)) * scale + ) + power = np.ones_like(scale) + + # For BatchNorm1d, reshape 1d to 2d + output_shape = input.shape + if not network.has_implicit_batch_dimension and len(input.shape) < 4: + assert ( + len(get_dynamic_dims(input.shape)) <= 1 + ), "BatchNorm1D with more than one dynamic dims is not currently supported." + reshape_layer = network.add_shuffle(input) + if len(input.shape) == 2: + reshape_layer.reshape_dims = (input.shape[0], input.shape[1], 1, 1) + else: # len(input_val.shape) == 3 + reshape_layer.reshape_dims = ( + input.shape[0], + input.shape[1], + input.shape[2], + 1, + ) + set_layer_name(reshape_layer, target, f"{name}_reshape_2d") + input = reshape_layer.get_output(0) + layer = network.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power) + set_layer_name(layer, target, name) + + # For BatchNorm1d, reshape output back to 1d + if not network.has_implicit_batch_dimension and len(output_shape) < 4: + reshape_output_layer = network.add_shuffle(layer.get_output(0)) + reshape_output_layer.reshape_dims = tuple(output_shape) + set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d") + layer = reshape_output_layer + return layer.get_output(0) + + +def layer_norm( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + normalized_shape: list, + weight: torch.Tensor, + bias: torch.Tensor, + eps: list, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + if not isinstance(input, trt.tensorrt.ITensor): + raise RuntimeError( + f"LayerNorm received input {input} that is not part " + "of the TensorRT region!" + ) + + gamma = weight.detach().cpu().float().numpy() + gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) + beta = bias.detach().cpu().float().numpy() + beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) + eps_field = trt.PluginField( + "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 + ) + try: + normalized_shape = np.array(normalized_shape, dtype=np.int32) + except TypeError: + _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") + normalized_shape = np.array([], dtype=np.int32) + + normalized_shape_filed = trt.PluginField( + "normalized_shape", normalized_shape, trt.PluginFieldType.INT32 + ) + field_collection = trt.PluginFieldCollection( + [gamma_field, beta_field, eps_field, normalized_shape_filed] + ) + + try: + if network.has_implicit_batch_dimension: + plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") + else: + plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") + except AssertionError: + _LOGGER.error( + "Unable to find layer norm plugin, fall back to TensorRT implementation." + ) + return layer_norm_no_plugin( + network, target, source_ir, name, input, normalized_shape, weight, bias, eps + ) + layer = network.add_plugin_v2([input], plugin) + layer.name = name + return layer.get_output(0) + + +def layer_norm_no_plugin( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + normalized_shape: list, + weight: torch.Tensor, + bias: torch.Tensor, + eps: list, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"LayerNorm received input {input} that is not part " + "of the TensorRT region!" + ) + + shape = weight.shape # type: ignore[union-attr] + broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape + gamma = to_numpy(weight.reshape(*shape)) # type: ignore[union-attr] + beta = to_numpy(bias.reshape(*shape)) # type: ignore[union-attr] + + axes = 0 + for d in range(len(shape)): + axes |= 1 << (len(input.shape) - d - 1) + + # E[x] + mean_expected_layer = network.add_reduce( + input, trt.ReduceOperation.AVG, axes, keep_dims=True + ) + set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") + + # X-E[x] + sub_trt = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_sub", + trt.ElementWiseOperation.SUB, + input, + mean_expected_layer.get_output(0), + ) + # Variance = mean(pow(x_sub_mean,2)) + pow_tensor = network.add_constant( + (1,) * len(input.shape), + trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), + ) + pow_tensor.name = f"{name}_power" + pow_var = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_pow_var", + trt.ElementWiseOperation.POW, + sub_trt, + pow_tensor.get_output(0), + ) + mean_trt_layer = network.add_reduce( + pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True + ) + set_layer_name(mean_trt_layer, target, f"{name}_mean") + # Variance + eps + eps_tensor = network.add_constant( + (1,) * len(input.shape), + trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), + ) + eps_tensor.name = f"{name}_eps" + add_trt = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_add", + trt.ElementWiseOperation.SUM, + mean_trt_layer.get_output(0), + eps_tensor.get_output(0), + ) + # SQRT((Var + eps)) + sqrt_trt = convert_unary( + network, + target, + source_ir, + f"{name}_sqrt", + trt.UnaryOperation.SQRT, + add_trt, + ) + # (x - E[x]) / sqrt((var + eps)) + div_trt = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_div_trt", + trt.ElementWiseOperation.DIV, + sub_trt, + sqrt_trt, + ) + + assert gamma is not None + gamma_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(gamma))) # type: ignore[attr-defined] + gamma_tensor.name = f"{name}_gamma" + assert beta is not None + beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) # type: ignore[attr-defined] + beta_tensor.name = f"{name}_beta" + # y * gamma + beta + scale_layer = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_scale", + trt.ElementWiseOperation.PROD, + div_trt, + gamma_tensor.get_output(0), + ) + return convert_binary_elementwise( + network, + target, + source_ir, + name, + trt.ElementWiseOperation.SUM, + scale_layer, + beta_tensor.get_output(0), + ) + + +def softmax( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Optional[Any] = None, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + input_ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] + + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"softmax received input {input} that is not part " + "of the TensorRT region!" + ) + + # Used to get dim when dim is None. Copied from PyTorch softmax implementation. + def get_softmax_dim(ndim: int) -> int: + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret + + if dim is None: + dim = get_softmax_dim(input_ranks) + else: + dim = cast(int, dim) + + dim = get_positive_dim(dim, input_ranks) + if network.has_implicit_batch_dimension: + assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." + dim -= 1 + + layer = network.add_softmax(input) + layer.axes = 1 << dim + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py new file mode 100644 index 0000000000..492e35ba97 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -0,0 +1,34 @@ +from typing import Optional, Sequence, cast + + +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + set_layer_name, + get_positive_dim, +) + + +def permute( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + permutation: Sequence[int], +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"permute received input {input} that is not a TensorRT ITensor" + ) + + permutation = [ + get_positive_dim(i, len(input.shape)) for i in cast(Sequence[int], permutation) + ] + + layer = network.add_shuffle(input) + layer.second_transpose = tuple(permutation) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py new file mode 100644 index 0000000000..26ad175104 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -0,0 +1,64 @@ +from typing import Optional, cast + +import numpy as np +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + get_positive_dim, + has_dynamic_shape, + to_numpy, +) +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape + + +def select( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Shape, + index: Shape, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"slice_tensor received input {input} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) + dim = get_positive_dim(cast(int, dim), ranks) + dynamic_shape = has_dynamic_shape(input.shape) + if network.has_implicit_batch_dimension: + if dim == 0: + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" + ) + dim = dim - 1 + else: + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't select on negative shape dimension!" + index = index + + if index >= input.shape[dim]: + raise RuntimeError( + f"cannot have index greater than the dimension length! {input.shape[dim]}" + ) + output_shape = list(input.shape) + output_shape[dim] = 1 + if dynamic_shape > 0: + output_shape = get_shape_with_dynamic_shape( + network, target, source_ir, name, output_shape, input + ) + index_value = np.array(index, dtype=np.int32) + indices_tensor = network.add_constant( + index_value.shape, to_numpy(index_value) + ).get_output(0) + layer = network.add_gather(input, indices_tensor, dim) + out = layer.get_output(0) + if len(out.shape) != 1: + layer = network.add_shuffle(out) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py new file mode 100644 index 0000000000..7f122f5646 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -0,0 +1,77 @@ +from typing import Union + +import numpy as np + +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + set_layer_name, + to_numpy, +) + +from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( + convert_binary_elementwise, +) + + +def get_shape_with_dynamic_shape( + network: TRTNetwork, + target: Target, + source_ir: SourceIR, + name: str, + shape: Union[list, tuple, torch.Tensor], + input_val: TRTTensor, +) -> TRTTensor: + """ + Prepare the real output tensor shape for dynamic shape mode tensor input. + How this functions works: + Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation + output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual + reduce operation output shape. Steps of calculations are: + 1. get the actual tensor shape of input_val via add_shape layer; + 2. create a all 0 tensor [0, 0, 0]; + 3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False]; + 4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace + all -1 dynamic shape dimensions with actual batch_size value; + 5. output shape with actual batch_size as [2048, 128, 256] + + Args: + network (TRTNetwork): TensorRT network object. + shape: calculated shape of the expected output tensor + input_val (TRTTensor): A TensorRT ITensor. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + Returns: + TensorRT ITensors that represents the actual shape of the input_val + """ + # Ger real shape info for input_val + input_shape = network.add_shape(input_val).get_output(0) + + scale_layer = network.add_constant( + input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) + ) + set_layer_name(scale_layer, target, f"{name}_scale") + scale_res = scale_layer.get_output(0) + + length = input_shape.shape[0] + zero_layer = network.add_constant( + input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) + ) + set_layer_name(zero_layer, target, f"{name}_zeros") + + condition_val = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_shape", + trt.ElementWiseOperation.LESS, + scale_res, + zero_layer.get_output(0), + ) + select_layer = network.add_select(condition_val, input_shape, scale_res) + set_layer_name(select_layer, target, f"{name}_select") + return select_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py new file mode 100644 index 0000000000..97cc0d1404 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py @@ -0,0 +1,39 @@ +from typing import Optional + +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + has_dynamic_shape, + set_layer_name, +) + +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape + + +def slice( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + start: Shape, + shape: Shape, + stride: Shape, +) -> TRTTensor: + dynamic_shape = has_dynamic_shape(input.shape) + if dynamic_shape: + shape = get_shape_with_dynamic_shape( + network, target, source_ir, name, shape, input + ) + layer = network.add_slice( + input, + start=start, + shape=[] if dynamic_shape else shape, + stride=stride, + ) + if dynamic_shape: + layer.set_input(2, shape) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py new file mode 100644 index 0000000000..848e13ba4b --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -0,0 +1,96 @@ +from typing import Optional, cast +import math + +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + get_positive_dim, + has_dynamic_shape, + broadcast, + get_trt_tensor, +) +from torch_tensorrt.dynamo.conversion.impl.slice.base import slice + + +def slice_op( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + start: int, + stop: int, + step: int, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"slice_tensor received input {input} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) + dim = get_positive_dim(cast(int, dim), ranks) + dynamic_shape = has_dynamic_shape(input.shape) + if network.has_implicit_batch_dimension: + if dim == 0: + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" + ) + dim = dim - 1 + else: + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + start_int = cast(int, start) + stop_int = cast(int, stop) + if stop_int == 2**63 - 1: + stop_int = input.shape[dim] + step_int = cast(int, step) + start = [0] * len(input.shape) + start[dim] = start_int + stride = [1] * len(start) + stride[dim] = step_int + output_shape = list(input.shape) + output_shape[dim] = math.ceil((stop_int - start_int) / step_int) + + return slice(network, target, source_ir, name, input, start, output_shape, stride) + + +def expand( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + sizes: Shape, +) -> TRTTensor: + shape = list(sizes) + + input_val = get_trt_tensor(network, input, f"{name}_input") + + if network.has_implicit_batch_dimension: + shape = shape[1:] + + ranks = len(input_val.shape) + # TRT does not support different dimension size + # though this condition is not seen in the case of bmm + # where input_t and shape dimensions are not equal + assert len(shape) >= ranks + if len(shape) != ranks: + shape_tuple = tuple([0] * len(shape)) + shape_tensor = get_trt_tensor(network, input, f"{name}_shape") + input_val, shape_tensor = broadcast( + network, input_val, shape_tensor, f"{name}_input_val", f"{name}_shape_val" + ) + ranks = len(shape) + + inshape = tuple(input_val.shape) + shape = tuple(shape) + start = tuple([0] * ranks) + stride = tuple( + [int(i == o) for i, o in zip(inshape, shape)] + ) # stride == 1 if dimensions match, 0 otherwise + return slice(network, target, source_ir, name, input_val, start, shape, stride) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py new file mode 100644 index 0000000000..4c5ad200ad --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py @@ -0,0 +1,63 @@ +from typing import Optional, cast, Any + +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + get_positive_dim, + set_layer_name, +) + +from torch_tensorrt.fx.utils import get_dynamic_dims + + +def squeeze( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Optional[Any] = None, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"squeeze received input {input} that is not part " + "of the TensorRT region!" + ) + dims = [] + if dim is not None: + if isinstance(dim, int): + dims.append(cast(Optional[int], dim)) + else: + for dim in dim: + dims.append(cast(Optional[int], dim)) + + # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic + # dim, which is a very rare case. For now we just claim not supporting dim=None. + assert not (len(dims) == 0), "We don't support dim=None right now for squeeze." + + for dim in dims: + dim = cast(Optional[int], dim) + dim = get_positive_dim( + dim, + len(input.shape) + (1 if network.has_implicit_batch_dimension else 0), + ) + if network.has_implicit_batch_dimension: + assert dim != 0, "We don't support squeeze batch dim when it's implicit." + dim -= 1 + + assert input.shape[dim] != -1, "We don't support squeeze dynamic dim." + assert ( + len(get_dynamic_dims(input.shape)) <= 1 + ), "Currently more than one dynamic dim for input to squeeze is not supported." + + output_shape = [] + for i, s in enumerate(input.shape): + if (i in dims) and s == 1: + continue + output_shape.append(s) + layer = network.add_shuffle(input) + layer.reshape_dims = tuple(output_shape) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py new file mode 100644 index 0000000000..0ee1185850 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py @@ -0,0 +1,44 @@ +from typing import Optional + +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import set_layer_name + + +def convert_unary( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + operation_type: trt.UnaryOperation, + input_val: TRTTensor, +) -> TRTTensor: + """ + Add a TensorRT Unary layer to `network`. + + Args: + network (TRTNetwork): TensorRT network object. + input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor. + op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + + Returns: + The output of TensorRT Unary layer. + """ + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_unary(input_val, operation_type) + set_layer_name(layer, target, name, source_ir) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py new file mode 100644 index 0000000000..e0a255f800 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -0,0 +1,98 @@ +from typing import Optional + +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + +from torch_tensorrt.dynamo.conversion import SourceIR + + +from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( + convert_binary_elementwise, +) +from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary + + +def sign( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + """ + Sign is calculated as below: + x = input + sign = (exp(x) // exp(abs(x))) * 2 - 1 + For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0. + With multiply 2, the value become 2(for pos and 0) and 0(for neg). + Finally minus 1, the value become 1(for pos and 0) and -1(for neg). + + Args: + network (TRTNetwork): TensorRT network object. + target (Target): fx node target. + source_ir (SourceIR): Source IR calling the function + name (str): Name of the fx node with optional suffix. + input_val (TRTTensor): The input tensor. + + Returns: + A TensorRT tensor represent the result of sign operator. + """ + input_exp_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_exp", + trt.UnaryOperation.EXP, + input_val, + ) + input_abs_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_abs", + trt.UnaryOperation.ABS, + input_val, + ) + input_abs_exp_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_abs_exp", + trt.UnaryOperation.EXP, + input_abs_output, + ) + + floor_div_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_exp_floor_div", + trt.ElementWiseOperation.FLOOR_DIV, + input_exp_output, + input_abs_exp_output, + ) + + double_floor_div_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_floor_div*2", + trt.ElementWiseOperation.PROD, + floor_div_output, + 2, + ) + + return convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_sign", + trt.ElementWiseOperation.SUB, + double_floor_div_output, + 1, + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py new file mode 100644 index 0000000000..d1559ef324 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -0,0 +1,52 @@ +from typing import Optional, cast + +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + get_positive_dim, + get_trt_tensor, + set_layer_name, +) + +from torch_tensorrt.fx.utils import get_dynamic_dims + + +def unsqueeze( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_t, + dim, +) -> TRTTensor: + input_val = get_trt_tensor(network, input_t, f"{name}_input_t") + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"unsqueeze received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dim = cast(int, dim) + input_shape = input_val.shape + input_shape_size = ( + len(input_val.shape) + 1 + if network.has_implicit_batch_dimension + else len(input_val.shape) + ) + dim = get_positive_dim(dim, input_shape_size + 1) + + if network.has_implicit_batch_dimension: + assert dim != 0 + dim -= 1 + + assert ( + len(get_dynamic_dims(input_val.shape)) <= 1 + ), "Currently we don't support unsqueeze with more than one dynamic dims." + layer = network.add_shuffle(input_val) + layer.reshape_dims = ( + tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:] + ) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py b/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py index 6b72d87ff6..4293fb65eb 100644 --- a/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py @@ -13,7 +13,7 @@ from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata -from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS +from .converter_registry import DYNAMO_CONVERTERS as CONVERTERS from torch_tensorrt import Input from torch_tensorrt.fx.observer import Observer from torch_tensorrt.fx.utils import ( @@ -64,7 +64,15 @@ def __init__( + "\n".join(f"{i}" for i in missing_ops) ) - self.optimization_profiles: Optional[List] = None + self.optimization_profiles = ( + [self.builder.create_optimization_profile()] + if any( + input_spec.shape_mode == Input._ShapeMode.DYNAMIC + for input_spec in input_specs + ) + else None + ) + self.input_specs = input_specs self.input_specs_iter = 0 self._cur_node_name: Optional[str] = None @@ -257,7 +265,7 @@ def placeholder(self, target, args, kwargs): opt_shape = current_input.shape["opt_shape"] max_shape = current_input.shape["max_shape"] self.optimization_profiles[0].set_shape( - target, [min_shape, opt_shape, max_shape] + target, min_shape, opt_shape, max_shape ) assert len(min_shape) == len(opt_shape) == len(max_shape) for i in range(len(min_shape)): diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 7021b55518..d56a3a8616 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -46,6 +46,16 @@ def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: return torch.reciprocal(torch.sqrt(*args, **kwargs)) +@register_decomposition(aten._unsafe_view, registry=DECOMPOSITIONS) +def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return torch.reshape(x, *args, **kwargs) + + +@register_decomposition(torch.ops.aten.lift_fresh_copy, registry=DECOMPOSITIONS) +def lift_fresh_copy_replacement(x: torch.Tensor) -> torch.Tensor: + return x + + @register_decomposition(aten.alias, registry=DECOMPOSITIONS) def alias_replacement(x: torch.Tensor) -> torch.Tensor: return x @@ -60,5 +70,12 @@ def addmm_replacement( ) +@register_decomposition(torch.ops.aten.reciprocal.default, registry=DECOMPOSITIONS) +def reciprocal_replacement( + input_: torch.Tensor, +) -> torch.Tensor: + return torch.div(1, input_) + + def get_decompositions(): return DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/test_utils.py b/py/torch_tensorrt/dynamo/test_utils.py new file mode 100644 index 0000000000..a3d742c70a --- /dev/null +++ b/py/torch_tensorrt/dynamo/test_utils.py @@ -0,0 +1,310 @@ +import time +import unittest +import torch +import logging +from typing import Callable, List, Optional, Set, Tuple +from torch.testing._internal.common_utils import TestCase + +import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer +from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( + compose_bmm, + compose_chunk, + compose_getitem_slice, + remove_ops, + replace_aten_op_with_indices, + replace_aten_reshape_alias_with_replace, + replace_builtin_ops, + replace_native_layernorm_with_layernorm, + replace_transpose_mm_op_with_linear, + run_const_fold, +) +from torch.fx.passes.infra.pass_base import PassResult +from torch_tensorrt.fx.passes.pass_utils import chain_passes + +# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry +from torch_tensorrt.dynamo.conversion.trt_interpreter import TRTInterpreter +from torch_tensorrt.dynamo.runtime._PythonTorchTRTModule import PythonTorchTRTModule +from torch_tensorrt import Input + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def fetch_attr(mod, target): + """ + Fetch an attribute from the ``Module`` hierarchy of ``mod.module``. + + Args: + target (str): The fully-qualfiied name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split(".") + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available") +class TRTTestCase(TestCase): + def setUp(self): + super().setUp() + torch.manual_seed(3) + + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops, + interpreter, + rtol, + atol, + precision=torch.float, + check_dtype=True, + ): + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + if len(expected_ops): + self.assert_has_op(mod, expected_ops) + if unexpected_ops: + self.assert_unexpected_op(mod, unexpected_ops) + start = time.perf_counter() + interpreter_result = interpreter.run(precision=precision) + sec = time.perf_counter() - start + _LOGGER.info(f"Interpreter run time(s): {sec}") + trt_mod = PythonTorchTRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + + ref_outputs = mod(*inputs) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + outputs = trt_mod(*cuda_inputs) + end_event.record() + torch.cuda.synchronize() + _LOGGER.info( + f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}" + ) + + if type(outputs) not in (list, tuple): + outputs = [outputs] + if type(ref_outputs) not in ( + list, + tuple, + torch.return_types.max, + torch.return_types.min, + ): + ref_outputs = [ref_outputs] + for out, ref in zip(outputs, ref_outputs): + if not isinstance(ref, torch.Tensor): + ref = torch.tensor([ref]) + ref = ref.cpu() # to_dtype test has cases with gpu output + if ref.dtype == torch.int64: + ref = ref.int() # convert torch.max's index output tensor to int32 + torch.testing.assert_close( + out.cpu(), + ref, + rtol=rtol, + atol=atol, + equal_nan=True, + check_dtype=check_dtype, + ) + + def run_test_custom_compare_results( + self, + mod, + inputs, + expected_ops, + interpreter, + comparators: List[Tuple[Callable, List]], + fp16_mode=False, + ): + """ + Runs the test and compares the result using the provided comparators. + The size of comparators must be equal to the number of outputs from 'mod'. + + mod - a model to run. + inputs - a list of the model inputs. + expected ops - a list of ops that should be verified. + interpreter - used for converting the model to TRT. + comparators - a list of (func, args) pairs corresponding to each of + the module outputs. usage: func(x, y, *args) + + """ + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + if len(expected_ops): + self.assert_has_op(mod, expected_ops) + + interpreter_result = interpreter.run( + precision=torch.half if fp16_mode else torch.float + ) + trt_mod = PythonTorchTRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + res_trt = trt_mod(*cuda_inputs).cpu() + res_cpu = mod(*inputs) + assert len(res_trt) == len(res_cpu) + assert len(res_cpu) == len(comparators) + for output_trt, output_cpu, comparator in zip( + res_trt, res_cpu, comparators + ): + comp_func = comparator[0] + args = comparator[1] + self.assertTrue(comp_func(output_trt, output_cpu, *args)) + + def run_test_with_error(self, mod, inputs, interpreter, expect_error): + with self.assertRaises(expect_error): + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + interpreter.run(precision=torch.float) + + def assert_has_op(self, mod, ops): + ops_in_mod = set() + + for node in mod.graph.nodes: + if node.op == "call_module": + ops_in_mod.add(type(fetch_attr(mod, node.target))) + elif node.op in {"call_function", "call_method"}: + ops_in_mod.add(node.target) + + self.assertTrue( + ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}" + ) + + def assert_unexpected_op(self, mod, ops): + for node in mod.graph.nodes: + if node.op == "call_module": + if type(fetch_attr(mod, node.target)) in ops: + return False + elif node.op in {"call_function", "call_method"}: + if node.target in ops: + return False + return True + + +class DispatchTestCase(TRTTestCase): + def generate_graph( + self, + mod: torch.nn.Module, + original_inputs: List[torch.Tensor], + expected_ops: Set[Callable], + unexpected_ops: Optional[Set[Callable]] = None, + customized_passes: List[Callable] = None, + ): + # Torchdynamo+aot proxytensor tracer + # Below are common passes + passes_list = [ + compose_bmm, + compose_chunk, + compose_getitem_slice, + replace_aten_reshape_alias_with_replace, + replace_aten_op_with_indices, + replace_transpose_mm_op_with_linear, # after compose_bmm + replace_native_layernorm_with_layernorm, + remove_ops, + replace_builtin_ops, # after replace_native_layernorm_with_layernorm + ] + # Combine with customized passes specific to any model + if customized_passes: + passes_list.extend(customized_passes) + fx_module, _ = aten_tracer.trace(mod, original_inputs) + for passes in passes_list: + pr: PassResult = passes(fx_module) + fx_module = pr.graph_module + fx_module(*original_inputs) + + fx_module = run_const_fold(fx_module) + _LOGGER.info(f"FX graph= {fx_module.graph}") + + if len(expected_ops): + self.assert_has_op(fx_module, expected_ops) + if unexpected_ops: + self.assert_unexpected_op(fx_module, unexpected_ops) + + return fx_module + + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops=None, + apply_passes=None, + rtol=1e-03, + atol=1e-03, + precision=torch.float, + check_dtype=True, + ): + mod.eval() + mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + + if apply_passes is not None: + pass_tracer = chain_passes(*apply_passes) + mod = pass_tracer(mod, inputs) + + interp = TRTInterpreter( + mod, + Input.from_tensors(inputs), + ) + super().run_test( + mod, + inputs, + expected_ops, + unexpected_ops, + interp, + rtol, + atol, + precision, + check_dtype, + ) + + def run_test_with_dynamic_shape( + self, + mod, + input_specs, + expected_ops, + unexpected_ops=None, + rtol=1e-03, + atol=1e-03, + ): + mod.eval() + inputs = [spec.example_tensor("opt_shape") for spec in input_specs] + mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + + interp = TRTInterpreter( + mod, + input_specs, + ) + # Since the lowering is based on optimal shape. We need to test with + # different shape(for ex. max shape) for testing dynamic shape + inputs_max = [spec.example_tensor("max_shape") for spec in input_specs] + super().run_test( + mod, inputs_max, expected_ops, unexpected_ops, interp, rtol, atol + ) diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 282fcdbfd2..fc9ba5a232 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -222,7 +222,7 @@ def test_int64_input_partial_support(self): class PartiallySupportedMultiOp(torch.nn.Module): def forward(self, x, y): return torch.ops.aten.div.Tensor_mode( - x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor" + x, torch.ops.aten.add.Tensor(y, y), rounding_mode=None ) fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) diff --git a/tests/py/dynamo/backend/test_decompositions.py b/tests/py/dynamo/backend/test_decompositions.py index a9578e4ed8..0e11bfd2b1 100644 --- a/tests/py/dynamo/backend/test_decompositions.py +++ b/tests/py/dynamo/backend/test_decompositions.py @@ -78,15 +78,14 @@ def forward(self, x): return y # Operations expected to be removed in the traced graph after decompositions - expected_ops = {torch.ops.aten.sqrt.default, torch.ops.aten.reciprocal.default} - unexpected_ops = {torch.ops.aten.rsqrt.default} + expected_ops = {torch.ops.aten.sqrt.default, torch.ops.aten.div.Tensor} + unexpected_ops = { + torch.ops.aten.rsqrt.default, + torch.ops.aten.reciprocal.default, + } inputs = [ - torch.randint( - 1, - 10, - (5,), - ), + torch.randint(1, 10, (5,), dtype=torch.int32), ] fx_graph = torch.fx.symbolic_trace(Rsqrt()) @@ -182,6 +181,69 @@ def forward(self, x, y, z): f"AddMM TRT outputs don't match with the original model.", ) + def test_lowering_reciprocal(self): + class Reciprocal(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x): + y = torch.ops.aten.reciprocal.default(x) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.div.Tensor} + unexpected_ops = {torch.ops.aten.reciprocal.default} + + inputs = [ + torch.randn( + 5, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(Reciprocal()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Reciprocal TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/converters/test_adaptive_avgpool_aten.py b/tests/py/dynamo/converters/test_adaptive_avgpool_aten.py new file mode 100644 index 0000000000..68ce24c20f --- /dev/null +++ b/tests/py/dynamo/converters/test_adaptive_avgpool_aten.py @@ -0,0 +1,127 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestAdaptiveAvgPoolConverter(DispatchTestCase): + def test_adaptive_avgpool_mean(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 256, 256)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.mean.dim}, + ) + + @parameterized.expand( + [ + ((64, 64),), + ((128, 64),), + (64,), + ] + ) + def test_adaptive_avgpool( + self, + output_size, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d(output_size) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 256, 256)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten._adaptive_avg_pool2d.default}, + ) + + def test_adaptive_avgpool_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d((64, 64)) + + def forward(self, x): + return self.pool(x) + + input_specs = [ + Input( + shape=(-1, -1, 256, 256), + dtype=torch.float32, + shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten._adaptive_avg_pool2d.default}, + ) + + @parameterized.expand( + [ + ((16, 16, 16),), + ((32, 16, 4),), + (32,), + ] + ) + def test_adaptive_avgpool3d( + self, + output_size, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d(output_size) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 32, 64, 64)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten._adaptive_avg_pool3d.default}, + ) + + def test_adaptive_avgpool3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16)) + + def forward(self, x): + return self.pool(x) + + input_specs = [ + Input( + shape=(-1, -1, 32, 64, 64), + dtype=torch.float32, + shape_ranges=[ + ((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64)) + ], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten._adaptive_avg_pool3d.default}, + ) + + # Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_batchnorm_aten.py b/tests/py/dynamo/converters/test_batchnorm_aten.py new file mode 100644 index 0000000000..c39f14abfe --- /dev/null +++ b/tests/py/dynamo/converters/test_batchnorm_aten.py @@ -0,0 +1,66 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestBatchNormConverter(DispatchTestCase): + def test_batchnorm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + return self.bn(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.batch_norm}) + + def test_batchnorm1d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm1d(3) + + def forward(self, x): + return self.bn(x) + + input_specs = [ + Input( + shape=(-1, 3, 5), + dtype=torch.float32, + shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) + + def test_batchnorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + return self.bn(x) + + input_specs = [ + Input( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) + + # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm. + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_binary_ops_aten.py b/tests/py/dynamo/converters/test_binary_ops_aten.py new file mode 100644 index 0000000000..19fa02721c --- /dev/null +++ b/tests/py/dynamo/converters/test_binary_ops_aten.py @@ -0,0 +1,263 @@ +from typing import Callable +import unittest + +import torch +import torch.nn as nn + +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + +NEED_TEST_BOTH_CONSTANTS_CASE = True + +elementwise_ops = [ + ((lambda x, y: x + y), torch.ops.aten.add.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ( + (lambda x, y: torch.add(x, y)), + torch.ops.aten.add.Tensor, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ((lambda x, y: x.add(y)), torch.ops.aten.add.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: x - y), torch.ops.aten.sub.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: torch.sub(x, y)), torch.ops.aten.sub.Tensor, False), + ((lambda x, y: x.sub(y)), torch.ops.aten.sub.Tensor, False), + ((lambda x, y: x / y), torch.ops.aten.div.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ( + (lambda x, y: x // y), + torch.ops.aten.floor_divide.default, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y, rounding_mode="trunc")), + torch.ops.aten.div.Tensor_mode, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y, rounding_mode="floor")), + torch.ops.aten.div.Tensor_mode, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y)), + torch.ops.aten.div.Tensor, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.fmod(x, y)), + torch.ops.aten.fmod.Tensor, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ## torch.floor_divide rounds result toward zero, rather than -Inf. + ## https://github.com/pytorch/pytorch/issues/43874 + ( + (lambda x, y: torch.floor_divide(x, y)), + torch.ops.aten.floor_divide.default, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ((lambda x, y: x * y), torch.ops.aten.mul.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + (torch.pow, torch.ops.aten.pow.Tensor_Tensor, not NEED_TEST_BOTH_CONSTANTS_CASE), +] + + +class TestBinaryOpConverters(DispatchTestCase): + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + def test_elementwise_ops(self, name, orig_op: Callable, expected_op): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x, x) + + m = TestModule(orig_op) + # Avoid dividing by 0. + inputs = [torch.rand(1, 1) + 1] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + @unittest.skip("Pending reimplementation of all binary converters in Dynamo") + def test_elementwise_ops_mismatched_dtypes( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x, y): + return self.orig_op(x, y) + + m = TestModule(orig_op) + # Avoid dividing by 0. + inputs = [ + 2 * torch.rand(1, 1, dtype=torch.float) + 1, + torch.randint(1, 3, (1, 1), dtype=torch.int), + ] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + def test_elementwise_ops_with_one_constant( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant = torch.randn(1) + self.orig_op = orig_op + + def forward(self, x): + x = self.orig_op(x, self.constant) + return self.orig_op(x, -2) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2)] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand( + [(op[1].__name__, op[0], op[1]) for op in elementwise_ops if op[2]] + ) + def test_elementwise_op_with_both_constants( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant0 = torch.nn.Parameter(torch.randn(1)) + self.constant1 = torch.nn.Parameter(torch.randn(1)) + self.orig_op = orig_op + + def forward(self, x): + const = self.orig_op(self.constant0, self.constant1) + return self.orig_op(x, const) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2)] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand([((lambda x, y: x / y), torch.ops.aten.div.Tensor)]) + def test_elementwise_op_div_with_two_ints(self, orig_op: Callable, expected_op): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x, x + 1) + + m = TestModule(orig_op) + inputs = [torch.randint(1, 10, (5,), dtype=torch.int32)] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand([((lambda x, y: x / y), torch.ops.aten.div.Tensor)]) + def test_elementwise_op_div_with_one_int_one_constant( + self, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant1 = torch.nn.Parameter( + torch.randn( + 5, + ) + ) + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x, self.constant1) + + m = TestModule(orig_op) + inputs = [torch.randint(1, 10, (5,), dtype=torch.int32)] + self.run_test(m, inputs, expected_ops={expected_op}) + + # Dynamic shape test + @parameterized.expand( + [ + ( + f"no_broadcast_{op[1].__name__}", + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + op[0], + op[1], + ) + for op in elementwise_ops + ] + + [ + ( + f"broadcast_{op[1].__name__}", + (-1, -1, -1), + ((1, 1, 1), (2, 2, 2), (3, 3, 3)), + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + op[0], + op[1], + ) + for op in elementwise_ops + ] + ) + def test_elementwise_op_with_dynamic_shape( + self, _, x_shape, x_shape_ranges, y_shape, y_shape_ranges, orig_op, expected_op + ): + class Op(nn.Module): + def forward(self, x, y): + return orig_op(x, y) + + input_specs = [ + Input( + shape=x_shape, + dtype=torch.float32, + shape_ranges=[x_shape_ranges], + ), + Input( + shape=y_shape, + dtype=torch.float32, + shape_ranges=[y_shape_ranges], + ), + ] + self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + + @parameterized.expand( + [ + ( + f"no_broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + + [ + ( + f"broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + ) + def test_elementwise_op_with_dynamic_shape_four_dimensions( + self, _, orig_op, expected_op + ): + class Op(nn.Module): + def forward(self, x, y): + return orig_op(x, y) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_cat_aten.py b/tests/py/dynamo/converters/test_cat_aten.py new file mode 100644 index 0000000000..d9d107de89 --- /dev/null +++ b/tests/py/dynamo/converters/test_cat_aten.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestCatConverter(DispatchTestCase): + @parameterized.expand( + [ + ("pos", 1), + ("neg", -2), + ] + ) + def test_cat(self, _, dim): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.cat((x, y, z), dim) + + inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] + self.run_test( + Cat(), + inputs, + expected_ops={torch.ops.aten.cat.default}, + ) + + @parameterized.expand( + [ + ("pos", 1), + ("neg", -2), + ] + ) + def test_cat_dynamic_shape(self, _, dim): + class Cat(nn.Module): + def forward(self, x, y): + return torch.cat((x, y), dim) + + input_specs = [ + Input( + shape=(16, -1, 3), + dtype=torch.float32, + shape_ranges=[((16, 2, 3), (16, 3, 3), (16, 32, 3))], + ), + Input( + shape=(16, -1, 3), + dtype=torch.float32, + shape_ranges=[((16, 2, 3), (16, 16, 3), (16, 32, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Cat(), + input_specs, + expected_ops={torch.ops.aten.cat.default}, + ) + + def test_cat_no_dim(self): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.cat((x, y, z)) + + inputs = [torch.randn(2, 1, 3), torch.randn(1, 1, 3), torch.randn(3, 1, 3)] + self.run_test( + Cat(), + inputs, + expected_ops={torch.ops.aten.cat.default}, + ) + + def test_cat_dynamic_shape_no_dim(self): + class Cat(nn.Module): + def forward(self, x, y): + return torch.cat((x, y)) + + input_specs = [ + Input( + shape=(-1, 16, 3), + dtype=torch.float32, + shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))], + ), + Input( + shape=(-1, 16, 3), + dtype=torch.float32, + shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Cat(), + input_specs, + expected_ops={torch.ops.aten.cat.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_clamp_aten.py b/tests/py/dynamo/converters/test_clamp_aten.py new file mode 100644 index 0000000000..05716c1657 --- /dev/null +++ b/tests/py/dynamo/converters/test_clamp_aten.py @@ -0,0 +1,71 @@ +import torch +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestClampConverter(DispatchTestCase): + @parameterized.expand( + [ + param("default", min=-1, max=0), + param("min", min=0.5), + param("max", max=0.5), + param("minBiggerThanMax", min=1, max=0), + param("float32Boundary", min=-3.4028234663852886e38), + ] + ) + def test_clamp( + self, + test_name, + min=None, + max=None, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.clamp(x, min, max) + + inputs = [torch.randn(3, 4)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.clamp.default}) + + @parameterized.expand( + [ + param("default", min=-1, max=0), + param("min", min=0.5), + param("max", max=0.5), + param("minBiggerThanMax", min=1, max=0), + ] + ) + def test_clamp_with_dynamic_shape_four_dimensions( + self, + test_name, + min=None, + max=None, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.clamp(x, min, max) + + class TestScalarModule(torch.nn.Module): + def forward(self, x): + y = torch.mean(x) + return torch.clamp(y, min, max) + + input_specs = [ + Input( + shape=(-1, -1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.clamp.default} + ) + self.run_test_with_dynamic_shape( + TestScalarModule(), input_specs, expected_ops={torch.ops.aten.clamp.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_convolution_aten.py b/tests/py/dynamo/converters/test_convolution_aten.py new file mode 100644 index 0000000000..a906d70d43 --- /dev/null +++ b/tests/py/dynamo/converters/test_convolution_aten.py @@ -0,0 +1,203 @@ +import torch +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestConvolutionConverter(DispatchTestCase): + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1), (1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv1d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.convolution.default}, + ) + + def test_conv1d_with_dynamic_shape( + self, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + Input( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1), (1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv2d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, + 6, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} + ) + + # Testing with (-1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + + def test_conv2d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + Input( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + ## TODO TRT 8.4.1 will trigger issue with this test. T127981773 + # param("groups", 1, groups=3), + ] + ) + def test_conv3d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} + ) + + # Testing with (-1, -1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + + def test_conv3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + Input( + shape=(-1, 3, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_elu_aten.py b/tests/py/dynamo/converters/test_elu_aten.py new file mode 100644 index 0000000000..dfaf2db5a6 --- /dev/null +++ b/tests/py/dynamo/converters/test_elu_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestELUConverter(DispatchTestCase): + def test_elu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default}) + + def test_elu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + def test_elu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_embedding_aten.py b/tests/py/dynamo/converters/test_embedding_aten.py new file mode 100644 index 0000000000..4d36478303 --- /dev/null +++ b/tests/py/dynamo/converters/test_embedding_aten.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from parameterized import param, parameterized +from torch_tensorrt import Input + + +class TestEmbeddingConverter(DispatchTestCase): + @parameterized.expand( + [ + param( + test_name="1d_indices", + indices_tensor=torch.tensor([3, 1, 2]), + weights_tensor=torch.randn(5, 10), + ), + param( + test_name="2d_indices", + indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]]), + weights_tensor=torch.randn(5, 10), + ), + param( + test_name="3d_indices", + indices_tensor=torch.tensor([[[0, 1], [2, 3]], [[3, 4], [4, 0]]]), + weights_tensor=torch.randn(5, 10), + ), + ] + ) + def test_embedding( + self, + test_name, + indices_tensor, + weights_tensor, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + ): + class TestEmbedding(torch.nn.Module): + def forward(self, indices, weights): + return torch.nn.functional.embedding( + input=indices, + weight=weights, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + + self.run_test( + TestEmbedding(), + inputs=[indices_tensor.int(), weights_tensor.float()], + expected_ops={torch.ops.aten.embedding.default}, + ) + + def test_embedding_with_dynamic_shape_four_dimensions( + self, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + ): + class TestEmbedding(torch.nn.Module): + def forward(self, input, weights): + return torch.nn.functional.embedding( + input=input, + weight=weights, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.int, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + Input( + shape=(-1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1), (2, 3), (2, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestEmbedding(), + input_specs, + expected_ops={torch.ops.aten.embedding.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_expand_aten.py b/tests/py/dynamo/converters/test_expand_aten.py new file mode 100644 index 0000000000..1b1f3d1c14 --- /dev/null +++ b/tests/py/dynamo/converters/test_expand_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase + + +class TestExpandConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (2, 3), (2, 1)), + ("3d_dim", (2, 3, 4), (2, 1, 1)), + ("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)), + ("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)), + ] + ) + def test_expand(self, _, sizes, init_size): + class Expand(nn.Module): + def forward(self, x): + return x.expand(*sizes) + + inputs = [torch.randn(*init_size)] + self.run_test( + Expand(), + inputs, + expected_ops={torch.ops.aten.expand.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_gelu_aten.py b/tests/py/dynamo/converters/test_gelu_aten.py new file mode 100644 index 0000000000..c62a028c0e --- /dev/null +++ b/tests/py/dynamo/converters/test_gelu_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestGeLUConverter(DispatchTestCase): + def test_gelu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.gelu.default}) + + def test_gelu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + ) + + def test_gelu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_hardtanh_aten.py b/tests/py/dynamo/converters/test_hardtanh_aten.py new file mode 100644 index 0000000000..8401dd17a9 --- /dev/null +++ b/tests/py/dynamo/converters/test_hardtanh_aten.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestHardTanHConverter(DispatchTestCase): + def test_hardtanh(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + def test_hardtanh_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + def test_hardtanh_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_layer_norm_aten.py b/tests/py/dynamo/converters/test_layer_norm_aten.py new file mode 100644 index 0000000000..a4766bd030 --- /dev/null +++ b/tests/py/dynamo/converters/test_layer_norm_aten.py @@ -0,0 +1,45 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestLayerNormConverter(DispatchTestCase): + def test_layer_norm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default} + ) + + def test_layernorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + input_specs = [ + Input( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.layer_norm.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_leaky_relu_aten.py b/tests/py/dynamo/converters/test_leaky_relu_aten.py new file mode 100644 index 0000000000..aa3d56641b --- /dev/null +++ b/tests/py/dynamo/converters/test_leaky_relu_aten.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestLeakyReLUConverter(DispatchTestCase): + def test_leaky_relu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + def test_leaky_relu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + def test_leaky_relu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_linear_aten.py b/tests/py/dynamo/converters/test_linear_aten.py new file mode 100644 index 0000000000..b9e3261642 --- /dev/null +++ b/tests/py/dynamo/converters/test_linear_aten.py @@ -0,0 +1,71 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase + + +class TestLinearConverter(DispatchTestCase): + @parameterized.expand( + [ + ("default", [1, 512], True, torch.ops.aten.linear), + ("matrix", [5, 512], True, torch.ops.aten.linear), + ("no_bias", [1, 512], False, torch.ops.aten.linear), + ( + "multi_dim_matrix", + [4, 5, 512], + True, + torch.ops.aten.linear, + ), + ( + "multi_dim_matrix", + [4, 5, 512], + False, + torch.ops.aten.linear, + ), + ] + ) + def test_linear(self, test_name, shape, bias, op): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 256, bias) + + def forward(self, x): + return self.linear(x) + + inputs = [torch.randn(shape)] + self.run_test(TestModule(), inputs, expected_ops={op}) + + # linear will be decomposed to P531484488 and view(reshape) can not handle reshape pattern + # like (2, 3, n)->(6, n) in implicit mode which is similar to dynamic shape test below. + + # Input is transposed through view [3,3,512]->[9,512]. Converter does not know dim=0 is dynamic now. + + # def test_linear_with_dynamic_shape(self): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.linear = torch.nn.Linear(512, 256) + + # def forward(self, x): + # return self.linear(x) + + # input_specs = [ + # Input( + # shape=(-1, 3, 512), + # dtype=torch.float32, + # shape_ranges=[((1, 3, 512), (3, 3, 512), (4, 3, 512))], + # ), + # ] + # self.run_test_with_dynamic_shape( + # TestModule(), + # input_specs, + # expected_ops={torch.ops.aten.addmm.default}, + # ) + + ## Testing with (-1, -1, 512) results into following error: + ## AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_matmul_aten.py b/tests/py/dynamo/converters/test_matmul_aten.py new file mode 100644 index 0000000000..f01325fb10 --- /dev/null +++ b/tests/py/dynamo/converters/test_matmul_aten.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase + + +class TestMatMulConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + ("2_2", (2, 3), (3, 1)), + # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? + # (2,3), (3,) torch.ops.aten.mv.default + # Following cases use torch.ops.aten.bmm.defauly + # ("4_3", (3,1,3,2), (2,2,3)), + # ("3_4", (3,1,3,2), (2,2,3)), + # ("3_4", (2, 2, 3), (3, 1, 3, 3)), + # ("4_2", (1, 2, 2, 3), (3, 2)), + ] + ) + def test_matmul_other_constant(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def __init__(self): + super().__init__() + self.other = nn.Parameter(torch.randn(*other_shape)) + + def forward(self, input): + return torch.matmul(input, self.other) + + inputs = [torch.randn(*input_shape)] + + self.run_test( + MatMul(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + ) + + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + ("1_2", (1, 3), (3, 2)), + # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? + # (2,3), (3,) torch.ops.aten.mv.default + # Following cases use torch.ops.aten.bmm.defauly + # ("4_3", (3,1,3,2), (2,2,3)), + # ("3_4", (3,1,3,2), (2,2,3)), + # ("3_4", (2, 2, 3), (3, 1, 3, 3)), + # ("4_2", (1, 2, 2, 3), (3, 2)), + ] + ) + def test_matmul_input_constant(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def __init__(self): + super().__init__() + self.input = nn.Parameter(torch.randn(*input_shape)) + + def forward(self, other): + return torch.matmul(self.input, other) + + inputs = [torch.randn(*other_shape)] + + self.run_test( + MatMul(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + ) + + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + # ("2_3", (2, 3), (2, 3, 4)), + # ("4_4", (2, 2, 2, 3), (2, 1, 3, 2)), + # ("4_2", (2, 1, 2, 3), (3, 2)), + # ("2_1", (2, 3), (3,)), + # ("1_2", (3,), (3, 2)), + # ("1_1", (3,), (3,)), + ] + ) + def test_matmul(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def forward(self, input, other): + return torch.matmul(input, other) + + inputs = [torch.randn(*input_shape), torch.randn(*other_shape)] + + self.run_test( + MatMul(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + ) + + # FIXME: dynamic shape is giving bmm + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_mean_aten.py b/tests/py/dynamo/converters/test_mean_aten.py new file mode 100644 index 0000000000..fe31d90a24 --- /dev/null +++ b/tests/py/dynamo/converters/test_mean_aten.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestMeanDimConverter(DispatchTestCase): + def test_mean_dim_keepdims(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=[0, 1], keepdim=True) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim}) + + def test_mean_dim_keepdims_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=[0, 1, 2], keepdim=True) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim} + ) + + def test_mean_dim_keepdims_false(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=0, keepdim=False) + + inputs = [torch.randn(3, 5, 7)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim}) + + def test_mean_dim_keepdims_false_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=-1, keepdim=False) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim} + ) + + +class TestMeanConverter(DispatchTestCase): + def test_mean(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x) + + inputs = [torch.randn(3, 8, 5, 7, 1)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.default}) + + def test_mean_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 5, 8), (3, 10, 10))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.mean.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_permutation_aten.py b/tests/py/dynamo/converters/test_permutation_aten.py new file mode 100644 index 0000000000..f9d614ae68 --- /dev/null +++ b/tests/py/dynamo/converters/test_permutation_aten.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestPermuteConverter(DispatchTestCase): + @parameterized.expand( + [ + ("positive", [0, 2, 1]), + ("negative", [0, -1, -2]), + ] + ) + def test_permute_list(self, _, permutation): + class Permute(nn.Module): + def forward(self, x): + return x.permute(permutation) + + inputs = [torch.randn(1, 3, 2)] + self.run_test(Permute(), inputs, expected_ops={torch.ops.aten.permute.default}) + + @parameterized.expand( + [ + ("positive", [0, 2, 1]), + ("negative", [0, -1, -2]), + ] + ) + def test_permute(self, _, permutation): + class Permute(nn.Module): + def forward(self, x): + return x.permute(*permutation) + + inputs = [torch.randn(1, 3, 2)] + self.run_test(Permute(), inputs, expected_ops={torch.ops.aten.permute.default}) + + def test_permute_with_dynamic_shape(self): + class Permute(nn.Module): + def forward(self, x): + return x.permute(1, 2, 0) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Permute(), input_specs, expected_ops={torch.ops.aten.permute.default} + ) + + def test_permute_with_dynamic_shape_four_dimensions(self): + class Permute(nn.Module): + def forward(self, x): + return x.permute(1, 2, 3, 0) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + Permute(), input_specs, expected_ops={torch.ops.aten.permute.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_relu_aten.py b/tests/py/dynamo/converters/test_relu_aten.py new file mode 100644 index 0000000000..08ab04014d --- /dev/null +++ b/tests/py/dynamo/converters/test_relu_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestReLUConverter(DispatchTestCase): + def test_relu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.relu.default}) + + def test_relu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.relu.default} + ) + + def test_relu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.relu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_reshape_aten.py b/tests/py/dynamo/converters/test_reshape_aten.py new file mode 100644 index 0000000000..1df71abc1a --- /dev/null +++ b/tests/py/dynamo/converters/test_reshape_aten.py @@ -0,0 +1,103 @@ +import unittest + +import tensorrt as trt +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestReshapeConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1, 20),), + ((1, 10, -1),), + ] + ) + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) + def test_reshape(self, target_shape): + class TestModule(torch.nn.Module): + def __init__(self, target_shape): + super().__init__() + self.target_shape = target_shape + + def forward(self, x): + return torch.reshape(x, self.target_shape) + + inputs = [torch.randn(1, 2, 10)] + self.run_test( + TestModule(target_shape), + inputs, + expected_ops={torch.ops.aten.view.default}, + ) + + @parameterized.expand( + [ + ((-1, 10),), + ((-1, 5),), + ((2, 2, -1),), + ] + ) + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) + def test_reshape_with_dynamic_shape(self, target_shape): + class TestModule(torch.nn.Module): + def __init__(self, target_shape): + super().__init__() + self.target_shape = target_shape + + def forward(self, x): + return torch.reshape(x, self.target_shape) + + input_specs = [ + Input( + shape=(-1, 2, 5), + dtype=torch.float32, + shape_ranges=[((1, 2, 5), (10, 2, 5), (10, 2, 5))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(target_shape), + input_specs, + expected_ops={torch.ops.aten.view.default}, + ) + + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) + def test_reshape_with_dynamic_shape_size(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + shape_y = y.shape + t = shape_y[1] + return torch.reshape(x, [-1, t, 3]) + + input_specs = [ + Input( + shape=(-1, 5, 6), + dtype=torch.float32, + shape_ranges=[((1, 5, 6), (3, 5, 6), (3, 5, 6))], + ), + Input( + shape=(-1, 5), + dtype=torch.float32, + shape_ranges=[((1, 5), (3, 5), (3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.view.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_rsqrt_aten.py b/tests/py/dynamo/converters/test_rsqrt_aten.py new file mode 100644 index 0000000000..5770e697fc --- /dev/null +++ b/tests/py/dynamo/converters/test_rsqrt_aten.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestRSqrtConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsqrt(self, _, x, alpha): + class rsqrt(nn.Module): + def forward(self, input): + return torch.rsqrt(input) + + inputs = [torch.randn(x) + 1] + self.run_test( + rsqrt(), + inputs, + expected_ops={torch.ops.aten.rsqrt.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_select_aten.py b/tests/py/dynamo/converters/test_select_aten.py new file mode 100644 index 0000000000..049cd9c7e6 --- /dev/null +++ b/tests/py/dynamo/converters/test_select_aten.py @@ -0,0 +1,79 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSelectConverterOne(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 1, 0), + ] + ) + def test_select(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.select(input, dim, index) + + input = [torch.randn(1, 2)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.select.int}, + ) + + +class TestSelectConverterTwo(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 1, 0), + ] + ) + def test_select(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.select(input, dim, index) + + input = [torch.randn(4, 4, 4, 4)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.select.int}, + ) + + +class TestSelectConverterWithDynamicShape(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 1, 0), + ] + ) + def test_select_with_dynamic_shape(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.select(input, dim, index) + + input_spec = [ + Input( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_spec, expected_ops={torch.ops.aten.select.int} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_selu_aten.py b/tests/py/dynamo/converters/test_selu_aten.py new file mode 100644 index 0000000000..7fb6afda76 --- /dev/null +++ b/tests/py/dynamo/converters/test_selu_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSeLUConverter(DispatchTestCase): + def test_selu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default}) + + def test_selu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + def test_selu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_sigmoid_aten.py b/tests/py/dynamo/converters/test_sigmoid_aten.py new file mode 100644 index 0000000000..37bbea1730 --- /dev/null +++ b/tests/py/dynamo/converters/test_sigmoid_aten.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSigmoidConverter(DispatchTestCase): + def test_sigmoid(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + def test_sigmoid_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + def test_sigmoid_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + def test_sigmoid_fp16(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.sigmoid.default}, + precision=torch.half, + check_dtype=False, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_slice_aten.py b/tests/py/dynamo/converters/test_slice_aten.py new file mode 100644 index 0000000000..86de36d351 --- /dev/null +++ b/tests/py/dynamo/converters/test_slice_aten.py @@ -0,0 +1,86 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSelectConverterImplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 0, 0, 7, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input = [torch.randn(10, 2, 3, 1)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.slice.Tensor}, + ) + + +class TestSelectConverterExplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 1, 0, 7, 2), + ("select_dim_start_stop_step_exact", 1, 0, 10, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input = [torch.randn(10, 10, 3, 1)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.slice.Tensor}, + ) + + +class TestSelectConverterDynamicShape(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 1, 0, 7, 2), + ("select_dim_start_stop_step", 1, 0, 10, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input_specs = [ + Input( + shape=(1, 10, -1), + dtype=torch.float32, + shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.slice.Tensor}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_softmax_aten.py b/tests/py/dynamo/converters/test_softmax_aten.py new file mode 100644 index 0000000000..8d33f3ebe0 --- /dev/null +++ b/tests/py/dynamo/converters/test_softmax_aten.py @@ -0,0 +1,45 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSoftMaxConverter(DispatchTestCase): + def test_softmax(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(1) + + def forward(self, x): + return self.softmax(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten._softmax.default} + ) + + def test_softmax_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(2) + + def forward(self, x): + return self.softmax(x) + + input_specs = [ + Input( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_squeeze_aten.py b/tests/py/dynamo/converters/test_squeeze_aten.py new file mode 100644 index 0000000000..152fe86300 --- /dev/null +++ b/tests/py/dynamo/converters/test_squeeze_aten.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (0), (2, 1)), + ("3d_one_dim", (0), (2, 2, 1)), + ("3d_two_dim", (0, 1), (2, 1, 1)), + ("4d_dim", (0, 1, 2), (2, 2, 1, 1)), + ] + ) + def test_squeeze(self, _, dim, init_size): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + inputs = [torch.randn(*init_size)] + expected_op = {} + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + self.run_test( + Squeeze(), + inputs, + expected_ops=expected_op, + ) + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]), + ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), + # ("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), + ] + ) + def test_squeeze(self, _, dim, init_size, shape_range): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + input_specs = [ + Input( + shape=init_size, + dtype=torch.float32, + shape_ranges=shape_range, + ), + ] + self.run_test_with_dynamic_shape( + Squeeze(), + input_specs, + expected_ops=expected_op, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_tanh_aten.py b/tests/py/dynamo/converters/test_tanh_aten.py new file mode 100644 index 0000000000..f9aa94a7bc --- /dev/null +++ b/tests/py/dynamo/converters/test_tanh_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestTanhConverter(DispatchTestCase): + def test_tanh(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.tanh.default}) + + def test_tanh_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default} + ) + + def test_tanh_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_unsqueeze_aten.py b/tests/py/dynamo/converters/test_unsqueeze_aten.py new file mode 100644 index 0000000000..db8ae7151f --- /dev/null +++ b/tests/py/dynamo/converters/test_unsqueeze_aten.py @@ -0,0 +1,62 @@ +import torch +import torch.fx +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestUnsqueeze(DispatchTestCase): + @parameterized.expand( + [ + ("negative_dim", -2), + ("positive_dim", 2), + ] + ) + def test_unsqueeze(self, _, dim): + class Unsqueeze(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.unsqueeze(x, self.dim) + + inputs = [torch.randn(1, 2, 3)] + self.run_test( + Unsqueeze(dim), inputs, expected_ops={torch.ops.aten.unsqueeze.default} + ) + + # Testing with more than one dynamic dims results in following error: + # AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. + + @parameterized.expand( + [ + ("negative_dim_dynamic", -4), + ("positive_dim_dynamic", 1), + ] + ) + def test_unsqueeze_with_dynamic_shape(self, _, dim): + class Unsqueeze(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.unsqueeze(x, self.dim) + + input_specs = [ + Input( + shape=(-1, 2, 3), + dtype=torch.float32, + shape_ranges=[((1, 2, 3), (2, 2, 3), (3, 2, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Unsqueeze(dim), input_specs, expected_ops={torch.ops.aten.unsqueeze.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_where_aten.py b/tests/py/dynamo/converters/test_where_aten.py new file mode 100644 index 0000000000..39ba0500b9 --- /dev/null +++ b/tests/py/dynamo/converters/test_where_aten.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase + + +class TestWhereConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_condition_xshape_yshape", (2, 2), (2, 2)), + ("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)), + ("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)), + ("2d_3d_condition_xshape_yshape", (2, 2), (1, 2, 2)), + ] + ) + def test_(self, _, x_size, y_size): + class Where(nn.Module): + def forward(self, condition, x, y): + return torch.where(condition, x, y) + + inputX = torch.randn(*x_size) + inputOther = torch.randn(*y_size) + condition = inputX < 0 + self.run_test( + Where(), + (condition, inputX, inputOther), + expected_ops={torch.ops.aten.where.self}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index 1141d54a7b..0fdfcb3fd0 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -32,7 +32,7 @@ def test_resnet18(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 8, + "min_block_size": 10, "ir": "torch_compile", } @@ -66,7 +66,7 @@ def test_mobilenet_v2(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 8, + "min_block_size": 10, "ir": "torch_compile", } @@ -100,7 +100,7 @@ def test_efficientnet_b0(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 8, + "min_block_size": 10, "ir": "torch_compile", } @@ -143,7 +143,7 @@ def test_bert_base_uncased(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 8, + "min_block_size": 10, "ir": "torch_compile", } trt_mod = torchtrt.compile(model, **compile_spec) @@ -181,7 +181,7 @@ def test_resnet18_half(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 8, + "min_block_size": 10, "ir": "torch_compile", }