Skip to content

Commit 55f3e87

Browse files
JackAKirknpmillerJackAKirksteffenlarsen
authored andcommitted
[SYCL][CUDA] Implement sycl_ext_oneapi_peer_access extension (#8303)
This implements the current extension doc from intel/llvm#6104 in the CUDA backend only. Fixes intel/llvm#7543. Fixes intel/llvm#6749. --------- Signed-off-by: JackAKirk <[email protected]> Co-authored-by: Nicolas Miller <[email protected]> Co-authored-by: JackAKirk <[email protected]> Co-authored-by: Steffen Larsen <[email protected]>
1 parent 0cdc873 commit 55f3e87

File tree

4 files changed

+98
-0
lines changed

4 files changed

+98
-0
lines changed

common.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,20 @@ thread_local char ErrorMessage[MaxMessageSize];
9898
ErrorMessageCode = ErrorCode;
9999
}
100100

101+
void setPluginSpecificMessage(CUresult cu_res) {
102+
const char *error_string;
103+
const char *error_name;
104+
cuGetErrorName(cu_res, &error_name);
105+
cuGetErrorString(cu_res, &error_string);
106+
char *message = (char *)malloc(strlen(error_string) + strlen(error_name) + 2);
107+
strcpy(message, error_name);
108+
strcat(message, "\n");
109+
strcat(message, error_string);
110+
111+
setErrorMessage(message, UR_RESULT_ERROR_ADAPTER_SPECIFIC);
112+
free(message);
113+
}
114+
101115
// Returns plugin specific error and warning messages; common implementation
102116
// that can be shared between adapters
103117
ur_result_t urGetLastResult(ur_platform_handle_t, const char **ppMessage) {

common.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ extern thread_local char ErrorMessage[MaxMessageSize];
3636
[[maybe_unused]] void setErrorMessage(const char *pMessage,
3737
ur_result_t ErrorCode);
3838

39+
void setPluginSpecificMessage(CUresult cu_res);
40+
3941
/// ------ Error handling, matching OpenCL plugin semantics.
4042
namespace sycl {
4143
__SYCL_INLINE_VER_NAMESPACE(_V1) {

ur_interface_loader.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,19 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
288288
return retVal;
289289
}
290290

291+
UR_DLLEXPORT ur_result_t UR_APICALL urGetUsmP2PExpProcAddrTable(
292+
ur_api_version_t version, ur_usm_p2p_exp_dditable_t *pDdiTable) {
293+
auto retVal = validateProcInputs(version, pDdiTable);
294+
if (UR_RESULT_SUCCESS != retVal) {
295+
return retVal;
296+
}
297+
pDdiTable->pfnEnablePeerAccessExp = urUsmP2PEnablePeerAccessExp;
298+
pDdiTable->pfnDisablePeerAccessExp = urUsmP2PDisablePeerAccessExp;
299+
pDdiTable->pfnPeerAccessGetInfoExp = urUsmP2PPeerAccessGetInfoExp;
300+
301+
return retVal;
302+
}
303+
291304
#if defined(__cplusplus)
292305
} // extern "C"
293306
#endif

usm_p2p.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===--------- usm_p2p.cpp - CUDA Adapter---------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------===//
8+
9+
#include "common.hpp"
10+
#include "context.hpp"
11+
12+
UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PEnablePeerAccessExp(
13+
ur_device_handle_t commandDevice, ur_device_handle_t peerDevice) {
14+
15+
ur_result_t result = UR_RESULT_SUCCESS;
16+
try {
17+
ScopedContext active(commandDevice->getContext());
18+
UR_CHECK_ERROR(cuCtxEnablePeerAccess(peerDevice->getContext(), 0));
19+
} catch (ur_result_t err) {
20+
result = err;
21+
}
22+
return result;
23+
}
24+
25+
UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PDisablePeerAccessExp(
26+
ur_device_handle_t commandDevice, ur_device_handle_t peerDevice) {
27+
28+
ur_result_t result = UR_RESULT_SUCCESS;
29+
try {
30+
ScopedContext active(commandDevice->getContext());
31+
UR_CHECK_ERROR(cuCtxDisablePeerAccess(peerDevice->getContext()));
32+
} catch (ur_result_t err) {
33+
result = err;
34+
}
35+
return result;
36+
}
37+
38+
UR_APIEXPORT ur_result_t UR_APICALL urUsmP2PPeerAccessGetInfoExp(
39+
ur_device_handle_t commandDevice, ur_device_handle_t peerDevice,
40+
ur_exp_peer_info_t propName, size_t propSize, void *pPropValue,
41+
size_t *pPropSizeRet) {
42+
43+
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);
44+
45+
int value;
46+
CUdevice_P2PAttribute cu_attr;
47+
try {
48+
ScopedContext active(commandDevice->getContext());
49+
switch (propName) {
50+
case UR_EXP_PEER_INFO_UR_PEER_ACCESS_SUPPORTED: {
51+
cu_attr = CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED;
52+
break;
53+
}
54+
case UR_EXP_PEER_INFO_UR_PEER_ATOMICS_SUPPORTED: {
55+
cu_attr = CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED;
56+
break;
57+
}
58+
default: {
59+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
60+
}
61+
}
62+
63+
UR_CHECK_ERROR(cuDeviceGetP2PAttribute(
64+
&value, cu_attr, commandDevice->get(), peerDevice->get()));
65+
} catch (ur_result_t err) {
66+
return err;
67+
}
68+
return ReturnValue(value);
69+
}

0 commit comments

Comments
 (0)