Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ namespace ur_validation_layer
%endif
%endfor

%for tp in tracked_params:
<%
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
%>
%if func_name in tp_handle_funcs['release']:
if( getContext()->enableLeakChecking )
{
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%endif
%endfor

${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );

%for tp in tracked_params:
Expand All @@ -114,15 +127,10 @@ namespace ur_validation_layer
}
}
%elif func_name in tp_handle_funcs['retain']:
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking )
{
getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%elif func_name in tp_handle_funcs['release']:
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%endif
%endfor

Expand Down
88 changes: 44 additions & 44 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
}
}

ur_result_t result = pfnAdapterRelease(hAdapter);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hAdapter, true);
}

ur_result_t result = pfnAdapterRelease(hAdapter);

return result;
}

Expand All @@ -99,7 +99,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(

ur_result_t result = pfnAdapterRetain(hAdapter);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hAdapter, true);
}

Expand Down Expand Up @@ -558,7 +558,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(

ur_result_t result = pfnRetain(hDevice);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hDevice, false);
}

Expand All @@ -583,12 +583,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
}
}

ur_result_t result = pfnRelease(hDevice);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hDevice, false);
}

ur_result_t result = pfnRelease(hDevice);

return result;
}

Expand Down Expand Up @@ -861,7 +861,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(

ur_result_t result = pfnRetain(hContext);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hContext, false);
}

Expand All @@ -886,12 +886,12 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
}
}

ur_result_t result = pfnRelease(hContext);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hContext, false);
}

ur_result_t result = pfnRelease(hContext);

return result;
}

Expand Down Expand Up @@ -1248,7 +1248,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(

ur_result_t result = pfnRetain(hMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hMem, false);
}

Expand All @@ -1273,12 +1273,12 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
}
}

ur_result_t result = pfnRelease(hMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hMem, false);
}

ur_result_t result = pfnRelease(hMem);

return result;
}

Expand Down Expand Up @@ -1657,7 +1657,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(

ur_result_t result = pfnRetain(hSampler);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hSampler, false);
}

Expand All @@ -1682,12 +1682,12 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
}
}

ur_result_t result = pfnRelease(hSampler);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hSampler, false);
}

ur_result_t result = pfnRelease(hSampler);

return result;
}

Expand Down Expand Up @@ -2154,7 +2154,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(

ur_result_t result = pfnPoolRetain(pPool);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(pPool, false);
}

Expand All @@ -2178,12 +2178,12 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
}
}

ur_result_t result = pfnPoolRelease(pPool);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(pPool, false);
}

ur_result_t result = pfnPoolRelease(pPool);

return result;
}

Expand Down Expand Up @@ -2631,7 +2631,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(

ur_result_t result = pfnRetain(hPhysicalMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hPhysicalMem, false);
}

Expand All @@ -2656,12 +2656,12 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
}
}

ur_result_t result = pfnRelease(hPhysicalMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hPhysicalMem, false);
}

ur_result_t result = pfnRelease(hPhysicalMem);

return result;
}

Expand Down Expand Up @@ -2952,7 +2952,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(

ur_result_t result = pfnRetain(hProgram);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hProgram, false);
}

Expand All @@ -2977,12 +2977,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
}
}

ur_result_t result = pfnRelease(hProgram);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hProgram, false);
}

ur_result_t result = pfnRelease(hProgram);

return result;
}

Expand Down Expand Up @@ -3618,7 +3618,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(

ur_result_t result = pfnRetain(hKernel);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hKernel, false);
}

Expand All @@ -3643,12 +3643,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
}
}

ur_result_t result = pfnRelease(hKernel);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hKernel, false);
}

ur_result_t result = pfnRelease(hKernel);

return result;
}

Expand Down Expand Up @@ -4138,7 +4138,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(

ur_result_t result = pfnRetain(hQueue);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hQueue, false);
}

Expand All @@ -4163,12 +4163,12 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
}
}

ur_result_t result = pfnRelease(hQueue);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hQueue, false);
}

ur_result_t result = pfnRelease(hQueue);

return result;
}

Expand Down Expand Up @@ -4454,7 +4454,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(

ur_result_t result = pfnRetain(hEvent);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hEvent, false);
}

Expand All @@ -4478,12 +4478,12 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
}
}

ur_result_t result = pfnRelease(hEvent);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hEvent, false);
}

ur_result_t result = pfnRelease(hEvent);

return result;
}

Expand Down
1 change: 1 addition & 0 deletions test/conformance/adapter/urAdapterRelease.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct urAdapterReleaseTest : uur::runtime::urAdapterTest {

TEST_F(urAdapterReleaseTest, Success) {
uint32_t referenceCountBefore = 0;
ASSERT_SUCCESS(urAdapterRetain(adapter));

ASSERT_SUCCESS(urAdapterGetInfo(adapter, UR_ADAPTER_INFO_REFERENCE_COUNT,
sizeof(referenceCountBefore),
Expand Down
2 changes: 2 additions & 0 deletions test/conformance/device/urDeviceRelease.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ struct urDeviceReleaseTest : uur::urAllDevicesTest {};

TEST_F(urDeviceReleaseTest, Success) {
for (auto device : devices) {
ASSERT_SUCCESS(urDeviceRetain(device));

uint32_t prevRefCount = 0;
ASSERT_SUCCESS(uur::GetObjectReferenceCount(device, prevRefCount));

Expand Down
1 change: 1 addition & 0 deletions test/conformance/testing/include/uur/fixtures.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ struct urDeviceTest : urPlatformTest,
void SetUp() override {
UUR_RETURN_ON_FATAL_FAILURE(urPlatformTest::SetUp());
device = GetParam();
EXPECT_SUCCESS(urDeviceRetain(device));
}

void TearDown() override {
Expand Down
Loading