From 3c9b77fe9d53c4cb8250ef90a7cfa0c7e4274e12 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 24 Feb 2025 09:26:56 -0800 Subject: [PATCH 1/2] removing the fuse distributed ops lowering pass for tegra platforms --- .../lowering/passes/_aten_lowering_pass.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 676e6e1175..eb7cc94232 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -15,18 +15,22 @@ from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices -ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( - [ - remove_input_alias_fixing_clones, - constant_fold, - repair_input_as_output, - fuse_prims_broadcast, - fuse_distributed_ops, - replace_max_pool_with_indices, - remove_assert_nodes, - accumulate_fp32_matmul, - ] -) +pass_list = [ + remove_input_alias_fixing_clones, + constant_fold, + repair_input_as_output, + fuse_prims_broadcast, + replace_max_pool_with_indices, + lower_scaled_dot_product_attention, + view_to_reshape, + remove_assert_nodes, + accumulate_fp32_matmul, +] + +if torch.cuda.get_device_capability() not in [(8, 7), (7, 2)]: + pass_list.append(fuse_distributed_ops) + +ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list) ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ From 55535caae0af9a70ae7bc67419703872934981d7 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 25 Feb 2025 06:09:33 -0800 Subject: [PATCH 2/2] utility function to detect tegra platform --- .../dynamo/lowering/passes/_aten_lowering_pass.py | 5 ++--- py/torch_tensorrt/dynamo/utils.py | 6 ++++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index eb7cc94232..b66f36c11e 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -3,6 +3,7 @@ import torch from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import is_tegra_platform from .accumulate_fp32_matmul import accumulate_fp32_matmul from .constant_folding import constant_fold @@ -21,13 +22,11 @@ repair_input_as_output, fuse_prims_broadcast, replace_max_pool_with_indices, - lower_scaled_dot_product_attention, - view_to_reshape, remove_assert_nodes, accumulate_fp32_matmul, ] -if torch.cuda.get_device_capability() not in [(8, 7), (7, 2)]: +if not is_tegra_platform(): pass_list.append(fuse_distributed_ops) ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 557c01667f..e4018ae95c 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -806,3 +806,9 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype] f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node" ) return output_dtypes + + +def is_tegra_platform() -> bool: + if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: + return True + return False