Skip to content

Commit 0717c40

Browse files
authored
[SYCL][UR] Bump UR and implement adapter handles (#10349)
Bump the Unified Runtime commit, and make adapter changes needed for the newly added adapter handles (see oneapi-src/unified-runtime#715 for details) This fixes #10066 by providing an implementation of `piPluginGetLastError` in pi2ur.
1 parent 7bf1f57 commit 0717c40

27 files changed

+618
-197
lines changed

sycl/plugins/cuda/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ add_sycl_plugin(cuda
5555
"../unified_runtime/ur/ur.cpp"
5656
"../unified_runtime/ur/usm_allocator.cpp"
5757
"../unified_runtime/ur/usm_allocator.hpp"
58+
"../unified_runtime/ur/adapters/cuda/adapter.cpp"
59+
"../unified_runtime/ur/adapters/cuda/adapter.hpp"
5860
"../unified_runtime/ur/adapters/cuda/command_buffer.cpp"
5961
"../unified_runtime/ur/adapters/cuda/command_buffer.hpp"
6062
"../unified_runtime/ur/adapters/cuda/common.cpp"

sycl/plugins/hip/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ add_sycl_plugin(hip
9292
"../unified_runtime/pi2ur.cpp"
9393
"../unified_runtime/ur/ur.hpp"
9494
"../unified_runtime/ur/ur.cpp"
95+
"../unified_runtime/ur/adapters/hip/adapter.cpp"
96+
"../unified_runtime/ur/adapters/hip/adapter.hpp"
9597
"../unified_runtime/ur/adapters/hip/command_buffer.cpp"
9698
"../unified_runtime/ur/adapters/hip/command_buffer.hpp"
9799
"../unified_runtime/ur/adapters/hip/common.cpp"

sycl/plugins/level_zero/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ add_sycl_plugin(level_zero
103103
"../unified_runtime/ur/usm_allocator_config.hpp"
104104
"../unified_runtime/ur/adapters/level_zero/ur_level_zero.hpp"
105105
"../unified_runtime/ur/adapters/level_zero/command_buffer.hpp"
106+
"../unified_runtime/ur/adapters/level_zero/adapter.hpp"
106107
"../unified_runtime/ur/adapters/level_zero/common.hpp"
107108
"../unified_runtime/ur/adapters/level_zero/context.hpp"
108109
"../unified_runtime/ur/adapters/level_zero/device.hpp"
@@ -116,6 +117,7 @@ add_sycl_plugin(level_zero
116117
"../unified_runtime/ur/adapters/level_zero/sampler.hpp"
117118
"../unified_runtime/ur/adapters/level_zero/usm.hpp"
118119
"../unified_runtime/ur/adapters/level_zero/ur_level_zero.cpp"
120+
"../unified_runtime/ur/adapters/level_zero/adapter.cpp"
119121
"../unified_runtime/ur/adapters/level_zero/command_buffer.cpp"
120122
"../unified_runtime/ur/adapters/level_zero/common.cpp"
121123
"../unified_runtime/ur/adapters/level_zero/context.cpp"

sycl/plugins/native_cpu/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_sycl_plugin(native_cpu
55
"../unified_runtime/pi2ur.cpp"
66
"../unified_runtime/ur/ur.hpp"
77
"../unified_runtime/ur/ur.cpp"
8+
"../unified_runtime/ur/adapters/native_cpu/adapter.cpp"
89
"../unified_runtime/ur/adapters/native_cpu/common.cpp"
910
"../unified_runtime/ur/adapters/native_cpu/common.hpp"
1011
"../unified_runtime/ur/adapters/native_cpu/context.cpp"
@@ -24,7 +25,6 @@ add_sycl_plugin(native_cpu
2425
"../unified_runtime/ur/adapters/native_cpu/queue.cpp"
2526
"../unified_runtime/ur/adapters/native_cpu/queue.hpp"
2627
"../unified_runtime/ur/adapters/native_cpu/sampler.cpp"
27-
"../unified_runtime/ur/adapters/native_cpu/runtime.cpp"
2828
"../unified_runtime/ur/adapters/native_cpu/ur_interface_loader.cpp"
2929
"../unified_runtime/ur/adapters/native_cpu/usm.cpp"
3030
"../unified_runtime/ur/adapters/native_cpu/usm_p2p.cpp"

sycl/plugins/unified_runtime/CMakeLists.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ if (NOT DEFINED UNIFIED_RUNTIME_LIBRARY OR NOT DEFINED UNIFIED_RUNTIME_INCLUDE_D
44
include(FetchContent)
55

66
set(UNIFIED_RUNTIME_REPO "https://github.com/oneapi-src/unified-runtime.git")
7-
set(UNIFIED_RUNTIME_TAG 3c6f02c7a76a0448a83932d93c2dbeff25af70aa)
7+
set(UNIFIED_RUNTIME_TAG 974a7d64dd1a26ede1ff27919b3b8713b848c376)
88

99
message(STATUS "Will fetch Unified Runtime from ${UNIFIED_RUNTIME_REPO}")
1010
FetchContent_Declare(unified-runtime
@@ -86,6 +86,7 @@ add_sycl_library("ur_adapter_level_zero" SHARED
8686
"ur/adapters/level_zero/ur_level_zero.hpp"
8787
"ur/adapters/level_zero/ur_level_zero.cpp"
8888
"ur/adapters/level_zero/ur_interface_loader.cpp"
89+
"ur/adapters/level_zero/adapter.hpp"
8990
"ur/adapters/level_zero/command_buffer.hpp"
9091
"ur/adapters/level_zero/common.hpp"
9192
"ur/adapters/level_zero/context.hpp"
@@ -100,6 +101,7 @@ add_sycl_library("ur_adapter_level_zero" SHARED
100101
"ur/adapters/level_zero/queue.hpp"
101102
"ur/adapters/level_zero/sampler.hpp"
102103
"ur/adapters/level_zero/usm.hpp"
104+
"ur/adapters/level_zero/adapter.cpp"
103105
"ur/adapters/level_zero/command_buffer.cpp"
104106
"ur/adapters/level_zero/common.cpp"
105107
"ur/adapters/level_zero/context.cpp"
@@ -135,6 +137,8 @@ if ("cuda" IN_LIST SYCL_ENABLE_PLUGINS)
135137
"ur/ur.cpp"
136138
"ur/usm_allocator.cpp"
137139
"ur/usm_allocator.hpp"
140+
"ur/adapters/cuda/adapter.cpp"
141+
"ur/adapters/cuda/adapter.hpp"
138142
"ur/adapters/cuda/command_buffer.cpp"
139143
"ur/adapters/cuda/command_buffer.hpp"
140144
"ur/adapters/cuda/common.cpp"
@@ -186,6 +190,8 @@ if ("hip" IN_LIST SYCL_ENABLE_PLUGINS)
186190
"ur/ur.cpp"
187191
"ur/usm_allocator.cpp"
188192
"ur/usm_allocator.hpp"
193+
"ur/adapters/hip/adapter.cpp"
194+
"ur/adapters/hip/adapter.hpp"
189195
"ur/adapters/hip/command_buffer.cpp"
190196
"ur/adapters/hip/command_buffer.hpp"
191197
"ur/adapters/hip/common.cpp"
@@ -243,6 +249,7 @@ if("native_cpu" IN_LIST SYCL_ENABLE_PLUGINS)
243249
SOURCES
244250
"ur/ur.cpp"
245251
"ur/ur.hpp"
252+
"ur/adapters/native_cpu/adapter.cpp"
246253
"ur/adapters/native_cpu/common.cpp"
247254
"ur/adapters/native_cpu/common.hpp"
248255
"ur/adapters/native_cpu/context.cpp"
@@ -262,7 +269,6 @@ if("native_cpu" IN_LIST SYCL_ENABLE_PLUGINS)
262269
"ur/adapters/native_cpu/queue.cpp"
263270
"ur/adapters/native_cpu/queue.hpp"
264271
"ur/adapters/native_cpu/sampler.cpp"
265-
"ur/adapters/native_cpu/runtime.cpp"
266272
"ur/adapters/native_cpu/ur_interface_loader.cpp"
267273
"ur/adapters/native_cpu/usm.cpp"
268274
"ur/adapters/native_cpu/usm_p2p.cpp"

sycl/plugins/unified_runtime/pi2ur.hpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,22 @@ namespace pi2ur {
719719

720720
inline pi_result piTearDown(void *PluginParameter) {
721721
std::ignore = PluginParameter;
722+
// Fetch the single known adapter (the one which is statically linked) so we
723+
// can release it. Fetching it for a second time (after piPlatformsGet)
724+
// increases the reference count, so we need to release it twice.
725+
// pi_unified_runtime has its own implementation of piTearDown.
726+
static std::once_flag AdapterReleaseFlag;
727+
ur_adapter_handle_t Adapter;
728+
ur_result_t Ret = UR_RESULT_SUCCESS;
729+
std::call_once(AdapterReleaseFlag, [&]() {
730+
Ret = urAdapterGet(1, &Adapter, nullptr);
731+
if (Ret == UR_RESULT_SUCCESS) {
732+
Ret = urAdapterRelease(Adapter);
733+
Ret = urAdapterRelease(Adapter);
734+
}
735+
});
736+
HANDLE_ERRORS(Ret);
737+
722738
// TODO: Dont check for errors in urTearDown, since
723739
// when using Level Zero plugin, the second urTearDown
724740
// will fail as ur_loader.so has already been unloaded,
@@ -731,9 +747,20 @@ inline pi_result piTearDown(void *PluginParameter) {
731747
inline pi_result piPlatformsGet(pi_uint32 NumEntries, pi_platform *Platforms,
732748
pi_uint32 *NumPlatforms) {
733749

734-
urInit(0);
750+
urInit(0, nullptr);
751+
// We're not going through the UR loader so we're guaranteed to have exactly
752+
// one adapter (whichever is statically linked). The PI plugin for UR has its
753+
// own implementation of piPlatformsGet.
754+
static ur_adapter_handle_t Adapter;
755+
static std::once_flag AdapterGetFlag;
756+
ur_result_t Ret = UR_RESULT_SUCCESS;
757+
std::call_once(AdapterGetFlag,
758+
[&Ret]() { Ret = urAdapterGet(1, &Adapter, nullptr); });
759+
HANDLE_ERRORS(Ret);
760+
735761
auto phPlatforms = reinterpret_cast<ur_platform_handle_t *>(Platforms);
736-
HANDLE_ERRORS(urPlatformGet(NumEntries, phPlatforms, NumPlatforms));
762+
HANDLE_ERRORS(
763+
urPlatformGet(&Adapter, 1, NumEntries, phPlatforms, NumPlatforms));
737764
return PI_SUCCESS;
738765
}
739766

@@ -894,8 +921,18 @@ inline pi_result piDeviceRelease(pi_device Device) {
894921
return PI_SUCCESS;
895922
}
896923

897-
inline pi_result piPluginGetLastError(char **message) {
898-
std::ignore = message;
924+
inline pi_result piPluginGetLastError(char **Message) {
925+
// We're not going through the UR loader so we're guaranteed to have exactly
926+
// one adapter (whichever is statically linked). The PI plugin for UR has its
927+
// own implementation of piPluginGetLastError. Materialize the adapter
928+
// reference for the urAdapterGetLastError call, then release it.
929+
ur_adapter_handle_t Adapter;
930+
urAdapterGet(1, &Adapter, nullptr);
931+
int32_t ErrorCode;
932+
urAdapterGetLastError(Adapter, const_cast<const char **>(Message),
933+
&ErrorCode);
934+
urAdapterRelease(Adapter);
935+
899936
return PI_SUCCESS;
900937
}
901938

sycl/plugins/unified_runtime/pi_unified_runtime.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,34 @@ static void DieUnsupported() {
1717
die("Unified Runtime: functionality is not supported");
1818
}
1919

20+
// Adapters may be released by piTearDown being called, or the global dtors
21+
// being called first. Handle releasing the adapters exactly once.
22+
static void releaseAdapters(std::vector<ur_adapter_handle_t> &Vec) {
23+
static std::once_flag ReleaseFlag{};
24+
std::call_once(ReleaseFlag, [&]() {
25+
for (auto Adapter : Vec) {
26+
urAdapterRelease(Adapter);
27+
}
28+
urTearDown(nullptr);
29+
});
30+
}
31+
32+
struct AdapterHolder {
33+
~AdapterHolder() { releaseAdapters(Vec); }
34+
std::vector<ur_adapter_handle_t> Vec{};
35+
} Adapters;
36+
2037
// All PI API interfaces are C interfaces
2138
extern "C" {
2239
__SYCL_EXPORT pi_result piPlatformsGet(pi_uint32 NumEntries,
2340
pi_platform *Platforms,
2441
pi_uint32 *NumPlatforms) {
25-
return pi2ur::piPlatformsGet(NumEntries, Platforms, NumPlatforms);
42+
// Get all the platforms from all available adapters
43+
urPlatformGet(Adapters.Vec.data(), static_cast<uint32_t>(Adapters.Vec.size()),
44+
NumEntries, reinterpret_cast<ur_platform_handle_t *>(Platforms),
45+
NumPlatforms);
46+
47+
return PI_SUCCESS;
2648
}
2749

2850
__SYCL_EXPORT pi_result piPlatformGetInfo(pi_platform Platform,
@@ -1122,6 +1144,12 @@ __SYCL_EXPORT pi_result piextPeerAccessGetInfo(
11221144
ParamValueSizeRet);
11231145
}
11241146

1147+
__SYCL_EXPORT pi_result piTearDown(void *) {
1148+
releaseAdapters(Adapters.Vec);
1149+
urTearDown(nullptr);
1150+
return PI_SUCCESS;
1151+
}
1152+
11251153
__SYCL_EXPORT pi_result piextMemImageAllocate(pi_context Context,
11261154
pi_device Device,
11271155
pi_image_format *ImageFormat,
@@ -1256,11 +1284,6 @@ __SYCL_EXPORT pi_result piextSignalExternalSemaphore(
12561284
Queue, SemHandle, NumEventsInWaitList, EventWaitList, Event);
12571285
}
12581286

1259-
// This interface is not in Unified Runtime currently
1260-
__SYCL_EXPORT pi_result piTearDown(void *PluginParameter) {
1261-
return pi2ur::piTearDown(PluginParameter);
1262-
}
1263-
12641287
// This interface is not in Unified Runtime currently
12651288
__SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) {
12661289
PI_ASSERT(PluginInit, PI_ERROR_INVALID_VALUE);
@@ -1279,6 +1302,15 @@ __SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) {
12791302

12801303
strncpy(PluginInit->PluginVersion, SupportedVersion, PluginVersionSize);
12811304

1305+
// Initialize UR and discover adapters
1306+
HANDLE_ERRORS(urInit(0, nullptr));
1307+
uint32_t NumAdapters;
1308+
HANDLE_ERRORS(urAdapterGet(0, nullptr, &NumAdapters));
1309+
if (NumAdapters > 0) {
1310+
Adapters.Vec.resize(NumAdapters);
1311+
HANDLE_ERRORS(urAdapterGet(NumAdapters, Adapters.Vec.data(), nullptr));
1312+
}
1313+
12821314
// Bind interfaces that are already supported and "die" for unsupported ones
12831315
#define _PI_API(api) \
12841316
(PluginInit->PiFunctionTable).api = (decltype(&::api))(&DieUnsupported);
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===--------- adapter.cpp - CUDA Adapter ----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===-----------------------------------------------------------------===//
8+
9+
#include <ur_api.h>
10+
11+
#include "common.hpp"
12+
13+
void enableCUDATracing();
14+
void disableCUDATracing();
15+
16+
struct ur_adapter_handle_t_ {
17+
std::atomic<uint32_t> RefCount = 0;
18+
std::mutex Mutex;
19+
};
20+
21+
ur_adapter_handle_t_ adapter{};
22+
23+
UR_APIEXPORT ur_result_t UR_APICALL urInit(ur_device_init_flags_t,
24+
ur_loader_config_handle_t) {
25+
return UR_RESULT_SUCCESS;
26+
}
27+
28+
UR_APIEXPORT ur_result_t UR_APICALL urTearDown(void *) {
29+
return UR_RESULT_SUCCESS;
30+
}
31+
32+
UR_APIEXPORT ur_result_t UR_APICALL
33+
urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
34+
uint32_t *pNumAdapters) {
35+
if (NumEntries > 0 && phAdapters) {
36+
std::lock_guard<std::mutex> Lock{adapter.Mutex};
37+
if (adapter.RefCount++ == 0) {
38+
enableCUDATracing();
39+
}
40+
41+
*phAdapters = &adapter;
42+
}
43+
44+
if (pNumAdapters) {
45+
*pNumAdapters = 1;
46+
}
47+
48+
return UR_RESULT_SUCCESS;
49+
}
50+
51+
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
52+
adapter.RefCount++;
53+
54+
return UR_RESULT_SUCCESS;
55+
}
56+
57+
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
58+
std::lock_guard<std::mutex> Lock{adapter.Mutex};
59+
if (--adapter.RefCount == 0) {
60+
disableCUDATracing();
61+
}
62+
return UR_RESULT_SUCCESS;
63+
}
64+
65+
UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError(
66+
ur_adapter_handle_t, const char **ppMessage, int32_t *pError) {
67+
*ppMessage = ErrorMessage;
68+
*pError = ErrorMessageCode;
69+
return UR_RESULT_SUCCESS;
70+
}
71+
72+
UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
73+
ur_adapter_info_t propName,
74+
size_t propSize,
75+
void *pPropValue,
76+
size_t *pPropSizeRet) {
77+
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);
78+
79+
switch (propName) {
80+
case UR_ADAPTER_INFO_BACKEND:
81+
return ReturnValue(UR_ADAPTER_BACKEND_CUDA);
82+
case UR_ADAPTER_INFO_REFERENCE_COUNT:
83+
return ReturnValue(adapter.RefCount.load());
84+
default:
85+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
86+
}
87+
88+
return UR_RESULT_SUCCESS;
89+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
//===--------- adapter.cpp - CUDA Adapter ----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===-----------------------------------------------------------------===//
8+
9+
struct ur_adapter_handle_t_;
10+
11+
extern ur_adapter_handle_t_ adapter;

sycl/plugins/unified_runtime/ur/adapters/cuda/common.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,3 @@ void setPluginSpecificMessage(CUresult cu_res) {
134134
setErrorMessage(message, UR_RESULT_ERROR_ADAPTER_SPECIFIC);
135135
free(message);
136136
}
137-
138-
// Returns plugin specific error and warning messages; common implementation
139-
// that can be shared between adapters
140-
ur_result_t urGetLastResult(ur_platform_handle_t, const char **ppMessage) {
141-
*ppMessage = &ErrorMessage[0];
142-
return ErrorMessageCode;
143-
}

0 commit comments

Comments
 (0)