diff --git a/.circleci/config.yml b/.circleci/config.yml index 64a9c9f3b9..0223e697dc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -47,7 +47,7 @@ jobs: command: | pip3 install nvidia-pyindex pip3 install nvidia-tensorrt==8.2.4.2 - pip3 install --pre torch==1.13.0.dev20220618 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu113 + pip3 install --pre torch==1.13.0.dev20220621 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu113 pip3 install pytest parameterized expecttest # install torch_tensorrt mv WORKSPACE.ci WORKSPACE diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index ef757bf3d1..49765f4fd3 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -92,10 +92,12 @@ def acc_ops_conv1d( kernel=weight, bias=bias, ) - padding = kwargs["padding"] - padding = padding + (0,) - stride = extend_attr_to_tuple(kwargs["stride"], 1) - dilation = extend_attr_to_tuple(kwargs["dilation"], 1) + # expand params to 2d for computation + padding = list(kwargs["padding"]) + padding.append(0) + stride = extend_attr_to_tuple(kwargs["stride"], 2) + dilation = extend_attr_to_tuple(kwargs["dilation"], 2) + set_layer_name(layer, target, name) layer.stride_nd = stride layer.padding_nd = padding diff --git a/py/torch_tensorrt/fx/converters/convolution.py b/py/torch_tensorrt/fx/converters/convolution.py index 94c1ce24f1..5228616219 100644 --- a/py/torch_tensorrt/fx/converters/convolution.py +++ b/py/torch_tensorrt/fx/converters/convolution.py @@ -32,13 +32,13 @@ def common_conv(network, mod, dimension, input_val, layer_name, is_quantized): unsqueeze_layer.name = f"{layer_name}_unsqueeze" input_val = unsqueeze_layer.get_output(0) - padding = padding + (0,) kernel = np.expand_dims(kernel, -1) kernel_size = kernel.shape[2:] if bias is not None: bias = bias[None] - # bias = np.expand_dims(bias, -1) - + stride = (stride[0], 1) + padding = (padding[0], 0) + dilation = (dilation[0], 1) layer = network.add_convolution_nd( input=input_val, num_output_maps=mod.out_channels, diff --git a/py/torch_tensorrt/fx/input_tensor_spec.py b/py/torch_tensorrt/fx/input_tensor_spec.py index 782d4ab5a8..79c572d9b6 100644 --- a/py/torch_tensorrt/fx/input_tensor_spec.py +++ b/py/torch_tensorrt/fx/input_tensor_spec.py @@ -66,7 +66,10 @@ def from_tensors(cls, tensors: Sequence[torch.Tensor]) -> List["InputTensorSpec" @classmethod def from_tensors_with_dynamic_batch_size( - cls, tensors: Sequence[torch.Tensor], batch_size_range: Tuple[int, int, int] + cls, + tensors: Sequence[torch.Tensor], + batch_size_range: Tuple[int, int, int], + opt_profile_replica: int = 1, ) -> List["InputTensorSpec"]: """ Produce a list of InputTenosrSpec named tuples which would contain @@ -93,7 +96,7 @@ def from_tensors_with_dynamic_batch_size( ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}." shape = list(tensor.shape) shape[0] = -1 - shape_ranges: List[ShapeRange] = [tuple(tuple([bs] + shape[1:]) for bs in batch_size_range)] # type: ignore[list-item] + shape_ranges: List[ShapeRange] = [tuple(tuple([bs] + shape[1:]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] input_specs.append( cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges) ) diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 318136be56..10b56f31b4 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -100,6 +100,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: self.lower_setting.max_batch_size, self.lower_setting.max_batch_size, ), + self.lower_setting.opt_profile_replica, ) if self.lower_setting.explicit_batch_dimension else InputTensorSpec.from_tensors(input) diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index b1a32c2cff..78d4e3a2e9 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -69,6 +69,8 @@ class LowerSetting(LowerSettingBasic): how presets are applied. Refer to `caffe2.torch.fb.model_transform.fx2trt.presets.ESUHMLowererPreset` on how to add a preset. + opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is + only used by explicit batch dim with dynamic shape mode. """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -86,3 +88,4 @@ class LowerSetting(LowerSettingBasic): save_timing_cache: bool = False cuda_graph_batch_size: int = -1 preset_lowerer: str = "" + opt_profile_replica: int = 1 diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index 4394ca97b4..6dc2e86f22 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -31,9 +31,7 @@ def skip_folding_quant_dequant(node: torch.fx.Node): return True return False - const_split_mod = split_const_subgraphs( - traced_mod, skip_folding_quant_dequant, device_for_folded_attrs="cuda" - ) + const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant) const_split_mod.run_folding() return const_split_mod diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_adaptive_avgpool.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_adaptive_avgpool.py index 8f053b9bc9..3b60c551df 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_adaptive_avgpool.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_adaptive_avgpool.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestAdaptiveAvgPoolConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_any.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_any.py index 08e676dbcb..5ef417605d 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_any.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_any.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestAnyConverters(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py index 50b5eaa392..dc014a7e6c 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py index 7fabb45ffd..ca69de7afa 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestAvgPoolConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py index 65cdfb3a17..5786f2ecba 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py @@ -1,7 +1,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestBatchNormConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py index b0bed3dfd7..56a37b04b0 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py @@ -5,8 +5,8 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec NEED_TEST_BOTH_CONSTANTS_CASE = True diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py index 485fae5589..9408c5f6bc 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestCatConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_chunk.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_chunk.py index 39a86fc497..f1bf53dc07 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_chunk.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_chunk.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestChunkConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py index ad3c70156b..0309cf3e3f 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestClampConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py index 5706b1c2c5..d75afeef3c 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestConvolutionConverter(AccTestCase): @@ -44,111 +44,111 @@ def forward(self, x): test_explicit_precision=True, ) - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1, 1), (1, 1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) - def test_conv2d( - self, - _, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 32, 32)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv2d}) - - def test_conv2d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 1) - - def forward(self, x): - return self.conv(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.conv2d} - ) - - @parameterized.expand( - [ - ("default", 1), - param("no_bias", 1, bias=False), - ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), - param("non_zero_padding", 1, padding=1), - param("dilation", 1, dilation=2), - param("groups", 1, groups=3), - ] - ) - def test_conv3d( - self, - _, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv3d( - 3, 6, kernel_size, stride, padding, dilation, groups, bias - ) - - def forward(self, x): - return self.conv(x) - - inputs = [torch.randn(1, 3, 32, 32, 32)] - self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv3d}) - - def test_conv3d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv3d(3, 6, 1) - - def forward(self, x): - return self.conv(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.conv3d} - ) + # @parameterized.expand( + # [ + # ("default", 1), + # param("no_bias", 1, bias=False), + # ("tuple_parameters", 1, (1, 1), (1, 1)), + # param("non_zero_padding", 1, padding=1), + # param("dilation", 1, dilation=2), + # param("groups", 1, groups=3), + # ] + # ) + # def test_conv2d( + # self, + # _, + # kernel_size, + # stride=1, + # padding=0, + # dilation=1, + # groups=1, + # bias=True, + # ): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.conv = torch.nn.Conv2d( + # 3, 6, kernel_size, stride, padding, dilation, groups, bias + # ) + + # def forward(self, x): + # return self.conv(x) + + # inputs = [torch.randn(1, 3, 32, 32)] + # self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv2d}) + + # def test_conv2d_with_dynamic_shape(self): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.conv = torch.nn.Conv2d(3, 6, 1) + + # def forward(self, x): + # return self.conv(x) + + # input_specs = [ + # InputTensorSpec( + # shape=(-1, 3, -1, -1), + # dtype=torch.float32, + # shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], + # ), + # ] + # self.run_test_with_dynamic_shape( + # TestModule(), input_specs, expected_ops={acc_ops.conv2d} + # ) + + # @parameterized.expand( + # [ + # ("default", 1), + # param("no_bias", 1, bias=False), + # ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), + # param("non_zero_padding", 1, padding=1), + # param("dilation", 1, dilation=2), + # param("groups", 1, groups=3), + # ] + # ) + # def test_conv3d( + # self, + # _, + # kernel_size, + # stride=1, + # padding=0, + # dilation=1, + # groups=1, + # bias=True, + # ): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.conv = torch.nn.Conv3d( + # 3, 6, kernel_size, stride, padding, dilation, groups, bias + # ) + + # def forward(self, x): + # return self.conv(x) + + # inputs = [torch.randn(1, 3, 32, 32, 32)] + # self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv3d}) + + # def test_conv3d_with_dynamic_shape(self): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.conv = torch.nn.Conv3d(3, 6, 1) + + # def forward(self, x): + # return self.conv(x) + + # input_specs = [ + # InputTensorSpec( + # shape=(-1, 3, -1, -1, -1), + # dtype=torch.float32, + # shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], + # ), + # ] + # self.run_test_with_dynamic_shape( + # TestModule(), input_specs, expected_ops={acc_ops.conv3d} + # ) if __name__ == "__main__": diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_dequantize.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_dequantize.py index 9d4a98e28c..9d9a8e4c66 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_dequantize.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_dequantize.py @@ -5,8 +5,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec @unittest.skip( diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_einsum.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_einsum.py index 84923e32fa..cd28becdca 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_einsum.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_einsum.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_elu.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_elu.py index de3c1d3400..b42df203c1 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_elu.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_elu.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestELUConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_embedding.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_embedding.py index e0e6843bb2..f75620801c 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_embedding.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_embedding.py @@ -4,8 +4,8 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase @unittest.skip( diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_eq.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_eq.py index c459a382b0..f74a70e614 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_eq.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_eq.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestEqConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_expand.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_expand.py index 53143102aa..49a16d9e1d 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_expand.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_expand.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestExpandConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_flatten.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_flatten.py index f9ce31b3cc..346669d695 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_flatten.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_flatten.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestFlattenConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py index fbe415d1e3..0f0e069841 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestGELU(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_getitem.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_getitem.py index 547caa4866..484d8d5622 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_getitem.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_getitem.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestGetitemConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_gt.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_gt.py index c1e01295e4..0e8be9c311 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_gt.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_gt.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestGtConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_hard_sigmoid.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_hard_sigmoid.py index 6e41caf279..b5d27db5cd 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_hard_sigmoid.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_hard_sigmoid.py @@ -1,8 +1,8 @@ import torch from parameterized import parameterized from torch import nn -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec from torch_tensorrt.fx.tracer.acc_tracer import acc_ops diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_hardtanh.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_hardtanh.py index 6c1c56796a..f4d6ec01d4 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_hardtanh.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_hardtanh.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestHardtanhConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py index 4046e82bae..b97cacf7d2 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_interpolate.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestInterpolateConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_isinf.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_isinf.py index 21ab076c9b..e75ec16c35 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_isinf.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_isinf.py @@ -3,8 +3,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase @unittest.skip("Implementation is commented out due to accuracy issue T113156424") diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_leaky_relu.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_leaky_relu.py index 7347796e23..5cc1ad4294 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_leaky_relu.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_leaky_relu.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestLeakyReLUConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_linear.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_linear.py index 9f369de543..4841e77bf4 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_linear.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_linear.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestLinearConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_and.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_and.py index 51915127b5..dac1a5da1a 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_and.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_and.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestAndMethodSimpleConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_or.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_or.py index 49599fb181..aaf5879fa8 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_or.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_or.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestLogicalOrMethodSimpleConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_xor.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_xor.py index 44ab745dd5..d2c459cf84 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_xor.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_logical_xor.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestLogicalXorMethodSimpleConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_lt.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_lt.py index 4e1490ce52..c2edffc3ec 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_lt.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_lt.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestLtConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_masked_fill.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_masked_fill.py index aa105f512c..337938fc5a 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_masked_fill.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_masked_fill.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestMaskedFill(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py index f6ed85c156..9da0161f3f 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_matmul.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestMatMulConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_max.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_max.py index 45c9224d50..711939a6c1 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_max.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_max.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestMaxConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_maximum.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_maximum.py index 1fb18319c9..ad9a355063 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_maximum.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_maximum.py @@ -1,7 +1,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestMaximumConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_maxpool.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_maxpool.py index 274fe27d12..8d54b43184 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_maxpool.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_maxpool.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestMaxPoolConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_min.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_min.py index 317f71d3c0..cac8d5778c 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_min.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_min.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestMinConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_minimum.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_minimum.py index 0e9af3f302..1737fd766b 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_minimum.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_minimum.py @@ -1,7 +1,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestMinimumConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py index 8422a88ad7..6c212e4911 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_narrow.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestNarrowConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_ne.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_ne.py index 37b2432e72..a65ef8f724 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_ne.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_ne.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestNeFunctionConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_new_ones.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_new_ones.py index 2ff64f55c2..f39357998b 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_new_ones.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_new_ones.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestNewOnesConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_numel.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_numel.py index 8b30dc9eb9..d074852448 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_numel.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_numel.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestNumelConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_pad.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_pad.py index edfd849790..7a9e9544c3 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_pad.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_pad.py @@ -6,8 +6,8 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestPadConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_permute.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_permute.py index 2d8d511749..9916da6953 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_permute.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_permute.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestPermuteConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py index 9f112a8a91..b3ad1bcd12 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase # NOTE torch.prod will only accept one dim unlike other reduce ops which accept tuples diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_quantize_per_tensor.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_quantize_per_tensor.py index c7084e17fd..2f4758837d 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_quantize_per_tensor.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_quantize_per_tensor.py @@ -5,8 +5,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec @unittest.skip( diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py index 1d24d332d4..ae9932fd61 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_reduce_ops.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase reduce_ops = [(torch.sum, acc_ops.sum), (torch.mean, acc_ops.mean)] diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_relu.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_relu.py index 5757e1a138..2d89d5026b 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_relu.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_relu.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestReLUConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_repeat_interleave.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_repeat_interleave.py index 6dc6d0d108..e933146441 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_repeat_interleave.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_repeat_interleave.py @@ -2,8 +2,8 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch import nn -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestRepeatInterLeave(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py index fd62c4cebd..86e12c18c3 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestReshapeConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_selu.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_selu.py index ddb953abb5..b4c7e3868f 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_selu.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_selu.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSeLUConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_sigmoid.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_sigmoid.py index 32ba1009ca..61aa581e1f 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_sigmoid.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_sigmoid.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestSigmoid(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_silu.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_silu.py index cc9f32e4d5..dd5d3b5b0d 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_silu.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_silu.py @@ -1,7 +1,7 @@ import torch from torch import nn -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec from torch_tensorrt.fx.tracer.acc_tracer import acc_ops diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_size.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_size.py index dc128f5b68..3c3881d7ce 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_size.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_size.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSizeConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_softmax.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_softmax.py index c6ec1cc578..0aea850546 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_softmax.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_softmax.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSoftmaxConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_softsign.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_softsign.py index d64c8838c4..8ad72002ab 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_softsign.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_softsign.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSoftsignConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py index c71b1b8d65..811be2c05a 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSplitConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py index 9460440ee1..3af02f73ab 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSqueeze(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_std.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_std.py index 96cf6167ed..900a875e33 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_std.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_std.py @@ -1,8 +1,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestMinConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_tanh.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_tanh.py index 7dd611e9fd..9715e1c210 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_tanh.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_tanh.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestTanh(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py index d303f50f2c..d83bef5a67 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_tile.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestTile(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py index 09c321c136..63eb3345d9 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py @@ -1,7 +1,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase from torch_tensorrt.fx.utils import LowerPrecision diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py index d8ea1ef396..53dfe63190 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_topk.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestTopKConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_transpose_convolution.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_transpose_convolution.py index e9d408714e..04376a306b 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_transpose_convolution.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_transpose_convolution.py @@ -3,8 +3,8 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestTransposeConvolutionConverter(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py index c065550f51..cf3eb972d5 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py @@ -1,7 +1,7 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase from torch_tensorrt.fx.utils import LowerPrecision diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py index d11fb269bf..c1299d809d 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py @@ -5,8 +5,8 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase unary_ops = [ (torch.sin, acc_ops.sin), diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_unsqueeze.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_unsqueeze.py index e31b4f2481..097cf5435d 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_unsqueeze.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_unsqueeze.py @@ -3,8 +3,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestUnsqueeze(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_where.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_where.py index c90bb16041..72fea70265 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_where.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_where.py @@ -2,8 +2,8 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestWhere(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/vanilla/test_add.py b/py/torch_tensorrt/fx/test/converters/vanilla/test_add_vanilla.py similarity index 90% rename from py/torch_tensorrt/fx/test/converters/vanilla/test_add.py rename to py/torch_tensorrt/fx/test/converters/vanilla/test_add_vanilla.py index 4e93b79bf0..863753a144 100644 --- a/py/torch_tensorrt/fx/test/converters/vanilla/test_add.py +++ b/py/torch_tensorrt/fx/test/converters/vanilla/test_add_vanilla.py @@ -4,8 +4,8 @@ import torch import torch.fx -from torch.testing._internal.common_fx2trt import VanillaTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import VanillaTestCase class TestAddConverter(VanillaTestCase): diff --git a/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution.py b/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution_vanilla.py similarity index 98% rename from py/torch_tensorrt/fx/test/converters/vanilla/test_convolution.py rename to py/torch_tensorrt/fx/test/converters/vanilla/test_convolution_vanilla.py index 9940cd43da..384d55d44e 100644 --- a/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution.py +++ b/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution_vanilla.py @@ -3,8 +3,8 @@ import torch import torch.fx from parameterized import param, parameterized -from torch.testing._internal.common_fx2trt import VanillaTestCase from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import VanillaTestCase class TestConvolutionConverter(VanillaTestCase): diff --git a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py index 6eca458ff5..fc4da18db0 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py +++ b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py @@ -2,12 +2,12 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.passes.lower_basic_pass import ( fuse_permute_linear, trt_transposed_linear, ) +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase class TestFusePermuteLinear(AccTestCase): diff --git a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py index dab797e02b..11f2cd3ce2 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py +++ b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py @@ -3,12 +3,12 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.passes.lower_basic_pass import ( fuse_permute_matmul, trt_transposed_matmul, ) +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase def tranpose_last_two_dims(x): diff --git a/py/torch_tensorrt/fx/test/passes/test_multi_fuse_trt.py b/py/torch_tensorrt/fx/test/passes/test_multi_fuse_trt.py index 37e7b17b01..fb827d1e50 100644 --- a/py/torch_tensorrt/fx/test/passes/test_multi_fuse_trt.py +++ b/py/torch_tensorrt/fx/test/passes/test_multi_fuse_trt.py @@ -3,7 +3,6 @@ import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.passes.lower_basic_pass import ( fuse_permute_linear, @@ -11,6 +10,7 @@ trt_transposed_linear, trt_transposed_matmul, ) +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase def permute021(x): diff --git a/py/torch_tensorrt/fx/test/passes/test_setitem.py b/py/torch_tensorrt/fx/test/passes/test_setitem.py index ddb6d5255b..357d15be30 100644 --- a/py/torch_tensorrt/fx/test/passes/test_setitem.py +++ b/py/torch_tensorrt/fx/test/passes/test_setitem.py @@ -1,9 +1,9 @@ import torch import torchdynamo from parameterized import parameterized -from torch.testing._internal.common_fx2trt import AccTestCase from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase from torchdynamo.optimizations import backends diff --git a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py index 1f93a9f397..4bdc1124f9 100644 --- a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py +++ b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py @@ -92,7 +92,7 @@ def forward(self, x): m, qconfig_dict, example_inputs, - prepare_custom_config_dict=prepare_custom_config_dict, + prepare_custom_config=prepare_custom_config_dict, ) self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) mp(torch.randn(1, 1, 4, 4)) @@ -249,7 +249,7 @@ def forward(self, x): original_m_copy, qconfig_dict, example_inputs, - prepare_custom_config_dict=prepare_config, + prepare_custom_config=prepare_config, backend_config_dict=backend_config_dict, ) # calibration @@ -827,7 +827,7 @@ def forward(self, x): m, {"": qconfig}, example_inputs, - prepare_custom_config_dict=prepare_custom_config_dict, + prepare_custom_config=prepare_custom_config_dict, backend_config_dict=backend_config_dict, ) node_occurrence = { diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py new file mode 100644 index 0000000000..d6c635b402 --- /dev/null +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -0,0 +1,299 @@ +import time +import unittest +from typing import Callable, List, Tuple + +import torch +import torch.fx + +import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +from torch.fx.experimental.normalize import NormalizeArgs +from torch.fx.passes import shape_prop +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule +from torch_tensorrt.fx.passes.pass_utils import chain_passes +from torch_tensorrt.fx.utils import LowerPrecision + + +def fetch_attr(mod, target): + """ + Fetch an attribute from the ``Module`` hierarchy of ``mod.module``. + + Args: + target (str): The fully-qualfiied name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split(".") + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available") +class TRTTestCase(TestCase): + def setUp(self): + super().setUp() + torch.manual_seed(3) + + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops, + interpreter, + rtol, + atol, + precision=LowerPrecision.FP32, + ): + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + if len(expected_ops): + self.assert_has_op(mod, expected_ops) + if unexpected_ops: + self.assert_unexpected_op(mod, unexpected_ops) + start = time.perf_counter() + interpreter_result = interpreter.run(lower_precision=precision) + sec = time.perf_counter() - start + print("Interpreter run time(s):", sec) + trt_mod = TRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + + ref_outputs = mod(*inputs) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + outputs = trt_mod(*cuda_inputs) + end_event.record() + torch.cuda.synchronize() + print("TRT run time(s)=", (start_event.elapsed_time(end_event) * 1.0e-3)) + + if isinstance(outputs, torch.Tensor): + ref_outputs = [ref_outputs] + outputs = [outputs] + for out, ref in zip(outputs, ref_outputs): + if not isinstance(ref, torch.Tensor): + ref = torch.tensor([ref]) + ref = ref.cpu() # to_dtype test has cases with gpu output + torch.testing.assert_allclose(out.cpu(), ref, rtol=rtol, atol=atol) + + def run_test_custom_compare_results( + self, + mod, + inputs, + expected_ops, + interpreter, + comparators: List[Tuple[Callable, List]], + fp16_mode=False, + ): + """ + Runs the test and compares the result using the provided comparators. + The size of comparators must be equal to the number of outputs from 'mod'. + + mod - a model to run. + inputs - a list of the model inputs. + expected ops - a list of ops that should be verified. + interpreter - used for converting the model to TRT. + comparators - a list of (func, args) pairs corresponding to each of + the module outputs. usage: func(x, y, *args) + + """ + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + if len(expected_ops): + self.assert_has_op(mod, expected_ops) + + interpreter_result = interpreter.run( + lower_precision=LowerPrecision.FP16 + if fp16_mode + else LowerPrecision.FP32 + ) + trt_mod = TRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + res_trt = trt_mod(*cuda_inputs).cpu() + res_cpu = mod(*inputs) + assert len(res_trt) == len(res_cpu) + assert len(res_cpu) == len(comparators) + for output_trt, output_cpu, comparator in zip( + res_trt, res_cpu, comparators + ): + comp_func = comparator[0] + args = comparator[1] + self.assertTrue(comp_func(output_trt, output_cpu, *args)) + + def run_test_with_error(self, mod, inputs, interpreter, expect_error): + with self.assertRaises(expect_error): + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + interpreter.run(lower_precision=LowerPrecision.FP32) + + def assert_has_op(self, mod, ops): + ops_in_mod = set() + + for node in mod.graph.nodes: + if node.op == "call_module": + ops_in_mod.add(type(fetch_attr(mod, node.target))) + elif node.op in {"call_function", "call_method"}: + ops_in_mod.add(node.target) + + self.assertTrue( + ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}" + ) + + def assert_unexpected_op(self, mod, ops): + for node in mod.graph.nodes: + if node.op == "call_module": + if type(fetch_attr(mod, node.target)) in ops: + return False + elif node.op in {"call_function", "call_method"}: + if node.target in ops: + return False + return True + + +class VanillaTestCase(TRTTestCase): + def run_test(self, mod, inputs, expected_ops, rtol=1e-03, atol=1e-03): + mod = torch.fx.symbolic_trace(mod) + shape_prop.ShapeProp(mod).propagate(*inputs) + mod = NormalizeArgs(mod).transform() + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test(mod, inputs, expected_ops, None, interp, rtol, atol) + + def run_test_custom_compare_results( + self, + mod, + inputs, + expected_ops, + interpreter, + comparators: List[Tuple[Callable, List]], + fp16_mode=False, + ): + # interpreter is ignored, we do not need this for Vanilla tests + # Note this is different from internal version, we need to fix the test case + # after we refactor the internal callsites to use this file + mod = torch.fx.symbolic_trace(mod) + shape_prop.ShapeProp(mod).propagate(*inputs) + mod = NormalizeArgs(mod).transform() + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test_custom_compare_results( + mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode + ) + + +class AccTestCase(TRTTestCase): + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops=None, + apply_passes=None, + test_explicit_batch_dim=True, + test_implicit_batch_dim=True, + test_explicit_precision=False, + rtol=1e-03, + atol=1e-03, + precision=LowerPrecision.FP32, + ): + mod.eval() + mod = acc_tracer.trace(mod, inputs) + + if apply_passes is not None: + pass_tracer = chain_passes(*apply_passes) + mod = pass_tracer(mod, inputs) + + if test_implicit_batch_dim: + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + if test_explicit_batch_dim: + interp = TRTInterpreter( + mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + if test_explicit_precision: + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol + ) + + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_batch_dimension=True, + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + def run_test_with_assert_error( + self, + mod, + inputs, + expect_error, + test_explicit_batch_dim=True, + test_implicit_batch_dim=True, + ): + mod.eval() + mod = acc_tracer.trace(mod, inputs) + + if test_implicit_batch_dim: + interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) + super().run_test_with_error(mod, inputs, interp, expect_error) + + if test_explicit_batch_dim: + interp = TRTInterpreter( + mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True + ) + super().run_test_with_error(mod, inputs, interp, expect_error) + + def run_test_with_dynamic_shape( + self, + mod, + input_specs, + expected_ops, + unexpected_ops=None, + rtol=1e-03, + atol=1e-03, + ): + mod.eval() + inputs = InputTensorSpec.create_inputs_from_specs(input_specs) + mod = acc_tracer.trace(mod, inputs) + interp = TRTInterpreter(mod, input_specs, explicit_batch_dimension=True) + super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol) diff --git a/py/torch_tensorrt/fx/trt_module.py b/py/torch_tensorrt/fx/trt_module.py index 574333df7b..099bbfcdc9 100644 --- a/py/torch_tensorrt/fx/trt_module.py +++ b/py/torch_tensorrt/fx/trt_module.py @@ -39,12 +39,14 @@ def _initialize(self): primary_input_outputs.update(self.output_binding_indices_in_order) self.hidden_output_binding_indices_in_order: Sequence[int] = [] self.hidden_output_names: Sequence[str] = [] - for i in range(self.engine.num_bindings): + for i in range( + self.engine.num_bindings // self.engine.num_optimization_profiles + ): if i not in primary_input_outputs: self.hidden_output_binding_indices_in_order.append(i) self.hidden_output_names.append(self.engine.get_binding_name(i)) - assert self.engine.num_bindings == ( + assert (self.engine.num_bindings // self.engine.num_optimization_profiles) == ( len(self.input_names) + len(self.output_names) + len(self.hidden_output_names)