Skip to content

Commit 57943f3

Browse files
committed
utility function to detect tegra platform
1 parent d8063cd commit 57943f3

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from torch_tensorrt.dynamo._settings import CompilationSettings
6+
from torch_tensorrt.dynamo.utils import is_tegra_platform
67

78
from .accumulate_fp32_matmul import accumulate_fp32_matmul
89
from .constant_folding import constant_fold
@@ -29,7 +30,7 @@
2930
accumulate_fp32_matmul,
3031
]
3132

32-
if torch.cuda.get_device_capability() not in [(8, 7), (7, 2)]:
33+
if not is_tegra_platform():
3334
pass_list.append(fuse_distributed_ops)
3435

3536
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,3 +799,9 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
799799
f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node"
800800
)
801801
return output_dtypes
802+
803+
804+
def is_tegra_platform() -> bool:
805+
if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]:
806+
return True
807+
return False

0 commit comments

Comments
 (0)