diff --git a/scripts/templates/valddi.cpp.mako b/scripts/templates/valddi.cpp.mako index 8cc4a9dc0f..7a18860ba9 100644 --- a/scripts/templates/valddi.cpp.mako +++ b/scripts/templates/valddi.cpp.mako @@ -94,6 +94,19 @@ namespace ur_validation_layer %endif %endfor + %for tp in tracked_params: + <% + tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None) + is_handle_to_adapter = ("_adapter_handle_t" in tp['type']) + %> + %if func_name in tp_handle_funcs['release']: + if( getContext()->enableLeakChecking ) + { + getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()}); + } + %endif + %endfor + ${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} ); %for tp in tracked_params: @@ -114,15 +127,10 @@ namespace ur_validation_layer } } %elif func_name in tp_handle_funcs['retain']: - if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS ) + if( getContext()->enableLeakChecking ) { getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()}); } - %elif func_name in tp_handle_funcs['release']: - if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS ) - { - getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()}); - } %endif %endfor diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index b3969de10f..26173a6d14 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -71,12 +71,12 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( } } - ur_result_t result = pfnAdapterRelease(hAdapter); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hAdapter, true); } + ur_result_t result = pfnAdapterRelease(hAdapter); + return result; } @@ -99,7 +99,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( ur_result_t result = pfnAdapterRetain(hAdapter); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hAdapter, true); } @@ -558,7 +558,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_result_t result = pfnRetain(hDevice); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hDevice, false); } @@ -583,12 +583,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( } } - ur_result_t result = pfnRelease(hDevice); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hDevice, false); } + ur_result_t result = pfnRelease(hDevice); + return result; } @@ -861,7 +861,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_result_t result = pfnRetain(hContext); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hContext, false); } @@ -886,12 +886,12 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( } } - ur_result_t result = pfnRelease(hContext); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hContext, false); } + ur_result_t result = pfnRelease(hContext); + return result; } @@ -1248,7 +1248,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( ur_result_t result = pfnRetain(hMem); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hMem, false); } @@ -1273,12 +1273,12 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( } } - ur_result_t result = pfnRelease(hMem); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hMem, false); } + ur_result_t result = pfnRelease(hMem); + return result; } @@ -1657,7 +1657,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_result_t result = pfnRetain(hSampler); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hSampler, false); } @@ -1682,12 +1682,12 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease( } } - ur_result_t result = pfnRelease(hSampler); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hSampler, false); } + ur_result_t result = pfnRelease(hSampler); + return result; } @@ -2154,7 +2154,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( ur_result_t result = pfnPoolRetain(pPool); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(pPool, false); } @@ -2178,12 +2178,12 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( } } - ur_result_t result = pfnPoolRelease(pPool); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(pPool, false); } + ur_result_t result = pfnPoolRelease(pPool); + return result; } @@ -2631,7 +2631,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_result_t result = pfnRetain(hPhysicalMem); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hPhysicalMem, false); } @@ -2656,12 +2656,12 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( } } - ur_result_t result = pfnRelease(hPhysicalMem); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hPhysicalMem, false); } + ur_result_t result = pfnRelease(hPhysicalMem); + return result; } @@ -2952,7 +2952,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( ur_result_t result = pfnRetain(hProgram); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hProgram, false); } @@ -2977,12 +2977,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease( } } - ur_result_t result = pfnRelease(hProgram); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hProgram, false); } + ur_result_t result = pfnRelease(hProgram); + return result; } @@ -3618,7 +3618,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( ur_result_t result = pfnRetain(hKernel); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hKernel, false); } @@ -3643,12 +3643,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease( } } - ur_result_t result = pfnRelease(hKernel); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hKernel, false); } + ur_result_t result = pfnRelease(hKernel); + return result; } @@ -4138,7 +4138,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( ur_result_t result = pfnRetain(hQueue); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hQueue, false); } @@ -4163,12 +4163,12 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease( } } - ur_result_t result = pfnRelease(hQueue); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hQueue, false); } + ur_result_t result = pfnRelease(hQueue); + return result; } @@ -4454,7 +4454,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( ur_result_t result = pfnRetain(hEvent); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hEvent, false); } @@ -4478,12 +4478,12 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease( } } - ur_result_t result = pfnRelease(hEvent); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hEvent, false); } + ur_result_t result = pfnRelease(hEvent); + return result; } diff --git a/test/conformance/adapter/urAdapterRelease.cpp b/test/conformance/adapter/urAdapterRelease.cpp index 8b29fa0f2c..0b28287aa7 100644 --- a/test/conformance/adapter/urAdapterRelease.cpp +++ b/test/conformance/adapter/urAdapterRelease.cpp @@ -16,6 +16,7 @@ struct urAdapterReleaseTest : uur::runtime::urAdapterTest { TEST_F(urAdapterReleaseTest, Success) { uint32_t referenceCountBefore = 0; + ASSERT_SUCCESS(urAdapterRetain(adapter)); ASSERT_SUCCESS(urAdapterGetInfo(adapter, UR_ADAPTER_INFO_REFERENCE_COUNT, sizeof(referenceCountBefore), diff --git a/test/conformance/device/urDeviceRelease.cpp b/test/conformance/device/urDeviceRelease.cpp index a8f6a3bc9d..dd5510394f 100644 --- a/test/conformance/device/urDeviceRelease.cpp +++ b/test/conformance/device/urDeviceRelease.cpp @@ -8,6 +8,8 @@ struct urDeviceReleaseTest : uur::urAllDevicesTest {}; TEST_F(urDeviceReleaseTest, Success) { for (auto device : devices) { + ASSERT_SUCCESS(urDeviceRetain(device)); + uint32_t prevRefCount = 0; ASSERT_SUCCESS(uur::GetObjectReferenceCount(device, prevRefCount)); diff --git a/test/conformance/testing/include/uur/fixtures.h b/test/conformance/testing/include/uur/fixtures.h index 1900568292..b1c90883d8 100644 --- a/test/conformance/testing/include/uur/fixtures.h +++ b/test/conformance/testing/include/uur/fixtures.h @@ -95,6 +95,7 @@ struct urDeviceTest : urPlatformTest, void SetUp() override { UUR_RETURN_ON_FATAL_FAILURE(urPlatformTest::SetUp()); device = GetParam(); + EXPECT_SUCCESS(urDeviceRetain(device)); } void TearDown() override {