Skip to content

Commit 6530d39

Browse files
callumfarefabiomestre
authored andcommitted
[SYCL][UR] Bump UR and implement adapter handles (#10349)
Bump the Unified Runtime commit, and make adapter changes needed for the newly added adapter handles (see #715 for details) This fixes #10066 by providing an implementation of `piPluginGetLastError` in pi2ur.
1 parent 056f1f3 commit 6530d39

File tree

7 files changed

+117
-25
lines changed

7 files changed

+117
-25
lines changed

adapter.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===--------- adapter.cpp - CUDA Adapter ----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===-----------------------------------------------------------------===//
8+
9+
#include <ur_api.h>
10+
11+
#include "common.hpp"
12+
13+
void enableCUDATracing();
14+
void disableCUDATracing();
15+
16+
struct ur_adapter_handle_t_ {
17+
std::atomic<uint32_t> RefCount = 0;
18+
std::mutex Mutex;
19+
};
20+
21+
ur_adapter_handle_t_ adapter{};
22+
23+
UR_APIEXPORT ur_result_t UR_APICALL urInit(ur_device_init_flags_t,
24+
ur_loader_config_handle_t) {
25+
return UR_RESULT_SUCCESS;
26+
}
27+
28+
UR_APIEXPORT ur_result_t UR_APICALL urTearDown(void *) {
29+
return UR_RESULT_SUCCESS;
30+
}
31+
32+
UR_APIEXPORT ur_result_t UR_APICALL
33+
urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
34+
uint32_t *pNumAdapters) {
35+
if (NumEntries > 0 && phAdapters) {
36+
std::lock_guard<std::mutex> Lock{adapter.Mutex};
37+
if (adapter.RefCount++ == 0) {
38+
enableCUDATracing();
39+
}
40+
41+
*phAdapters = &adapter;
42+
}
43+
44+
if (pNumAdapters) {
45+
*pNumAdapters = 1;
46+
}
47+
48+
return UR_RESULT_SUCCESS;
49+
}
50+
51+
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
52+
adapter.RefCount++;
53+
54+
return UR_RESULT_SUCCESS;
55+
}
56+
57+
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
58+
std::lock_guard<std::mutex> Lock{adapter.Mutex};
59+
if (--adapter.RefCount == 0) {
60+
disableCUDATracing();
61+
}
62+
return UR_RESULT_SUCCESS;
63+
}
64+
65+
UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError(
66+
ur_adapter_handle_t, const char **ppMessage, int32_t *pError) {
67+
*ppMessage = ErrorMessage;
68+
*pError = ErrorMessageCode;
69+
return UR_RESULT_SUCCESS;
70+
}
71+
72+
UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
73+
ur_adapter_info_t propName,
74+
size_t propSize,
75+
void *pPropValue,
76+
size_t *pPropSizeRet) {
77+
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);
78+
79+
switch (propName) {
80+
case UR_ADAPTER_INFO_BACKEND:
81+
return ReturnValue(UR_ADAPTER_BACKEND_CUDA);
82+
case UR_ADAPTER_INFO_REFERENCE_COUNT:
83+
return ReturnValue(adapter.RefCount.load());
84+
default:
85+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
86+
}
87+
88+
return UR_RESULT_SUCCESS;
89+
}

adapter.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
//===--------- adapter.cpp - CUDA Adapter ----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===-----------------------------------------------------------------===//
8+
9+
struct ur_adapter_handle_t_;
10+
11+
extern ur_adapter_handle_t_ adapter;

common.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,3 @@ void setPluginSpecificMessage(CUresult cu_res) {
134134
setErrorMessage(message, UR_RESULT_ERROR_ADAPTER_SPECIFIC);
135135
free(message);
136136
}
137-
138-
// Returns plugin specific error and warning messages; common implementation
139-
// that can be shared between adapters
140-
ur_result_t urGetLastResult(ur_platform_handle_t, const char **ppMessage) {
141-
*ppMessage = &ErrorMessage[0];
142-
return ErrorMessageCode;
143-
}

