Skip to content

Commit cc6fbdb

Browse files
authored
Add sm_89 and point to nvcuda.dll (huggingface#731)
1 parent ecfdec1 commit cc6fbdb

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

shark/iree_utils/gpu_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def get_iree_gpu_args():
2525
# TODO: Give the user_interface to pass the sm_arch.
2626
sm_arch = get_cuda_sm_cc()
2727
if (
28-
sm_arch in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86"]
28+
sm_arch
29+
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
2930
) and (shark_args.enable_tf32 == True):
3031
return [
3132
"--iree-hal-cuda-disable-loop-nounroll-wa",
@@ -56,7 +57,7 @@ def get_iree_rocm_args():
5657

5758

5859
def get_cuda_sm_cc():
59-
libnames = ("libcuda.so", "libcuda.dylib", "cuda.dll")
60+
libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll")
6061
for libname in libnames:
6162
try:
6263
cuda = ctypes.CDLL(libname)

0 commit comments

Comments
 (0)