Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 84 additions & 54 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ using namespace loader_driver_ddi;

namespace loader
{
__${x}dlllocal ze_result_t ${X}_APICALL
${n}loaderInitDriverDDITables(loader::driver_t *driver) {
ze_result_t result = ZE_RESULT_SUCCESS;
%for tbl in th.get_pfntables(specs, meta, n, tags):
result = ${tbl['export']['name']}FromDriver(driver);
if (result != ZE_RESULT_SUCCESS) {
return result;
}
%endfor
return result;
}
%for obj in th.extract_objs(specs, r"function"):
<%
ret_type = obj['return_type']
Expand Down Expand Up @@ -65,6 +76,17 @@ namespace loader
if(drv.initStatus != ZE_RESULT_SUCCESS)
continue;
%endif
if (!drv.handle || !drv.ddiInitialized) {
%if namespace != "zes":
bool sysmanInit = false;
%else:
bool sysmanInit = true;
%endif
auto res = loader::context->init_driver( drv, flags, nullptr, nullptr, nullptr, sysmanInit );
if (res != ZE_RESULT_SUCCESS) {
continue;
}
}
%if re.match(r"Init", obj['name']) and namespace == "zes":
if (!drv.dditable.${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}) {
drv.initSysManStatus = ZE_RESULT_ERROR_UNINITIALIZED;
Expand All @@ -90,6 +112,13 @@ namespace loader

%elif re.match(r"\w+DriverGet$", th.make_func_name(n, tags, obj)) or re.match(r"\w+InitDrivers$", th.make_func_name(n, tags, obj)):
uint32_t total_driver_handle_count = 0;
%if re.match(r"\w+InitDrivers$", th.make_func_name(n, tags, obj)):
for( auto& drv : loader::context->zeDrivers ) {
if (!drv.handle || !drv.ddiInitialized) {
loader::context->init_driver( drv, 0, desc, nullptr, nullptr, false );
}
}
%endif

{
std::lock_guard<std::mutex> lock(loader::context->sortMutex);
Expand Down Expand Up @@ -124,15 +153,16 @@ namespace loader
%endif
{
%if not (re.match(r"\w+InitDrivers$", th.make_func_name(n, tags, obj))) and namespace != "zes":
if(drv.initStatus != ZE_RESULT_SUCCESS)
if(drv.initStatus != ZE_RESULT_SUCCESS || !drv.ddiInitialized)
continue;
%elif namespace == "zes":
if(drv.initStatus != ZE_RESULT_SUCCESS || drv.initSysManStatus != ZE_RESULT_SUCCESS)
if(drv.initStatus != ZE_RESULT_SUCCESS || drv.initSysManStatus != ZE_RESULT_SUCCESS || !drv.ddiInitialized)
continue;
%else:
if (!drv.dditable.${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}) {
%if re.match(r"\w+InitDrivers$", th.make_func_name(n, tags, obj)):
drv.initDriversStatus = ${X}_RESULT_ERROR_UNINITIALIZED;
result = ${X}_RESULT_ERROR_UNINITIALIZED;
%else:
drv.initStatus = ${X}_RESULT_ERROR_UNINITIALIZED;
%endif
Expand Down Expand Up @@ -495,6 +525,52 @@ ${tbl['export']['name']}Legacy()

%endfor

%for tbl in th.get_pfntables(specs, meta, n, tags):
///////////////////////////////////////////////////////////////////////////////
/// @brief Exported function for filling application's ${tbl['name']} table
/// with current process' addresses
///
/// @returns
/// - ::${X}_RESULT_SUCCESS
/// - ::${X}_RESULT_ERROR_UNINITIALIZED
/// - ::${X}_RESULT_ERROR_INVALID_NULL_POINTER
/// - ::${X}_RESULT_ERROR_UNSUPPORTED_VERSION
__${x}dlllocal ${x}_result_t ${X}_APICALL
${tbl['export']['name']}FromDriver(loader::driver_t *driver)
{
${x}_result_t result = ${X}_RESULT_SUCCESS;
if(driver->initStatus != ZE_RESULT_SUCCESS)
return driver->initStatus;
auto getTable = reinterpret_cast<${tbl['pfn']}>(
GET_FUNCTION_PTR( driver->handle, "${tbl['export']['name']}") );
if(!getTable)
%if th.isNewProcTable(tbl['export']['name']) is True:
{
//It is valid to not have this proc addr table
return ${X}_RESULT_SUCCESS;
}
%else:
return driver->initStatus;
%endif
%if tbl['experimental'] is False: #//Experimental Tables may not be implemented in driver
auto getTableResult = getTable( loader::context->ddi_init_version, &driver->dditable.${n}.${tbl['name']});
if(getTableResult == ZE_RESULT_SUCCESS) {
loader::context->configured_version = loader::context->ddi_init_version;
} else
driver->initStatus = getTableResult;
%if namespace != "zes":
%if tbl['name'] == "Global":
if (driver->dditable.ze.Global.pfnInitDrivers) {
loader::context->initDriversSupport = true;
}
%endif
%endif
%else:
result = getTable( loader::context->ddi_init_version, &driver->dditable.${n}.${tbl['name']});
%endif
return result;
}
%endfor
%for tbl in th.get_pfntables(specs, meta, n, tags):
///////////////////////////////////////////////////////////////////////////////
/// @brief Exported function for filling application's ${tbl['name']} table
Expand Down Expand Up @@ -526,63 +602,17 @@ ${tbl['export']['name']}(
if( loader::context->version < version )
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;

${x}_result_t result = ${X}_RESULT_SUCCESS;

%if tbl['experimental'] is False: #//Experimental Tables may not be implemented in driver
bool atLeastOneDriverValid = false;
%endif
// Load the device-driver DDI tables
%if namespace != "zes":
for( auto& drv : loader::context->zeDrivers )
%else:
for( auto& drv : *loader::context->sysmanInstanceDrivers )
%endif
{
if(drv.initStatus != ZE_RESULT_SUCCESS)
continue;
auto getTable = reinterpret_cast<${tbl['pfn']}>(
GET_FUNCTION_PTR( drv.handle, "${tbl['export']['name']}") );
if(!getTable)
%if th.isNewProcTable(tbl['export']['name']) is True:
{
atLeastOneDriverValid = true;
//It is valid to not have this proc addr table
continue;
}
%else:
continue;
%endif
%if tbl['experimental'] is False: #//Experimental Tables may not be implemented in driver
auto getTableResult = getTable( version, &drv.dditable.${n}.${tbl['name']});
if(getTableResult == ZE_RESULT_SUCCESS) {
atLeastOneDriverValid = true;
loader::context->configured_version = version;
} else
drv.initStatus = getTableResult;
%if namespace != "zes":
%if tbl['name'] == "Global":
if (drv.dditable.ze.Global.pfnInitDrivers) {
loader::context->initDriversSupport = true;
}
%endif
%endif
%else:
result = getTable( version, &drv.dditable.${n}.${tbl['name']});
%endif
}
loader::context->ddi_init_version = version;

%if tbl['experimental'] is False: #//Experimental Tables may not be implemented in driver
if(!atLeastOneDriverValid)
result = ${X}_RESULT_ERROR_UNINITIALIZED;
else
result = ${X}_RESULT_SUCCESS;
%endif
${x}_result_t result = ${X}_RESULT_SUCCESS;

if( ${X}_RESULT_SUCCESS == result )
{
%if namespace != "zes":
%if tbl['name'] == "Global":
if( true )
%elif namespace != "zes":
if( ( loader::context->zeDrivers.size() > 1 ) || loader::context->forceIntercept )
%else:
%elif namespace == "zes":
if( ( loader::context->sysmanInstanceDrivers->size() > 1 ) || loader::context->forceIntercept )
%endif
{
Expand Down
8 changes: 8 additions & 0 deletions scripts/templates/ldrddi.h.mako
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ from templates import helper as th

namespace loader
{
///////////////////////////////////////////////////////////////////////////////
// Forward declaration for driver_t so this header can reference loader::driver_t*
// without requiring inclusion of ze_loader_internal.h (which includes this file).
struct driver_t;
///////////////////////////////////////////////////////////////////////////////
%for obj in th.extract_objs(specs, r"handle"):
%if 'class' in obj:
Expand All @@ -32,6 +36,8 @@ namespace loader

%endif
%endfor
__${x}dlllocal ze_result_t ${X}_APICALL
${n}loaderInitDriverDDITables(loader::driver_t *driver);
}

namespace loader_driver_ddi
Expand All @@ -57,6 +63,8 @@ extern "C" {
%for tbl in th.get_pfntables(specs, meta, n, tags):
__${x}dlllocal void ${X}_APICALL
${tbl['export']['name']}Legacy();
__${x}dlllocal ze_result_t ${X}_APICALL
${tbl['export']['name']}FromDriver(loader::driver_t *driver);
%endfor

#if defined(__cplusplus)
Expand Down
4 changes: 4 additions & 0 deletions scripts/templates/ze_loader_internal.h.mako
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ namespace loader
bool legacyInitAttempted = false;
bool driverDDIHandleSupportQueried = false;
ze_driver_handle_t zerDriverHandle = nullptr;
ze_api_version_t versionRequested = ZE_API_VERSION_CURRENT;
bool ddiInitialized = false;
};

using driver_vector_t = std::vector< driver_t >;
Expand Down Expand Up @@ -97,6 +99,7 @@ namespace loader
std::unordered_map<ze_sampler_object_t *, ze_sampler_handle_t> sampler_handle_map;
ze_api_version_t version = ZE_API_VERSION_CURRENT;
ze_api_version_t configured_version = ZE_API_VERSION_CURRENT;
ze_api_version_t ddi_init_version = ZE_API_VERSION_CURRENT;

driver_vector_t allDrivers;
driver_vector_t zeDrivers;
Expand Down Expand Up @@ -129,6 +132,7 @@ namespace loader
std::atomic<bool> sortingInProgress = {false};
std::mutex sortMutex;
bool instrumentationEnabled = false;
bool pciOrderingRequested = false;
dditable_t tracing_dditable = {};
std::shared_ptr<Logger> zel_logger;
ze_driver_handle_t* defaultZerDriverHandle = nullptr;
Expand Down
Loading
Loading