From 4baaea1839edd1be7658dcf884c079bb4fa58f05 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 31 Jul 2025 15:49:38 -0700 Subject: [PATCH 1/2] add strong type test cases --- .../dynamo/conversion/_TRTInterpreter.py | 18 +++---- .../dynamo/lowering/_decompositions.py | 4 +- tests/py/dynamo/models/test_dtype_support.py | 22 ++++---- tests/py/dynamo/models/test_dyn_models.py | 19 ++++--- tests/py/dynamo/models/test_models.py | 52 ++++++++++++------- tests/py/dynamo/models/test_models_export.py | 44 ++++++++-------- 6 files changed, 93 insertions(+), 66 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 8d7a914836..b8d4994fca 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -277,18 +277,18 @@ def _populate_trt_builder_config( trt.MemoryPoolType.DLA_GLOBAL_DRAM, self.compilation_settings.dla_global_dram_size, ) + if not self.compilation_settings.use_explicit_typing: + if dtype.float16 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.FP16) - if dtype.float16 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.FP16) + if dtype.int8 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.INT8) - if dtype.int8 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.INT8) + if dtype.fp8 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.FP8) - if dtype.fp8 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.FP8) - - if dtype.bfloat16 in self.compilation_settings.enabled_precisions: - builder_config.set_flag(trt.BuilderFlag.BF16) + if dtype.bfloat16 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.BF16) if self.compilation_settings.sparse_weights: builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index b0cfdee4f0..feefcc4c4c 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -483,7 +483,9 @@ def scaled_dot_product_attention_decomposition( attn_weight = query @ key.transpose(-2, -1) if scale is None: - scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int)) + scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int)).to( + query.dtype + ) attn_weight = attn_weight / scale else: attn_weight = attn_weight * scale diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py index 37b40574a1..62bcacc94a 100644 --- a/tests/py/dynamo/models/test_dtype_support.py +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -42,6 +42,7 @@ def forward(self, x): use_python_runtime=False, cache_built_engines=False, reuse_cached_engines=False, + use_explicit_typing=True, ) torch_model_results = mod(in_tensor) @@ -82,12 +83,13 @@ def forward(self, x): use_python_runtime=True, cache_built_engines=False, reuse_cached_engines=False, + use_explicit_typing=True, ) torch_model_results = mod(in_tensor) with torch_tensorrt.logging.debug(): optimized_model_results = trt_mod(in_tensor) - + assert torch_model_results.dtype == optimized_model_results.dtype max_diff = float( torch.max(torch.abs(optimized_model_results - torch_model_results)) ) @@ -128,11 +130,12 @@ def forward(self, x): use_python_runtime=False, cache_built_engines=False, reuse_cached_engines=False, + use_explicit_typing=True, ) torch_model_results = mod(in_tensor) optimized_model_results = trt_mod(in_tensor) - + assert torch_model_results.dtype == optimized_model_results.dtype max_diff = float( torch.max(torch.abs(optimized_model_results - torch_model_results)) ) @@ -169,11 +172,12 @@ def forward(self, x): use_python_runtime=True, cache_built_engines=False, reuse_cached_engines=False, + use_explicit_typing=True, ) torch_model_results = mod(in_tensor) optimized_model_results = trt_mod(in_tensor) - + assert torch_model_results.dtype == optimized_model_results.dtype max_diff = float( torch.max(torch.abs(optimized_model_results - torch_model_results)) ) @@ -218,16 +222,16 @@ def forward(self, x): exp_mod, inputs=[in_tensor], pass_through_build_failures=True, - enabled_precisions={torch.float, torch.bfloat16, torch.half}, min_block_size=1, use_python_runtime=False, cache_built_engines=False, reuse_cached_engines=False, + use_explicit_typing=True, ) torch_model_results = mod(in_tensor) optimized_model_results = trt_mod(in_tensor) - + assert torch_model_results.dtype == optimized_model_results.dtype max_diff = float( torch.max(torch.abs(optimized_model_results - torch_model_results)) ) @@ -258,16 +262,16 @@ def forward(self, x): exp_mod, inputs=[in_tensor], pass_through_build_failures=True, - enabled_precisions={torch.float, torch.bfloat16, torch.half}, min_block_size=1, use_python_runtime=True, cache_built_engines=False, reuse_cached_engines=False, + use_explicit_typing=True, ) torch_model_results = mod(in_tensor) optimized_model_results = trt_mod(in_tensor) - + assert torch_model_results.dtype == optimized_model_results.dtype max_diff = float( torch.max(torch.abs(optimized_model_results - torch_model_results)) ) @@ -296,16 +300,16 @@ def forward(self, x): mod, ir="torch_compile", inputs=inputs, - enabled_precisions={torch.bfloat16}, min_block_size=1, device=device, cache_built_engines=False, reuse_cached_engines=False, + use_explicit_typing=True, ) torch_model_results = mod(*inputs) optimized_model_results = trt_mod(*inputs) - + assert torch_model_results.dtype == optimized_model_results.dtype max_diff = float( torch.max(torch.abs(optimized_model_results - torch_model_results)) ) diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 052a268a64..fb3a3b8688 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -178,26 +178,27 @@ def forward(self, x): not importlib.util.find_spec("torchvision"), "torchvision not installed" ) @pytest.mark.unit -def test_resnet_dynamic(ir): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_resnet_dynamic(ir, dtype): """ Tests the Resnet18 model (which is fully convertible) with dynamic shapes """ import torchvision.models as models - model = models.resnet18(pretrained=True).eval().to("cuda") + model = models.resnet18(pretrained=True).eval().to("cuda").to(dtype) compile_spec = { "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, "ir": ir, "pass_through_build_failures": True, "min_block_size": 1, "cache_built_engines": False, "reuse_cached_engines": False, + "use_explicit_typing": True, } if ir == "torch_compile": - input_bs2 = torch.randn((2, 3, 224, 224)).to("cuda") + input_bs2 = torch.randn((2, 3, 224, 224)).to("cuda").to(dtype) torch._dynamo.mark_dynamic(input_bs2, 0, min=1, max=8) # Compile the model trt_model = torch.compile(model, backend="tensorrt", options=compile_spec) @@ -208,14 +209,18 @@ def test_resnet_dynamic(ir): min_shape=(1, 3, 224, 224), opt_shape=(4, 3, 224, 224), max_shape=(8, 3, 224, 224), - dtype=torch.float32, + dtype=dtype, name="x", ) ] trt_model = torchtrt.compile(model, **compile_spec) - input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda") - cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6)) + input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda").to(dtype) + pyt_output = model(input_bs6) + trt_output = trt_model(input_bs6) + assert pyt_output.dtype == trt_output.dtype + assert trt_output.dtype == dtype + cos_sim = cosine_similarity(pyt_output, trt_output) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"test_resnet_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index 90d3cc637b..2584141ef3 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -136,28 +136,31 @@ def test_resnet18_torch_exec_ops(ir): not importlib.util.find_spec("torchvision"), "torchvision is not installed", ) -def test_mobilenet_v2(ir): - model = models.mobilenet_v2(pretrained=True).eval().to("cuda") - input = torch.randn((1, 3, 224, 224)).to("cuda") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_mobilenet_v2(ir, dtype): + model = models.mobilenet_v2(pretrained=True).eval().to("cuda").to(dtype) + input = torch.randn((1, 3, 224, 224)).to("cuda").to(dtype) compile_spec = { "inputs": [ - torchtrt.Input( - input.shape, dtype=torch.float, format=torch.contiguous_format - ) + torchtrt.Input(input.shape, dtype=dtype, format=torch.contiguous_format) ], "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 10, "cache_built_engines": False, "reuse_cached_engines": False, + "use_explicit_typing": True, } trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)) + pyt_output = model(input) + trt_output = trt_mod(input) + assert pyt_output.dtype == trt_output.dtype + assert pyt_output.dtype == dtype + cos_sim = cosine_similarity(pyt_output, trt_output) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", @@ -172,28 +175,36 @@ def test_mobilenet_v2(ir): not importlib.util.find_spec("timm") or not importlib.util.find_spec("torchvision"), "timm or torchvision not installed", ) -def test_efficientnet_b0(ir): - model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") - input = torch.randn((1, 3, 224, 224)).to("cuda") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_efficientnet_b0(ir, dtype): + model = ( + timm.create_model("efficientnet_b0", pretrained=True) + .eval() + .to("cuda") + .to(dtype) + ) + input = torch.randn((1, 3, 224, 224)).to("cuda").to(dtype) compile_spec = { "inputs": [ - torchtrt.Input( - input.shape, dtype=torch.float, format=torch.contiguous_format - ) + torchtrt.Input(input.shape, dtype=dtype, format=torch.contiguous_format) ], "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 10, "cache_built_engines": False, "reuse_cached_engines": False, + "use_explicit_typing": True, } trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)) + pyt_output = model(input) + trt_output = trt_mod(input) + assert pyt_output.dtype == trt_output.dtype + assert pyt_output.dtype == dtype + cos_sim = cosine_similarity(pyt_output, trt_output) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", @@ -208,10 +219,11 @@ def test_efficientnet_b0(ir): not importlib.util.find_spec("transformers"), "transformers is required to run this test", ) -def test_bert_base_uncased(ir): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_bert_base_uncased_lan(ir, dtype): from transformers import BertModel - model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() + model = BertModel.from_pretrained("bert-base-uncased").cuda().eval().to(dtype) input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") input2 = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") @@ -229,7 +241,6 @@ def test_bert_base_uncased(ir): ), ], "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, "truncate_double": True, "ir": ir, "pass_through_build_failures": True, @@ -237,6 +248,7 @@ def test_bert_base_uncased(ir): "min_block_size": 15, "cache_built_engines": False, "reuse_cached_engines": False, + "use_explicit_typing": True, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -244,6 +256,8 @@ def test_bert_base_uncased(ir): trt_model_outputs = trt_mod(input, input2) for key in model_outputs.keys(): out, trt_out = model_outputs[key], trt_model_outputs[key] + assert out.dtype == trt_out.dtype + assert out.dtype == dtype cos_sim = cosine_similarity(out, trt_out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index f7a851676f..53be48879a 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -386,7 +386,8 @@ def calibrate_loop(model): "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", ) @pytest.mark.unit -def test_base_int8(ir): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_base_int8(ir, dtype): import modelopt.torch.quantization as mtq from modelopt.torch.quantization.utils import export_torch_mode @@ -406,27 +407,29 @@ def calibrate_loop(model): """Simple calibration function for testing.""" model(input_tensor) - input_tensor = torch.randn(1, 10).cuda() - model = SimpleNetwork().eval().cuda() - + input_tensor = torch.randn(1, 10).cuda().to(dtype) + model = SimpleNetwork().eval().cuda().to(dtype) quant_cfg = mtq.INT8_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has INT8 qdq nodes at this point output_pyt = model(input_tensor) - + breakpoint() with torch.no_grad(): with export_torch_mode(): exp_program = torch.export.export(model, (input_tensor,), strict=False) - trt_model = torchtrt.dynamo.compile( - exp_program, - inputs=[input_tensor], - enabled_precisions={torch.int8}, - min_block_size=1, - cache_built_engines=False, - reuse_cached_engines=False, - truncate_double=True, - ) + with torchtrt.logging.debug(): + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + truncate_double=True, + use_explicit_typing=True, + ) outputs_trt = trt_model(input_tensor) + assert output_pyt.dtype == outputs_trt.dtype + assert outputs_trt.dtype == dtype assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) @@ -437,17 +440,16 @@ def calibrate_loop(model): "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", ) @pytest.mark.unit -def test_base_int8_dynamic_shape(ir): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_base_int8_dynamic_shape(ir, dtype): import modelopt.torch.quantization as mtq from modelopt.torch.quantization.utils import export_torch_mode - dtype = torch.bfloat16 - class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() - self.conv = torch.nn.Conv2d(3, 3, 3, dtype=dtype) - self.linear = torch.nn.Linear(222, 222, dtype=dtype) + self.conv = torch.nn.Conv2d(3, 3, 3) + self.linear = torch.nn.Linear(222, 222) def forward(self, x): return self.linear(self.conv(x)) @@ -459,7 +461,7 @@ def calibrate_loop(model): BATCH_SIZE = torch.export.Dim("BATCH_SIZE", min=2, max=16) batch_size = 8 input_tensor = torch.randn(batch_size, 3, 224, 224, dtype=dtype).cuda() - model = SimpleNetwork().eval().cuda() + model = SimpleNetwork().eval().cuda().to(dtype) quant_cfg = mtq.INT8_DEFAULT_CFG mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) @@ -475,11 +477,11 @@ def calibrate_loop(model): trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], - enabled_precisions={torch.int8, dtype}, min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, truncate_double=True, + use_explicit_typing=True, ) outputs_trt = trt_model(input_tensor) assert torch.allclose(output_pyt, outputs_trt, rtol=5e-2, atol=5e-2) From 19d516c97e9d2c2ab26fae0ddf0ff9a8f17efb13 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 1 Aug 2025 11:19:15 -0700 Subject: [PATCH 2/2] test --- .../dynamo/conversion/impl/normalization/ops.py | 10 +++++----- tests/py/dynamo/models/test_models.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 7cf9422330..f9b47542a8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -60,28 +60,28 @@ def batch_norm( ): # We name the weight here according to the state_dict name weight = ( - get_trt_tensor(ctx, 1.0, f"{name}_weight") + get_trt_tensor(ctx, 1.0, f"{name}_weight", dtype=input.dtype) if weight is None else get_trt_tensor(ctx, weight, f"{name}_weight") ) bias = ( - get_trt_tensor(ctx, 0.0, f"{name}_bias") + get_trt_tensor(ctx, 0.0, f"{name}_bias", dtype=input.dtype) if bias is None else get_trt_tensor(ctx, bias, f"{name}_bias") ) running_mean = ( - get_trt_tensor(ctx, 0.0, f"{name}_running_mean") + get_trt_tensor(ctx, 0.0, f"{name}_running_mean", dtype=input.dtype) if running_mean is None else get_trt_tensor(ctx, running_mean, f"{name}_running_mean") ) running_var = ( - get_trt_tensor(ctx, 1.0, f"{name}_running_var") + get_trt_tensor(ctx, 1.0, f"{name}_running_var", dtype=input.dtype) if running_var is None else get_trt_tensor(ctx, running_var, f"{name}_running_var") ) # eps_tensor for numerical stability - eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps") + eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps", dtype=input.dtype) # adjusted_var = running_var + eps adjusted_var = impl.elementwise.add( diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index 2584141ef3..84f36e48ce 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -220,7 +220,7 @@ def test_efficientnet_b0(ir, dtype): "transformers is required to run this test", ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) -def test_bert_base_uncased_lan(ir, dtype): +def test_bert_base_uncased(ir, dtype): from transformers import BertModel model = BertModel.from_pretrained("bert-base-uncased").cuda().eval().to(dtype)