Skip to content

Commit c97e365

Browse files
committed
[SYCL] Allow overriding plugin libraries
1 parent 7fc8aa0 commit c97e365

File tree

3 files changed

+47
-14
lines changed

3 files changed

+47
-14
lines changed

sycl/source/detail/config.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,7 @@ CONFIG(SYCL_CACHE_THRESHOLD, 16, __SYCL_CACHE_THRESHOLD)
3131
CONFIG(SYCL_CACHE_MIN_DEVICE_IMAGE_SIZE, 16, __SYCL_CACHE_MIN_DEVICE_IMAGE_SIZE)
3232
CONFIG(SYCL_CACHE_MAX_DEVICE_IMAGE_SIZE, 16, __SYCL_CACHE_MAX_DEVICE_IMAGE_SIZE)
3333
CONFIG(INTEL_ENABLE_OFFLOAD_ANNOTATIONS, 1, __SYCL_INTEL_ENABLE_OFFLOAD_ANNOTATIONS)
34+
CONFIG(SYCL_OVERRIDE_PI_OPENCL, 1024, __SYCL_OVERRIDE_PI_OPENCL)
35+
CONFIG(SYCL_OVERRIDE_PI_LEVEL_ZERO, 1024, __SYCL_OVERRIDE_PI_LEVEL_ZERO)
36+
CONFIG(SYCL_OVERRIDE_PI_CUDA, 1024, __SYCL_OVERRIDE_PI_CUDA)
37+
CONFIG(SYCL_OVERRIDE_PI_ROCM, 1024, __SYCL_OVERRIDE_PI_ROCM)