device.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <cassert>
1010
#include <sstream>
1111

12+
#include "adapter.hpp"
1213
#include "context.hpp"
1314
#include "device.hpp"
1415
#include "platform.hpp"
@@ -1206,13 +1207,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
12061207

12071208
// Get list of platforms
12081209
uint32_t NumPlatforms = 0;
1209-
ur_result_t Result = urPlatformGet(0, nullptr, &NumPlatforms);
1210+
ur_adapter_handle_t AdapterHandle = &adapter;
1211+
ur_result_t Result =
1212+
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
12101213
if (Result != UR_RESULT_SUCCESS)
12111214
return Result;
12121215

12131216
ur_platform_handle_t *Plat = static_cast<ur_platform_handle_t *>(
12141217
malloc(NumPlatforms * sizeof(ur_platform_handle_t)));
1215-
Result = urPlatformGet(NumPlatforms, Plat, nullptr);
1218+
Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, Plat, nullptr);
12161219
if (Result != UR_RESULT_SUCCESS)
12171220
return Result;
12181221

platform.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
#include <cuda.h>
1616
#include <sstream>
1717

18-
void enableCUDATracing();
19-
void disableCUDATracing();
20-
2118
UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
2219
ur_platform_handle_t hPlatform, ur_platform_info_t PlatformInfoType,
2320
size_t Size, void *pPlatformInfo, size_t *pSizeRet) {
@@ -57,8 +54,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
5754
/// However because multiple devices in a context is not currently supported,
5855
/// place each device in a separate platform.
5956
UR_APIEXPORT ur_result_t UR_APICALL
60-
urPlatformGet(uint32_t NumEntries, ur_platform_handle_t *phPlatforms,
61-
uint32_t *pNumPlatforms) {
57+
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
58+
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {
6259

6360
try {
6461
static std::once_flag InitFlag;
@@ -188,16 +185,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformCreateWithNativeHandle(
188185
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
189186
}
190187

191-
UR_APIEXPORT ur_result_t UR_APICALL urInit(ur_device_init_flags_t) {
192-
enableCUDATracing();
193-
return UR_RESULT_SUCCESS;
194-
}
195-
196-
UR_APIEXPORT ur_result_t UR_APICALL urTearDown(void *) {
197-
disableCUDATracing();
198-
return UR_RESULT_SUCCESS;
199-
}
200-
201188
// Get CUDA plugin specific backend option.
202189
// Current support is only for optimization options.
203190
// Return empty string for cuda.

ur_interface_loader.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,12 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetGlobalProcAddrTable(
202202
}
203203
pDdiTable->pfnInit = urInit;
204204
pDdiTable->pfnTearDown = urTearDown;
205+
pDdiTable->pfnAdapterGet = urAdapterGet;
206+
pDdiTable->pfnAdapterRelease = urAdapterRelease;
207+
pDdiTable->pfnAdapterRetain = urAdapterRetain;
208+
pDdiTable->pfnAdapterGetLastError = urAdapterGetLastError;
209+
pDdiTable->pfnAdapterGetInfo = urAdapterGetInfo;
210+
205211
return UR_RESULT_SUCCESS;
206212
}
207213

usm.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <cassert>
1010

11+
#include "adapter.hpp"
1112
#include "common.hpp"
1213
#include "context.hpp"
1314
#include "device.hpp"
@@ -204,7 +205,9 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
204205
// the same index
205206
std::vector<ur_platform_handle_t> Platforms;
206207
Platforms.resize(DeviceIndex + 1);
207-
Result = urPlatformGet(DeviceIndex + 1, Platforms.data(), nullptr);
208+
ur_adapter_handle_t AdapterHandle = &adapter;
209+
Result = urPlatformGet(&AdapterHandle, 1, DeviceIndex + 1,
210+
Platforms.data(), nullptr);
208211

209212
// get the device from the platform
210213
ur_device_handle_t Device = Platforms[DeviceIndex]->Devices[0].get();

0 commit comments

Comments
 (0)