Skip to content

Commit 999fcbb

Browse files
author
Dan Holmes
committed
Add urDeviceGetSelected (loader only, no implementation yet)
1 parent a7f5097 commit 999fcbb

File tree

11 files changed

+399
-0
lines changed

11 files changed

+399
-0
lines changed

include/ur.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ class ur_function_v(IntEnum):
196196
ADAPTER_RETAIN = 179 ## Enumerator for ::urAdapterRetain
197197
ADAPTER_GET_LAST_ERROR = 180 ## Enumerator for ::urAdapterGetLastError
198198
ADAPTER_GET_INFO = 181 ## Enumerator for ::urAdapterGetInfo
199+
DEVICE_GET_SELECTED = 182 ## Enumerator for ::urDeviceGetSelected
199200

200201
class ur_function_t(c_int):
201202
def __str__(self):

include/ur_api.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ typedef enum ur_function_t {
205205
UR_FUNCTION_ADAPTER_RETAIN = 179, ///< Enumerator for ::urAdapterRetain
206206
UR_FUNCTION_ADAPTER_GET_LAST_ERROR = 180, ///< Enumerator for ::urAdapterGetLastError
207207
UR_FUNCTION_ADAPTER_GET_INFO = 181, ///< Enumerator for ::urAdapterGetInfo
208+
UR_FUNCTION_DEVICE_GET_SELECTED = 182, ///< Enumerator for ::urDeviceGetSelected
208209
/// @cond
209210
UR_FUNCTION_FORCE_UINT32 = 0x7fffffff
210211
/// @endcond
@@ -1302,6 +1303,46 @@ urDeviceGet(
13021303
///< pNumDevices will be updated with the total number of devices available.
13031304
);
13041305

1306+
///////////////////////////////////////////////////////////////////////////////
1307+
/// @brief Retrieves devices within a platform selected by ONEAPI_DEVICE_SELECTOR
1308+
///
1309+
/// @details
1310+
/// - Multiple calls to this function will return identical device handles,
1311+
/// in the same order.
1312+
/// - The number and order of handles returned from this function will be
1313+
/// affected by environment variables that filter or select which devices
1314+
/// are exposed through this API.
1315+
/// - A reference is taken for each returned device and must be released
1316+
/// with a subsequent call to ::urDeviceRelease.
1317+
/// - The application may call this function from simultaneous threads, the
1318+
/// implementation must be thread-safe.
1319+
///
1320+
/// @returns
1321+
/// - ::UR_RESULT_SUCCESS
1322+
/// - ::UR_RESULT_ERROR_UNINITIALIZED
1323+
/// - ::UR_RESULT_ERROR_DEVICE_LOST
1324+
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
1325+
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
1326+
/// + `NULL == hPlatform`
1327+
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
1328+
/// + `::UR_DEVICE_TYPE_VPU < DeviceType`
1329+
/// - ::UR_RESULT_ERROR_INVALID_VALUE
1330+
UR_APIEXPORT ur_result_t UR_APICALL
1331+
urDeviceGetSelected(
1332+
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
1333+
ur_device_type_t DeviceType, ///< [in] the type of the devices.
1334+
uint32_t NumEntries, ///< [in] the number of devices to be added to phDevices.
1335+
///< If phDevices in not NULL then NumEntries should be greater than zero,
1336+
///< otherwise ::UR_RESULT_ERROR_INVALID_VALUE,
1337+
///< will be returned.
1338+
ur_device_handle_t *phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices.
1339+
///< If NumEntries is less than the number of devices available, then only
1340+
///< that number of devices will be retrieved.
1341+
uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices.
1342+
///< pNumDevices will be updated with the total number of selected devices
1343+
///< available for the given platform.
1344+
);
1345+
13051346
///////////////////////////////////////////////////////////////////////////////
13061347
/// @brief Supported device info
13071348
typedef enum ur_device_info_t {
@@ -10179,6 +10220,18 @@ typedef struct ur_device_get_params_t {
1017910220
uint32_t **ppNumDevices;
1018010221
} ur_device_get_params_t;
1018110222

10223+
///////////////////////////////////////////////////////////////////////////////
10224+
/// @brief Function parameters for urDeviceGetSelected
10225+
/// @details Each entry is a pointer to the parameter passed to the function;
10226+
/// allowing the callback the ability to modify the parameter's value
10227+
typedef struct ur_device_get_selected_params_t {
10228+
ur_platform_handle_t *phPlatform;
10229+
ur_device_type_t *pDeviceType;
10230+
uint32_t *pNumEntries;
10231+
ur_device_handle_t **pphDevices;
10232+
uint32_t **ppNumDevices;
10233+
} ur_device_get_selected_params_t;
10234+
1018210235
///////////////////////////////////////////////////////////////////////////////
1018310236
/// @brief Function parameters for urDeviceGetInfo
1018410237
/// @details Each entry is a pointer to the parameter passed to the function;

scripts/core/device.yml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,45 @@ params:
141141
returns:
142142
- $X_RESULT_ERROR_INVALID_VALUE
143143
--- #--------------------------------------------------------------------------
144+
type: function
145+
desc: "Retrieves devices within a platform selected by ONEAPI_DEVICE_SELECTOR"
146+
class: $xDevice
147+
loader_only: True
148+
name: GetSelected
149+
decl: static
150+
ordinal: "0"
151+
details:
152+
- "Multiple calls to this function will return identical device handles, in the same order."
153+
- "The number and order of handles returned from this function will be affected by environment variables that filter or select which devices are exposed through this API."
154+
- "A reference is taken for each returned device and must be released with a subsequent call to $xDeviceRelease."
155+
- "The application may call this function from simultaneous threads, the implementation must be thread-safe."
156+
params:
157+
- type: $x_platform_handle_t
158+
name: hPlatform
159+
desc: "[in] handle of the platform instance"
160+
- type: "$x_device_type_t"
161+
name: DeviceType
162+
desc: |
163+
[in] the type of the devices.
164+
- type: "uint32_t"
165+
name: NumEntries
166+
desc: |
167+
[in] the number of devices to be added to phDevices.
168+
If phDevices in not NULL then NumEntries should be greater than zero, otherwise $X_RESULT_ERROR_INVALID_VALUE,
169+
will be returned.
170+
- type: "$x_device_handle_t*"
171+
name: phDevices
172+
desc: |
173+
[out][optional][range(0, NumEntries)] array of handle of devices.
174+
If NumEntries is less than the number of devices available, then only that number of devices will be retrieved.
175+
- type: "uint32_t*"
176+
name: pNumDevices
177+
desc: |
178+
[out][optional] pointer to the number of devices.
179+
pNumDevices will be updated with the total number of selected devices available for the given platform.
180+
returns:
181+
- $X_RESULT_ERROR_INVALID_VALUE
182+
--- #--------------------------------------------------------------------------
144183
type: enum
145184
desc: "Supported device info"
146185
class: $xDevice

scripts/core/registry.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,9 @@ etors:
529529
- name: ADAPTER_GET_INFO
530530
desc: Enumerator for $xAdapterGetInfo
531531
value: '181'
532+
- name: DEVICE_GET_SELECTED
533+
desc: Enumerator for $xDeviceGetSelected
534+
value: '182'
532535
---
533536
type: enum
534537
desc: Defines structure types

source/adapters/null/ur_nullddi.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,44 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGet(
393393
return exceptionToResult(std::current_exception());
394394
}
395395

396+
///////////////////////////////////////////////////////////////////////////////
397+
/// @brief Intercept function for urDeviceGetSelected
398+
__urdlllocal ur_result_t UR_APICALL urDeviceGetSelected(
399+
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
400+
ur_device_type_t DeviceType, ///< [in] the type of the devices.
401+
uint32_t
402+
NumEntries, ///< [in] the number of devices to be added to phDevices.
403+
///< If phDevices in not NULL then NumEntries should be greater than zero,
404+
///< otherwise ::UR_RESULT_ERROR_INVALID_VALUE,
405+
///< will be returned.
406+
ur_device_handle_t *
407+
phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices.
408+
///< If NumEntries is less than the number of devices available, then only
409+
///< that number of devices will be retrieved.
410+
uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices.
411+
///< pNumDevices will be updated with the total number of selected devices
412+
///< available for the given platform.
413+
) try {
414+
ur_result_t result = UR_RESULT_SUCCESS;
415+
416+
// if the driver has created a custom function, then call it instead of using the generic path
417+
auto pfnGetSelected = d_context.urDdiTable.Device.pfnGetSelected;
418+
if (nullptr != pfnGetSelected) {
419+
result = pfnGetSelected(hPlatform, DeviceType, NumEntries, phDevices,
420+
pNumDevices);
421+
} else {
422+
// generic implementation
423+
for (size_t i = 0; (nullptr != phDevices) && (i < NumEntries); ++i) {
424+
phDevices[i] =
425+
reinterpret_cast<ur_device_handle_t>(d_context.get());
426+
}
427+
}
428+
429+
return result;
430+
} catch (...) {
431+
return exceptionToResult(std::current_exception());
432+
}
433+
396434
///////////////////////////////////////////////////////////////////////////////
397435
/// @brief Intercept function for urDeviceGetInfo
398436
__urdlllocal ur_result_t UR_APICALL urDeviceGetInfo(

source/common/ur_params.hpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,10 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_function_t value) {
11121112
case UR_FUNCTION_ADAPTER_GET_INFO:
11131113
os << "UR_FUNCTION_ADAPTER_GET_INFO";
11141114
break;
1115+
1116+
case UR_FUNCTION_DEVICE_GET_SELECTED:
1117+
os << "UR_FUNCTION_DEVICE_GET_SELECTED";
1118+
break;
11151119
default:
11161120
os << "unknown enumerator";
11171121
break;
@@ -14951,6 +14955,44 @@ inline std::ostream &operator<<(std::ostream &os,
1495114955
return os;
1495214956
}
1495314957

14958+
inline std::ostream &
14959+
operator<<(std::ostream &os,
14960+
const struct ur_device_get_selected_params_t *params) {
14961+
14962+
os << ".hPlatform = ";
14963+
14964+
ur_params::serializePtr(os, *(params->phPlatform));
14965+
14966+
os << ", ";
14967+
os << ".DeviceType = ";
14968+
14969+
os << *(params->pDeviceType);
14970+
14971+
os << ", ";
14972+
os << ".NumEntries = ";
14973+
14974+
os << *(params->pNumEntries);
14975+
14976+
os << ", ";
14977+
os << ".phDevices = {";
14978+
for (size_t i = 0;
14979+
*(params->pphDevices) != NULL && i < *params->pNumEntries; ++i) {
14980+
if (i != 0) {
14981+
os << ", ";
14982+
}
14983+
14984+
ur_params::serializePtr(os, (*(params->pphDevices))[i]);
14985+
}
14986+
os << "}";
14987+
14988+
os << ", ";
14989+
os << ".pNumDevices = ";
14990+
14991+
ur_params::serializePtr(os, *(params->ppNumDevices));
14992+
14993+
return os;
14994+
}
14995+
1495414996
inline std::ostream &
1495514997
operator<<(std::ostream &os, const struct ur_device_get_info_params_t *params) {
1495614998

@@ -15688,6 +15730,9 @@ inline int serializeFunctionParams(std::ostream &os, uint32_t function,
1568815730
case UR_FUNCTION_DEVICE_GET: {
1568915731
os << (const struct ur_device_get_params_t *)params;
1569015732
} break;
15733+
case UR_FUNCTION_DEVICE_GET_SELECTED: {
15734+
os << (const struct ur_device_get_selected_params_t *)params;
15735+
} break;
1569115736
case UR_FUNCTION_DEVICE_GET_INFO: {
1569215737
os << (const struct ur_device_get_info_params_t *)params;
1569315738
} break;

source/loader/layers/tracing/ur_trcddi.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,44 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGet(
435435
return result;
436436
}
437437

438+
///////////////////////////////////////////////////////////////////////////////
439+
/// @brief Intercept function for urDeviceGetSelected
440+
__urdlllocal ur_result_t UR_APICALL urDeviceGetSelected(
441+
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
442+
ur_device_type_t DeviceType, ///< [in] the type of the devices.
443+
uint32_t
444+
NumEntries, ///< [in] the number of devices to be added to phDevices.
445+
///< If phDevices in not NULL then NumEntries should be greater than zero,
446+
///< otherwise ::UR_RESULT_ERROR_INVALID_VALUE,
447+
///< will be returned.
448+
ur_device_handle_t *
449+
phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices.
450+
///< If NumEntries is less than the number of devices available, then only
451+
///< that number of devices will be retrieved.
452+
uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices.
453+
///< pNumDevices will be updated with the total number of selected devices
454+
///< available for the given platform.
455+
) {
456+
auto pfnGetSelected = context.urDdiTable.Device.pfnGetSelected;
457+
458+
if (nullptr == pfnGetSelected) {
459+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
460+
}
461+
462+
ur_device_get_selected_params_t params = {
463+
&hPlatform, &DeviceType, &NumEntries, &phDevices, &pNumDevices};
464+
uint64_t instance = context.notify_begin(UR_FUNCTION_DEVICE_GET_SELECTED,
465+
"urDeviceGetSelected", &params);
466+
467+
ur_result_t result = pfnGetSelected(hPlatform, DeviceType, NumEntries,
468+
phDevices, pNumDevices);
469+
470+
context.notify_end(UR_FUNCTION_DEVICE_GET_SELECTED, "urDeviceGetSelected",
471+
&params, &result, instance);
472+
473+
return result;
474+
}
475+
438476
///////////////////////////////////////////////////////////////////////////////
439477
/// @brief Intercept function for urDeviceGetInfo
440478
__urdlllocal ur_result_t UR_APICALL urDeviceGetInfo(

source/loader/layers/validation/ur_valddi.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,46 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGet(
477477
return result;
478478
}
479479

480+
///////////////////////////////////////////////////////////////////////////////
481+
/// @brief Intercept function for urDeviceGetSelected
482+
__urdlllocal ur_result_t UR_APICALL urDeviceGetSelected(
483+
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
484+
ur_device_type_t DeviceType, ///< [in] the type of the devices.
485+
uint32_t
486+
NumEntries, ///< [in] the number of devices to be added to phDevices.
487+
///< If phDevices in not NULL then NumEntries should be greater than zero,
488+
///< otherwise ::UR_RESULT_ERROR_INVALID_VALUE,
489+
///< will be returned.
490+
ur_device_handle_t *
491+
phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices.
492+
///< If NumEntries is less than the number of devices available, then only
493+
///< that number of devices will be retrieved.
494+
uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices.
495+
///< pNumDevices will be updated with the total number of selected devices
496+
///< available for the given platform.
497+
) {
498+
auto pfnGetSelected = context.urDdiTable.Device.pfnGetSelected;
499+
500+
if (nullptr == pfnGetSelected) {
501+
return UR_RESULT_ERROR_UNINITIALIZED;
502+
}
503+
504+
if (context.enableParameterValidation) {
505+
if (NULL == hPlatform) {
506+
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
507+
}
508+
509+
if (UR_DEVICE_TYPE_VPU < DeviceType) {
510+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
511+
}
512+
}
513+
514+
ur_result_t result = pfnGetSelected(hPlatform, DeviceType, NumEntries,
515+
phDevices, pNumDevices);
516+
517+
return result;
518+
}
519+
480520
///////////////////////////////////////////////////////////////////////////////
481521
/// @brief Intercept function for urDeviceGetInfo
482522
__urdlllocal ur_result_t UR_APICALL urDeviceGetInfo(

source/loader/ur_ldrddi.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,58 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGet(
518518
return result;
519519
}
520520

521+
///////////////////////////////////////////////////////////////////////////////
522+
/// @brief Intercept function for urDeviceGetSelected
523+
__urdlllocal ur_result_t UR_APICALL urDeviceGetSelected(
524+
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
525+
ur_device_type_t DeviceType, ///< [in] the type of the devices.
526+
uint32_t
527+
NumEntries, ///< [in] the number of devices to be added to phDevices.
528+
///< If phDevices in not NULL then NumEntries should be greater than zero,
529+
///< otherwise ::UR_RESULT_ERROR_INVALID_VALUE,
530+
///< will be returned.
531+
ur_device_handle_t *
532+
phDevices, ///< [out][optional][range(0, NumEntries)] array of handle of devices.
533+
///< If NumEntries is less than the number of devices available, then only
534+
///< that number of devices will be retrieved.
535+
uint32_t *pNumDevices ///< [out][optional] pointer to the number of devices.
536+
///< pNumDevices will be updated with the total number of selected devices
537+
///< available for the given platform.
538+
) {
539+
ur_result_t result = UR_RESULT_SUCCESS;
540+
541+
// extract platform's function pointer table
542+
auto dditable =
543+
reinterpret_cast<ur_platform_object_t *>(hPlatform)->dditable;
544+
auto pfnGetSelected = dditable->ur.Device.pfnGetSelected;
545+
if (nullptr == pfnGetSelected) {
546+
return UR_RESULT_ERROR_UNINITIALIZED;
547+
}
548+
549+
// convert loader handle to platform handle
550+
hPlatform = reinterpret_cast<ur_platform_object_t *>(hPlatform)->handle;
551+
552+
// forward to device-platform
553+
result = pfnGetSelected(hPlatform, DeviceType, NumEntries, phDevices,
554+
pNumDevices);
555+
556+
if (UR_RESULT_SUCCESS != result) {
557+
return result;
558+
}
559+
560+
try {
561+
// convert platform handles to loader handles
562+
for (size_t i = 0; (nullptr != phDevices) && (i < NumEntries); ++i) {
563+
phDevices[i] = reinterpret_cast<ur_device_handle_t>(
564+
ur_device_factory.getInstance(phDevices[i], dditable));
565+
}
566+
} catch (std::bad_alloc &) {
567+
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
568+
}
569+
570+
return result;
571+
}
572+
521573
///////////////////////////////////////////////////////////////////////////////
522574
/// @brief Intercept function for urDeviceGetInfo
523575
__urdlllocal ur_result_t UR_APICALL urDeviceGetInfo(

0 commit comments

Comments
 (0)