sycl/source/detail/pi.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -225,19 +225,34 @@ std::string memFlagsToString(pi_mem_flags Flags) {
225225
std::shared_ptr<plugin> GlobalPlugin;
226226

227227
// Find the plugin at the appropriate location and return the location.
228-
bool findPlugins(std::vector<std::pair<std::string, backend>> &PluginNames) {
228+
std::vector<std::pair<std::string, backend>> findPlugins() {
229+
std::vector<std::pair<std::string, backend>> PluginNames;
230+
229231
// TODO: Based on final design discussions, change the location where the
230232
// plugin must be searched; how to identify the plugins etc. Currently the
231233
// search is done for libpi_opencl.so/pi_opencl.dll file in LD_LIBRARY_PATH
232234
// env only.
233235
//
236+
const char *OpenCLPluginName =
237+
SYCLConfig<SYCL_OVERRIDE_PI_OPENCL>::get()
238+
? SYCLConfig<SYCL_OVERRIDE_PI_OPENCL>::get()
239+
: __SYCL_OPENCL_PLUGIN_NAME;
240+
const char *L0PluginName =
241+
SYCLConfig<SYCL_OVERRIDE_PI_LEVEL_ZERO>::get()
242+
? SYCLConfig<SYCL_OVERRIDE_PI_LEVEL_ZERO>::get()
243+
: __SYCL_LEVEL_ZERO_PLUGIN_NAME;
244+
const char *CUDAPluginName = SYCLConfig<SYCL_OVERRIDE_PI_CUDA>::get()
245+
? SYCLConfig<SYCL_OVERRIDE_PI_CUDA>::get()
246+
: __SYCL_CUDA_PLUGIN_NAME;
247+
const char *ROCMPluginName = SYCLConfig<SYCL_OVERRIDE_PI_ROCM>::get()
248+
? SYCLConfig<SYCL_OVERRIDE_PI_ROCM>::get()
249+
: __SYCL_ROCM_PLUGIN_NAME;
234250
device_filter_list *FilterList = SYCLConfig<SYCL_DEVICE_FILTER>::get();
235251
if (!FilterList) {
236-
PluginNames.emplace_back(__SYCL_OPENCL_PLUGIN_NAME, backend::opencl);
237-
PluginNames.emplace_back(__SYCL_LEVEL_ZERO_PLUGIN_NAME,
238-
backend::level_zero);
239-
PluginNames.emplace_back(__SYCL_CUDA_PLUGIN_NAME, backend::cuda);
240-
PluginNames.emplace_back(__SYCL_ROCM_PLUGIN_NAME, backend::rocm);
252+
PluginNames.emplace_back(OpenCLPluginName, backend::opencl);
253+
PluginNames.emplace_back(L0PluginName, backend::level_zero);
254+
PluginNames.emplace_back(CUDAPluginName, backend::cuda);
255+
PluginNames.emplace_back(ROCMPluginName, backend::rocm);
241256
} else {
242257
std::vector<device_filter> Filters = FilterList->get();
243258
bool OpenCLFound = false;
@@ -248,26 +263,25 @@ bool findPlugins(std::vector<std::pair<std::string, backend>> &PluginNames) {
248263
backend Backend = Filter.Backend;
249264
if (!OpenCLFound &&
250265
(Backend == backend::opencl || Backend == backend::all)) {
251-
PluginNames.emplace_back(__SYCL_OPENCL_PLUGIN_NAME, backend::opencl);
266+
PluginNames.emplace_back(OpenCLPluginName, backend::opencl);
252267
OpenCLFound = true;
253268
}
254269
if (!LevelZeroFound &&
255270
(Backend == backend::level_zero || Backend == backend::all)) {
256-
PluginNames.emplace_back(__SYCL_LEVEL_ZERO_PLUGIN_NAME,
257-
backend::level_zero);
271+
PluginNames.emplace_back(L0PluginName, backend::level_zero);
258272
LevelZeroFound = true;
259273
}
260274
if (!CudaFound && (Backend == backend::cuda || Backend == backend::all)) {
261-
PluginNames.emplace_back(__SYCL_CUDA_PLUGIN_NAME, backend::cuda);
275+
PluginNames.emplace_back(CUDAPluginName, backend::cuda);
262276
CudaFound = true;
263277
}
264278
if (!RocmFound && (Backend == backend::rocm || Backend == backend::all)) {
265-
PluginNames.emplace_back(__SYCL_ROCM_PLUGIN_NAME, backend::rocm);
279+
PluginNames.emplace_back(ROCMPluginName, backend::rocm);
266280
RocmFound = true;
267281
}
268282
}
269283
}
270-
return true;
284+
return PluginNames;
271285
}
272286

273287
// Load the Plugin by calling the OS dependent library loading call.
@@ -321,8 +335,7 @@ const std::vector<plugin> &initialize() {
321335
}
322336

323337
static void initializePlugins(std::vector<plugin> *Plugins) {
324-
std::vector<std::pair<std::string, backend>> PluginNames;
325-
findPlugins(PluginNames);
338+
std::vector<std::pair<std::string, backend>> PluginNames = findPlugins();
326339

327340
if (PluginNames.empty() && trace(PI_TRACE_ALL))
328341
std::cerr << "SYCL_PI_TRACE[all]: "
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: env SYCL_OVERRIDE_PI_OPENCL=opencl_test env SYCL_OVERRIDE_PI_LEVEL_ZERO=l0_test env SYCL_OVERRIDE_PI_CUDA=cuda_test env SYCL_OVERRIDE_PI_ROCM=rocm_test env SYCL_PI_TRACE=-1 %t.out > %t.log 2>&1
3+
// FileCheck %s --input-file %t.log
4+
5+
#include <sycl/sycl.hpp>
6+
7+
int main() {
8+
sycl::queue Q;
9+
10+
return 0;
11+
}
12+
13+
// CHECK: SYCL_PI_TRACE[all]: Check if plugin is present. Failed to load plugin: opencl_test
14+
// CHECK: SYCL_PI_TRACE[all]: Check if plugin is present. Failed to load plugin: l0_test
15+
// CHECK: SYCL_PI_TRACE[all]: Check if plugin is present. Failed to load plugin: cuda_test
16+
// CHECK: SYCL_PI_TRACE[all]: Check if plugin is present. Failed to load plugin: rocm_test

0 commit comments

Comments
 (0)