Skip to content

Commit 7680456

Browse files
committed
feat(//py): Use TensorRT to fill in .so libraries automatically if
possible Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 1625cd3 commit 7680456

File tree

2 files changed

+90
-6
lines changed

2 files changed

+90
-6
lines changed

py/setup.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@
2222

2323
JETPACK_VERSION = None
2424

25-
__version__ = '1.2.0a0'
2625
FX_ONLY = False
2726

27+
__version__ = '1.2.0a0'
28+
__cuda_version__ = '11.3'
29+
__cudnn_version__ = '8.2'
30+
__tensorrt_version__ = '8.2'
31+
2832
def get_git_revision_short_hash() -> str:
2933
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip()
3034

@@ -51,8 +55,10 @@ def get_git_revision_short_hash() -> str:
5155
JETPACK_VERSION = "4.5"
5256
elif version == "4.6":
5357
JETPACK_VERSION = "4.6"
58+
elif version == "5.0":
59+
JETPACK_VERSION = "4.6"
5460
if not JETPACK_VERSION:
55-
warnings.warn("Assuming jetpack version to be 4.6, if not use the --jetpack-version option")
61+
warnings.warn("Assuming jetpack version to be 4.6 or greater, if not use the --jetpack-version option")
5662
JETPACK_VERSION = "4.6"
5763

5864

@@ -103,7 +109,7 @@ def build_libtorchtrt_pre_cxx11_abi(develop=True, use_dist_dir=True, cxx11_abi=F
103109
print("Jetpack version: 4.5")
104110
elif JETPACK_VERSION == "4.6":
105111
cmd.append("--platforms=//toolchains:jetpack_4.6")
106-
print("Jetpack version: 4.6")
112+
print("Jetpack version: >=4.6")
107113

108114
print("building libtorchtrt")
109115
status_code = subprocess.run(cmd).returncode
@@ -118,7 +124,10 @@ def gen_version_file():
118124

119125
with open(dir_path + '/torch_tensorrt/_version.py', 'w') as f:
120126
print("creating version file")
121-
f.write("__version__ = \"" + __version__ + '\"')
127+
f.write("__version__ = \"" + __version__ + '\"\n')
128+
f.write("__cuda_version__ = \"" + __cuda_version__ + '\"\n')
129+
f.write("__cudnn_version__ = \"" + __cudnn_version__ + '\"\n')
130+
f.write("__tensorrt_version__ = \"" + __tensorrt_version__ + '\"\n')
122131

123132

124133
def copy_libtorchtrt(multilinux=False):

py/torch_tensorrt/__init__.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,88 @@
1+
import ctypes
2+
import glob
13
import os
24
import sys
5+
import platform
6+
import warnings
7+
from torch_tensorrt._version import __version__, __cuda_version__, __cudnn_version__, __tensorrt_version__
8+
39

410
if sys.version_info < (3,):
511
raise Exception("Python 2 has reached end-of-life and is not supported by Torch-TensorRT")
612

7-
import ctypes
13+
def _parse_semver(version):
14+
split = version.split(".")
15+
if len(split) < 3:
16+
split.append("")
17+
18+
return {
19+
"major": split[0],
20+
"minor": split[1],
21+
"patch": split[2]
22+
}
23+
24+
def _find_lib(name, paths):
25+
for path in paths:
26+
libpath = os.path.join(path, name)
27+
if os.path.isfile(libpath):
28+
return libpath
29+
30+
raise FileNotFoundError(
31+
f"Could not find {name}\n Search paths: {paths}"
32+
)
33+
34+
try:
35+
import tensorrt
36+
except:
37+
cuda_version = _parse_semver(__cuda_version__)
38+
cudnn_version = _parse_semver(__cudnn_version__)
39+
tensorrt_version = _parse_semver(__tensorrt_version__)
40+
41+
CUDA_MAJOR = cuda_version["major"]
42+
CUDNN_MAJOR = cudnn_version["major"]
43+
TENSORRT_MAJOR = tensorrt_version["major"]
44+
45+
if sys.platform.startswith("win"):
46+
WIN_LIBS = [
47+
f"cublas64_{CUDA_MAJOR}.dll",
48+
f"cublasLt64_{CUDA_MAJOR}.dll",
49+
f"cudnn64_{CUDNN_MAJOR}.dll",
50+
"nvinfer.dll",
51+
"nvinfer_plugin.dll",
52+
]
53+
54+
WIN_PATHS = os.environ["PATH"].split(os.path.pathsep)
55+
56+
57+
for lib in WIN_LIBS:
58+
ctypes.CDLL(_find_lib(lib, WIN_PATHS))
59+
60+
elif sys.platform.startswith("linux"):
61+
LINUX_PATHS = [
62+
"/usr/local/cuda/lib64",
63+
] + os.environ["LD_LIBRARY_PATH"].split(os.path.pathsep)
64+
65+
if platform.uname().processor == "x86_64":
66+
LINUX_PATHS += [
67+
"/usr/lib/x86_64-linux-gnu",
68+
]
69+
70+
elif platform.uname().processor == "aarch64":
71+
LINUX_PATHS += [
72+
"/usr/lib/aarch64-linux-gnu"
73+
]
74+
75+
LINUX_LIBS = [
76+
f"libcudnn.so.{CUDNN_MAJOR}",
77+
f"libnvinfer.so.{TENSORRT_MAJOR}",
78+
f"libnvinfer_plugin.so.{TENSORRT_MAJOR}",
79+
]
80+
81+
for lib in LINUX_LIBS:
82+
ctypes.CDLL(_find_lib(lib, LINUX_PATHS))
83+
884
import torch
985

10-
from torch_tensorrt._version import __version__
1186
from torch_tensorrt._compile import *
1287
from torch_tensorrt._util import *
1388
from torch_tensorrt import ts

0 commit comments

Comments
 (0)