Skip to content

Commit ba42634

Browse files
committed
Skip adapters with no platforms in urAdapterGet
1 parent 30fa2d8 commit ba42634

File tree

6 files changed

+111
-137
lines changed

6 files changed

+111
-137
lines changed

unified-runtime/scripts/templates/ldrddi.cpp.mako

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,21 @@ namespace ur_loader
5151
%if func_basename == "AdapterGet":
5252
auto context = getContext();
5353

54-
size_t adapterIndex = 0;
55-
if( nullptr != ${obj['params'][1]['name']} && ${obj['params'][0]['name']} !=0)
56-
{
57-
for( auto& platform : context->platforms )
58-
{
59-
if(platform.initStatus != ${X}_RESULT_SUCCESS)
60-
continue;
61-
platform.dditable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( 1, &${obj['params'][1]['name']}[adapterIndex], nullptr );
62-
adapterIndex++;
63-
if (adapterIndex == NumEntries) {
64-
break;
65-
}
66-
}
54+
uint32_t numAdapters = 0;
55+
for (auto &platform : context->platforms) {
56+
if (platform.initStatus != ${X}_RESULT_SUCCESS)
57+
continue;
58+
59+
uint32_t adapter;
60+
ur_adapter_handle_t *adapterHandle = numAdapters < NumEntries ? &${obj['params'][1]['name']}[numAdapters] : nullptr;
61+
platform.dditable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( 1, adapterHandle, &adapter );
62+
63+
numAdapters += adapter;
6764
}
6865

6966
if( ${obj['params'][2]['name']} != nullptr )
7067
{
71-
*${obj['params'][2]['name']} = static_cast<uint32_t>(context->platforms.size());
68+
*${obj['params'][2]['name']} = numAdapters;
7269
}
7370

7471
return ${X}_RESULT_SUCCESS;

unified-runtime/source/adapters/level_zero/adapter.cpp

Lines changed: 81 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,21 @@ class ur_legacy_sink : public logger::Sink {
5454
};
5555

5656
// Find the corresponding ZesDevice Handle for a given ZeDevice
57-
ur_result_t getZesDeviceHandle(zes_uuid_t coreDeviceUuid,
57+
ur_result_t getZesDeviceHandle(ur_adapter_handle_t_ *adapter,
58+
zes_uuid_t coreDeviceUuid,
5859
zes_device_handle_t *ZesDevice,
5960
uint32_t *SubDeviceId, ze_bool_t *SubDevice) {
6061
uint32_t ZesDriverCount = 0;
6162
std::vector<zes_driver_handle_t> ZesDrivers;
6263
std::vector<zes_device_handle_t> ZesDevices;
6364
ze_result_t ZesResult = ZE_RESULT_ERROR_INVALID_ARGUMENT;
64-
ZE2UR_CALL(GlobalAdapter->getSysManDriversFunctionPtr,
65-
(&ZesDriverCount, nullptr));
65+
ZE2UR_CALL(adapter->getSysManDriversFunctionPtr, (&ZesDriverCount, nullptr));
6666
ZesDrivers.resize(ZesDriverCount);
67-
ZE2UR_CALL(GlobalAdapter->getSysManDriversFunctionPtr,
67+
ZE2UR_CALL(adapter->getSysManDriversFunctionPtr,
6868
(&ZesDriverCount, ZesDrivers.data()));
6969
for (uint32_t I = 0; I < ZesDriverCount; ++I) {
7070
ZesResult = ZE_CALL_NOCHECK(
71-
GlobalAdapter->getDeviceByUUIdFunctionPtr,
71+
adapter->getDeviceByUUIdFunctionPtr,
7272
(ZesDrivers[I], coreDeviceUuid, ZesDevice, SubDevice, SubDeviceId));
7373
if (ZesResult == ZE_RESULT_SUCCESS) {
7474
return UR_RESULT_SUCCESS;
@@ -147,7 +147,7 @@ ur_result_t checkDeviceIntelGPUIpVersionOrNewer(uint32_t ipVersion) {
147147
* for the devices into the platform.
148148
* 10. The function handles exceptions and returns the appropriate result.
149149
*/
150-
ur_result_t initPlatforms(PlatformVec &platforms,
150+
ur_result_t initPlatforms(ur_adapter_handle_t_ *adapter, PlatformVec &platforms,
151151
ze_result_t ZesResult) noexcept try {
152152
std::vector<ze_driver_handle_t> ZeDrivers;
153153
std::vector<ze_driver_handle_t> ZeDriverGetHandles;
@@ -162,20 +162,20 @@ ur_result_t initPlatforms(PlatformVec &platforms,
162162
ZeDriverGetHandles.resize(ZeDriverGetCount);
163163
ZE2UR_CALL(zeDriverGet, (&ZeDriverGetCount, ZeDriverGetHandles.data()));
164164
}
165-
if (ZeDriverGetCount == 0 && GlobalAdapter->ZeInitDriversCount == 0) {
165+
if (ZeDriverGetCount == 0 && adapter->ZeInitDriversCount == 0) {
166166
UR_LOG(ERR, "\nNo Valid L0 Drivers found.\n");
167167
return UR_RESULT_SUCCESS;
168168
}
169169

170-
if (GlobalAdapter->InitDriversSupported) {
171-
ZeInitDriversHandles.resize(GlobalAdapter->ZeInitDriversCount);
172-
ZeDrivers.resize(GlobalAdapter->ZeInitDriversCount);
173-
ZE2UR_CALL(GlobalAdapter->initDriversFunctionPtr,
174-
(&GlobalAdapter->ZeInitDriversCount, ZeInitDriversHandles.data(),
175-
&GlobalAdapter->InitDriversDesc));
170+
if (adapter->InitDriversSupported) {
171+
ZeInitDriversHandles.resize(adapter->ZeInitDriversCount);
172+
ZeDrivers.resize(adapter->ZeInitDriversCount);
173+
ZE2UR_CALL(adapter->initDriversFunctionPtr,
174+
(&adapter->ZeInitDriversCount, ZeInitDriversHandles.data(),
175+
&adapter->InitDriversDesc));
176176
ZeDrivers.assign(ZeInitDriversHandles.begin(), ZeInitDriversHandles.end());
177-
if (ZeDriverGetCount > 0 && GlobalAdapter->ZeInitDriversCount > 0) {
178-
for (uint32_t X = 0; X < GlobalAdapter->ZeInitDriversCount; ++X) {
177+
if (ZeDriverGetCount > 0 && adapter->ZeInitDriversCount > 0) {
178+
for (uint32_t X = 0; X < adapter->ZeInitDriversCount; ++X) {
179179
for (uint32_t Y = 0; Y < ZeDriverGetCount; ++Y) {
180180
ZeStruct<ze_driver_properties_t> ZeDriverGetProperties;
181181
ZeStruct<ze_driver_properties_t> ZeInitDriverProperties;
@@ -234,9 +234,10 @@ ur_result_t initPlatforms(PlatformVec &platforms,
234234
ur_zes_device_handle_data_t ZesDeviceData;
235235
zes_uuid_t ZesUUID;
236236
std::memcpy(&ZesUUID, &device_properties.uuid, sizeof(zes_uuid_t));
237-
if (getZesDeviceHandle(
238-
ZesUUID, &ZesDeviceData.ZesDevice, &ZesDeviceData.SubDeviceId,
239-
&ZesDeviceData.SubDevice) == UR_RESULT_SUCCESS) {
237+
if (getZesDeviceHandle(adapter, ZesUUID, &ZesDeviceData.ZesDevice,
238+
&ZesDeviceData.SubDeviceId,
239+
&ZesDeviceData.SubDevice) ==
240+
UR_RESULT_SUCCESS) {
240241
platforms.back()->ZedeviceToZesDeviceMap.insert(
241242
std::make_pair(ZeDevices[D], std::move(ZesDeviceData)));
242243
}
@@ -342,9 +343,9 @@ Behavior Summary:
342343
*/
343344
ur_adapter_handle_t_::ur_adapter_handle_t_()
344345
: handle_base(), logger(logger::get_logger("level_zero")), RefCount(0) {
345-
ZeInitDriversResult = ZE_RESULT_ERROR_UNINITIALIZED;
346-
ZeInitResult = ZE_RESULT_ERROR_UNINITIALIZED;
347-
ZesResult = ZE_RESULT_ERROR_UNINITIALIZED;
346+
auto ZeInitDriversResult = ZE_RESULT_ERROR_UNINITIALIZED;
347+
auto ZeInitResult = ZE_RESULT_ERROR_UNINITIALIZED;
348+
auto ZesResult = ZE_RESULT_ERROR_UNINITIALIZED;
348349

349350
#ifdef UR_STATIC_LEVEL_ZERO
350351
// Given static linking of the L0 Loader, we must delay the loader's
@@ -370,7 +371,6 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
370371
setEnvVar("ZEL_ENABLE_BASIC_LEAK_CHECKER", "1");
371372
}
372373

373-
PlatformCache.Compute = [](Result<PlatformVec> &result) {
374374
uint32_t UserForcedSysManInit = 0;
375375
// Check if the user has disabled the default L0 Env initialization.
376376
const int UrSysManEnvInitEnabled = [&UserForcedSysManInit] {
@@ -386,14 +386,12 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
386386
// not exist in older loader runtimes.
387387
#ifndef UR_STATIC_LEVEL_ZERO
388388
#ifdef _WIN32
389-
GlobalAdapter->processHandle = GetModuleHandle(NULL);
389+
processHandle = GetModuleHandle(NULL);
390390
#else
391-
GlobalAdapter->processHandle = nullptr;
391+
processHandle = nullptr;
392392
#endif
393393
#endif
394394

395-
// initialize level zero only once.
396-
if (GlobalAdapter->ZeResult == std::nullopt) {
397395
// Setting these environment variables before running zeInit will enable
398396
// the validation layer in the Level Zero loader.
399397
if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
@@ -427,10 +425,10 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
427425
}
428426
UR_LOG(DEBUG, "\nzeInit with flags value of {}\n",
429427
static_cast<int>(L0InitFlags));
430-
GlobalAdapter->ZeInitResult = ZE_CALL_NOCHECK(zeInit, (L0InitFlags));
431-
if (GlobalAdapter->ZeInitResult != ZE_RESULT_SUCCESS) {
428+
ZeInitResult = ZE_CALL_NOCHECK(zeInit, (L0InitFlags));
429+
if (ZeInitResult != ZE_RESULT_SUCCESS) {
432430
const char *ErrorString = "Unknown";
433-
zeParseError(GlobalAdapter->ZeInitResult, ErrorString);
431+
zeParseError(ZeInitResult, ErrorString);
434432
UR_LOG(ERR, "\nzeInit failed with {}\n", ErrorString);
435433
}
436434

@@ -474,65 +472,47 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
474472

475473
if (useInitDrivers) {
476474
#ifdef UR_STATIC_LEVEL_ZERO
477-
GlobalAdapter->initDriversFunctionPtr = zeInitDrivers;
475+
initDriversFunctionPtr = zeInitDrivers;
478476
#else
479-
GlobalAdapter->initDriversFunctionPtr =
477+
initDriversFunctionPtr =
480478
(ze_pfnInitDrivers_t)ur_loader::LibLoader::getFunctionPtr(
481-
GlobalAdapter->processHandle, "zeInitDrivers");
479+
processHandle, "zeInitDrivers");
482480
#endif
483-
if (GlobalAdapter->initDriversFunctionPtr) {
481+
if (initDriversFunctionPtr) {
484482
UR_LOG(DEBUG, "\nzeInitDrivers with flags value of {}\n",
485-
static_cast<int>(GlobalAdapter->InitDriversDesc.flags));
486-
GlobalAdapter->ZeInitDriversResult =
487-
ZE_CALL_NOCHECK(GlobalAdapter->initDriversFunctionPtr,
488-
(&GlobalAdapter->ZeInitDriversCount, nullptr,
489-
&GlobalAdapter->InitDriversDesc));
490-
if (GlobalAdapter->ZeInitDriversResult == ZE_RESULT_SUCCESS) {
491-
GlobalAdapter->InitDriversSupported = true;
483+
static_cast<int>(InitDriversDesc.flags));
484+
ZeInitDriversResult =
485+
ZE_CALL_NOCHECK(initDriversFunctionPtr,
486+
(&ZeInitDriversCount, nullptr, &InitDriversDesc));
487+
if (ZeInitDriversResult == ZE_RESULT_SUCCESS) {
488+
InitDriversSupported = true;
492489
} else {
493490
const char *ErrorString = "Unknown";
494-
zeParseError(GlobalAdapter->ZeInitDriversResult, ErrorString);
491+
zeParseError(ZeInitDriversResult, ErrorString);
495492
UR_LOG(ERR, "\nzeInitDrivers failed with {}\n", ErrorString);
496493
}
497494
}
498495
}
499496

500-
if (GlobalAdapter->ZeInitResult == ZE_RESULT_SUCCESS ||
501-
GlobalAdapter->ZeInitDriversResult == ZE_RESULT_SUCCESS) {
502-
GlobalAdapter->ZeResult = ZE_RESULT_SUCCESS;
503-
} else {
504-
GlobalAdapter->ZeResult = ZE_RESULT_ERROR_UNINITIALIZED;
497+
if (ZeInitResult != ZE_RESULT_SUCCESS &&
498+
ZeInitDriversResult != ZE_RESULT_SUCCESS) {
499+
// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
500+
UR_LOG(ERR, "Level Zero Uninitialized\n");
501+
return;
505502
}
506-
}
507-
assert(GlobalAdapter->ZeResult !=
508-
std::nullopt); // verify that level-zero is initialized
509-
PlatformVec platforms;
510-
511-
// Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms.
512-
if (*GlobalAdapter->ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) {
513-
UR_LOG(ERR, "Level Zero Uninitialized\n");
514-
result = std::move(platforms);
515-
return;
516-
}
517-
if (*GlobalAdapter->ZeResult != ZE_RESULT_SUCCESS) {
518-
UR_LOG(ERR, "Level Zero initialization failure\n");
519-
result = ze2urResult(*GlobalAdapter->ZeResult);
520503

521-
return;
522-
}
504+
PlatformVec platforms;
523505

524506
#ifdef UR_ADAPTER_LEVEL_ZERO_V2
525507
auto [useV2, reason] = shouldUseV2Adapter();
526508
if (!useV2) {
527509
UR_LOG(INFO, "Skipping L0 V2 adapter: {}", reason);
528-
result = std::move(platforms);
529510
return;
530511
}
531512
#else
532513
auto [useV1, reason] = shouldUseV1Adapter();
533514
if (!useV1) {
534515
UR_LOG(INFO, "Skipping L0 V1 adapter: {}", reason);
535-
result = std::move(platforms);
536516
return;
537517
}
538518
#endif
@@ -558,41 +538,37 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
558538
}
559539
if (ZesInitNeeded) {
560540
#ifdef UR_STATIC_LEVEL_ZERO
561-
GlobalAdapter->getDeviceByUUIdFunctionPtr = zesDriverGetDeviceByUuidExp;
562-
GlobalAdapter->getSysManDriversFunctionPtr = zesDriverGet;
563-
GlobalAdapter->sysManInitFunctionPtr = zesInit;
541+
getDeviceByUUIdFunctionPtr = zesDriverGetDeviceByUuidExp;
542+
getSysManDriversFunctionPtr = zesDriverGet;
543+
sysManInitFunctionPtr = zesInit;
564544
#else
565-
GlobalAdapter->getDeviceByUUIdFunctionPtr =
566-
(zes_pfnDriverGetDeviceByUuidExp_t)
567-
ur_loader::LibLoader::getFunctionPtr(
568-
GlobalAdapter->processHandle, "zesDriverGetDeviceByUuidExp");
569-
GlobalAdapter->getSysManDriversFunctionPtr =
545+
getDeviceByUUIdFunctionPtr = (zes_pfnDriverGetDeviceByUuidExp_t)
546+
ur_loader::LibLoader::getFunctionPtr(processHandle,
547+
"zesDriverGetDeviceByUuidExp");
548+
getSysManDriversFunctionPtr =
570549
(zes_pfnDriverGet_t)ur_loader::LibLoader::getFunctionPtr(
571-
GlobalAdapter->processHandle, "zesDriverGet");
572-
GlobalAdapter->sysManInitFunctionPtr =
573-
(zes_pfnInit_t)ur_loader::LibLoader::getFunctionPtr(
574-
GlobalAdapter->processHandle, "zesInit");
550+
processHandle, "zesDriverGet");
551+
sysManInitFunctionPtr =
552+
(zes_pfnInit_t)ur_loader::LibLoader::getFunctionPtr(processHandle,
553+
"zesInit");
575554
#endif
576555
}
577-
if (GlobalAdapter->getDeviceByUUIdFunctionPtr &&
578-
GlobalAdapter->getSysManDriversFunctionPtr &&
579-
GlobalAdapter->sysManInitFunctionPtr) {
556+
if (getDeviceByUUIdFunctionPtr && getSysManDriversFunctionPtr &&
557+
sysManInitFunctionPtr) {
580558
ze_init_flags_t L0ZesInitFlags = 0;
581559
UR_LOG(DEBUG, "\nzesInit with flags value of {}\n",
582560
static_cast<int>(L0ZesInitFlags));
583-
GlobalAdapter->ZesResult = ZE_CALL_NOCHECK(
584-
GlobalAdapter->sysManInitFunctionPtr, (L0ZesInitFlags));
561+
ZesResult = ZE_CALL_NOCHECK(sysManInitFunctionPtr, (L0ZesInitFlags));
585562
} else {
586-
GlobalAdapter->ZesResult = ZE_RESULT_ERROR_UNINITIALIZED;
563+
ZesResult = ZE_RESULT_ERROR_UNINITIALIZED;
587564
}
588565

589-
ur_result_t err = initPlatforms(platforms, GlobalAdapter->ZesResult);
566+
ur_result_t err = initPlatforms(this, platforms, ZesResult);
590567
if (err == UR_RESULT_SUCCESS) {
591-
result = std::move(platforms);
568+
Platforms = std::move(platforms);
592569
} else {
593-
result = err;
570+
throw err;
594571
}
595-
};
596572
}
597573

598574
void globalAdapterOnDemandCleanup() {
@@ -623,15 +599,25 @@ ur_result_t urAdapterGet(
623599
/// ::urAdapterGet shall only retrieve that number of platforms.
624600
ur_adapter_handle_t *Adapters,
625601
/// [out][optional] returns the total number of adapters available.
626-
uint32_t *NumAdapters) {
602+
uint32_t *NumAdapters) try {
627603
static std::mutex AdapterConstructionMutex{};
628604

629-
if (NumEntries > 0 && Adapters) {
630-
std::lock_guard<std::mutex> Lock{AdapterConstructionMutex};
605+
// We need to initialize the adapter even if user only queries
606+
// the number of adapters to decided whether to use V1 or V2.
607+
std::lock_guard<std::mutex> Lock{AdapterConstructionMutex};
631608

632-
if (!GlobalAdapter) {
633-
GlobalAdapter = new ur_adapter_handle_t_();
609+
if (!GlobalAdapter) {
610+
GlobalAdapter = new ur_adapter_handle_t_();
611+
}
612+
613+
if (GlobalAdapter->Platforms.size() == 0) {
614+
if (NumAdapters) {
615+
*NumAdapters = 0;
634616
}
617+
return UR_RESULT_ERROR_UNSUPPORTED_VERSION;
618+
}
619+
620+
if (NumEntries && Adapters) {
635621
*Adapters = GlobalAdapter;
636622

637623
if (GlobalAdapter->RefCount.retain() == 0) {
@@ -644,6 +630,10 @@ ur_result_t urAdapterGet(
644630
}
645631

646632
return UR_RESULT_SUCCESS;
633+
} catch (ur_result_t result) {
634+
return result;
635+
} catch (...) {
636+
return UR_RESULT_ERROR_UNKNOWN;
647637
}
648638

649639
ur_result_t urAdapterRelease([[maybe_unused]] ur_adapter_handle_t Adapter) {

unified-runtime/source/adapters/level_zero/adapter.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@ struct ur_adapter_handle_t_ : ur::handle_base<ur::level_zero::ddi_getter> {
3737
uint32_t ZeInitDriversCount = 0;
3838
bool InitDriversSupported = false;
3939

40-
ze_result_t ZeInitDriversResult;
41-
ze_result_t ZeInitResult;
42-
ze_result_t ZesResult;
43-
std::optional<ze_result_t> ZeResult;
44-
ZeCache<Result<PlatformVec>> PlatformCache;
40+
PlatformVec Platforms;
4541
logger::Logger &logger;
4642
HMODULE processHandle = nullptr;
4743

unified-runtime/source/adapters/level_zero/device.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,12 +1583,8 @@ ur_result_t urDeviceCreateWithNativeHandle(
15831583
// a valid Level Zero device.
15841584

15851585
ur_device_handle_t Dev = nullptr;
1586-
if (const auto *platforms = GlobalAdapter->PlatformCache->get_value()) {
1587-
for (const auto &p : *platforms) {
1588-
Dev = p->getDeviceFromNativeHandle(ZeDevice);
1589-
}
1590-
} else {
1591-
return GlobalAdapter->PlatformCache->get_error();
1586+
for (const auto &p : GlobalAdapter->Platforms) {
1587+
Dev = p->getDeviceFromNativeHandle(ZeDevice);
15921588
}
15931589

15941590
if (Dev == nullptr)

0 commit comments

Comments
 (0)