From e611e701173f53fc97b5a006333f2342bdab6add Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 9 May 2023 14:21:02 -0700 Subject: [PATCH 1/2] Converter reorg and adding rsqrt converter --- .../fx/converters/aten_ops_converters.py | 37 +++++++++++++++++++ .../fx/converters/impl/elementwise/ops.py | 30 +++++++++++++++ .../converters/aten_op/test_rsqrt_aten.py | 29 +++++++++++++++ 3 files changed, 96 insertions(+) create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 1f9ee6fa42..1fd42f4bce 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -24,6 +24,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils from torch_tensorrt.fx.converters.impl import activation from torch_tensorrt.fx.converters.impl.elementwise import trunc_div +from torch_tensorrt.fx.converters.impl.elementwise import rsqrt _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -300,6 +301,42 @@ def aten_ops_relu( ) +@tensorrt_converter(torch.ops.aten.relu.default) +def aten_ops_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return activation.relu( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@tensorrt_converter(torch.ops.aten.rsqrt.default) +def aten_ops_rsqrt( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return rsqrt( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @tensorrt_converter(torch.ops.aten.sub.Tensor) def aten_ops_sub( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py index ae44ce838c..67f2db4cf1 100644 --- a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py @@ -109,3 +109,33 @@ def trunc_div( ) return output + + +def rsqrt( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + + sqrt_trt_output = convert_unary( + network, + target, + source_ir, + f"{name}"_sqrt, + trt.UnaryOperation.SQRT, + input, + ) + + output = convert_binary_elementwise( + network, + 1, + sqrt_trt_output, + trt.ElementWiseOperation.DIV, + target, + f"{name}_outpur", + ) + + return output \ No newline at end of file diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py new file mode 100644 index 0000000000..5effe38d8a --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSubConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsqrt(self, _, x, alpha): + class rsqrt(nn.Module): + def forward(self, input): + return torch.rsqrt(input) + + inputs = [torch.randn(x) + 1] + self.run_test( + rsqrt(), + inputs, + expected_ops={torch.ops.aten.rsqrt.default}, + ) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file From e431d701458db6840aa732aa0ce04829f6312d0d Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 12 May 2023 00:07:25 -0700 Subject: [PATCH 2/2] Rsqrt and linting error --- .../fx/converters/aten_ops_converters.py | 8 ++++---- .../fx/converters/impl/elementwise/ops.py | 16 ++++++++-------- .../test/converters/aten_op/test_rsqrt_aten.py | 4 ++-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 1fd42f4bce..8c8492158f 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -327,11 +327,11 @@ def aten_ops_rsqrt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - + return rsqrt( - network, - target, - SourceIR.ATEN, + network, + target, + SourceIR.ATEN, name, args[0], ) diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py index 67f2db4cf1..8fddb426a6 100644 --- a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py @@ -117,25 +117,25 @@ def rsqrt( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - other: TRTTensor, ) -> TRTTensor: - + sqrt_trt_output = convert_unary( network, target, source_ir, - f"{name}"_sqrt, + f"{name}_sqrt", trt.UnaryOperation.SQRT, input, ) output = convert_binary_elementwise( network, + target, + source_ir, + f"{name}_output", + trt.ElementWiseOperation.DIV, 1, sqrt_trt_output, - trt.ElementWiseOperation.DIV, - target, - f"{name}_outpur", ) - - return output \ No newline at end of file + + return output diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py index 5effe38d8a..3fa27af1a0 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py @@ -5,7 +5,7 @@ from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec -class TestRSubConverter(DispatchTestCase): +class TestRSqrtConverter(DispatchTestCase): @parameterized.expand( [ ("2d_dim_alpha", (2, 1), 2), @@ -26,4 +26,4 @@ def forward(self, input): if __name__ == "__main__": - run_tests() \ No newline at end of file + run_tests()