6
6
7
7
from libc.stdint cimport intptr_t
8
8
9
- from .utils cimport get_nvvm_dso_version_suffix
10
-
11
9
from .utils import FunctionNotFoundError, NotSupportedError
12
10
13
- import os
14
- import site
11
+ from cuda.bindings import path_finder
15
12
16
13
import win32api
17
14
@@ -40,54 +37,6 @@ cdef void* __nvvmGetProgramLogSize = NULL
40
37
cdef void * __nvvmGetProgramLog = NULL
41
38
42
39
43
- cdef inline list get_site_packages():
44
- return [site.getusersitepackages()] + site.getsitepackages() + [" conda" ]
45
-
46
-
47
- cdef load_library(const int driver_ver):
48
- handle = 0
49
-
50
- for suffix in get_nvvm_dso_version_suffix(driver_ver):
51
- if len (suffix) == 0 :
52
- continue
53
- dll_name = " nvvm64_40_0.dll"
54
-
55
- # First check if the DLL has been loaded by 3rd parties
56
- try :
57
- return win32api.GetModuleHandle(dll_name)
58
- except :
59
- pass
60
-
61
- # Next, check if DLLs are installed via pip or conda
62
- for sp in get_site_packages():
63
- if sp == " conda" :
64
- # nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path
65
- conda_prefix = os.environ.get(" CONDA_PREFIX" )
66
- if conda_prefix is None :
67
- continue
68
- mod_path = os.path.join(conda_prefix, " Library" , " nvvm" , " bin" )
69
- else :
70
- mod_path = os.path.join(sp, " nvidia" , " cuda_nvcc" , " nvvm" , " bin" )
71
- if os.path.isdir(mod_path):
72
- os.add_dll_directory(mod_path)
73
- try :
74
- return win32api.LoadLibraryEx(
75
- # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
76
- os.path.join(mod_path, dll_name),
77
- 0 , LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
78
- except :
79
- pass
80
-
81
- # Finally, try default search
82
- # Only reached if DLL wasn't found in any site-package path
83
- try :
84
- return win32api.LoadLibrary(dll_name)
85
- except :
86
- pass
87
-
88
- raise RuntimeError (' Failed to load nvvm' )
89
-
90
-
91
40
cdef int _check_or_init_nvvm() except - 1 nogil:
92
41
global __py_nvvm_init
93
42
if __py_nvvm_init:
@@ -110,7 +59,7 @@ cdef int _check_or_init_nvvm() except -1 nogil:
110
59
raise RuntimeError (' something went wrong' )
111
60
112
61
# Load library
113
- handle = load_library(driver_ver)
62
+ handle = path_finder._load_nvidia_dynamic_library( " nvvm " ).handle
114
63
115
64
# Load function
116
65
global __nvvmVersion
0 commit comments