From 8412c8bdfbf8262d7ee6d953a85cf956bb50c887 Mon Sep 17 00:00:00 2001 From: Ross Brunton Date: Thu, 21 Nov 2024 16:12:45 +0000 Subject: [PATCH] Fix "use after release" issues In some cases, we use handles after releasing them, or incorrectly release handles we shouldn't. This doesn't cause any issues currently, but will when we start using reference counting in the loader. --- scripts/templates/valddi.cpp.mako | 20 +++-- source/loader/layers/validation/ur_valddi.cpp | 88 +++++++++---------- test/conformance/adapter/urAdapterRelease.cpp | 1 + test/conformance/device/urDeviceRelease.cpp | 2 + .../testing/include/uur/fixtures.h | 1 + 5 files changed, 62 insertions(+), 50 deletions(-) diff --git a/scripts/templates/valddi.cpp.mako b/scripts/templates/valddi.cpp.mako index 8cc4a9dc0f..7a18860ba9 100644 --- a/scripts/templates/valddi.cpp.mako +++ b/scripts/templates/valddi.cpp.mako @@ -94,6 +94,19 @@ namespace ur_validation_layer %endif %endfor + %for tp in tracked_params: + <% + tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None) + is_handle_to_adapter = ("_adapter_handle_t" in tp['type']) + %> + %if func_name in tp_handle_funcs['release']: + if( getContext()->enableLeakChecking ) + { + getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()}); + } + %endif + %endfor + ${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} ); %for tp in tracked_params: @@ -114,15 +127,10 @@ namespace ur_validation_layer } } %elif func_name in tp_handle_funcs['retain']: - if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS ) + if( getContext()->enableLeakChecking ) { getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()}); } - %elif func_name in tp_handle_funcs['release']: - if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS ) - { - getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()}); - } %endif %endfor diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index b3969de10f..26173a6d14 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -71,12 +71,12 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( } } - ur_result_t result = pfnAdapterRelease(hAdapter); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hAdapter, true); } + ur_result_t result = pfnAdapterRelease(hAdapter); + return result; } @@ -99,7 +99,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( ur_result_t result = pfnAdapterRetain(hAdapter); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hAdapter, true); } @@ -558,7 +558,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_result_t result = pfnRetain(hDevice); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hDevice, false); } @@ -583,12 +583,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( } } - ur_result_t result = pfnRelease(hDevice); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hDevice, false); } + ur_result_t result = pfnRelease(hDevice); + return result; } @@ -861,7 +861,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_result_t result = pfnRetain(hContext); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hContext, false); } @@ -886,12 +886,12 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( } } - ur_result_t result = pfnRelease(hContext); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hContext, false); } + ur_result_t result = pfnRelease(hContext); + return result; } @@ -1248,7 +1248,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( ur_result_t result = pfnRetain(hMem); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hMem, false); } @@ -1273,12 +1273,12 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( } } - ur_result_t result = pfnRelease(hMem); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hMem, false); } + ur_result_t result = pfnRelease(hMem); + return result; } @@ -1657,7 +1657,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_result_t result = pfnRetain(hSampler); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hSampler, false); } @@ -1682,12 +1682,12 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease( } } - ur_result_t result = pfnRelease(hSampler); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hSampler, false); } + ur_result_t result = pfnRelease(hSampler); + return result; } @@ -2154,7 +2154,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( ur_result_t result = pfnPoolRetain(pPool); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(pPool, false); } @@ -2178,12 +2178,12 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( } } - ur_result_t result = pfnPoolRelease(pPool); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(pPool, false); } + ur_result_t result = pfnPoolRelease(pPool); + return result; } @@ -2631,7 +2631,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_result_t result = pfnRetain(hPhysicalMem); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hPhysicalMem, false); } @@ -2656,12 +2656,12 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( } } - ur_result_t result = pfnRelease(hPhysicalMem); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hPhysicalMem, false); } + ur_result_t result = pfnRelease(hPhysicalMem); + return result; } @@ -2952,7 +2952,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( ur_result_t result = pfnRetain(hProgram); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hProgram, false); } @@ -2977,12 +2977,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease( } } - ur_result_t result = pfnRelease(hProgram); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hProgram, false); } + ur_result_t result = pfnRelease(hProgram); + return result; } @@ -3618,7 +3618,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( ur_result_t result = pfnRetain(hKernel); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hKernel, false); } @@ -3643,12 +3643,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease( } } - ur_result_t result = pfnRelease(hKernel); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hKernel, false); } + ur_result_t result = pfnRelease(hKernel); + return result; } @@ -4138,7 +4138,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( ur_result_t result = pfnRetain(hQueue); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hQueue, false); } @@ -4163,12 +4163,12 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease( } } - ur_result_t result = pfnRelease(hQueue); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hQueue, false); } + ur_result_t result = pfnRelease(hQueue); + return result; } @@ -4454,7 +4454,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( ur_result_t result = pfnRetain(hEvent); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hEvent, false); } @@ -4478,12 +4478,12 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease( } } - ur_result_t result = pfnRelease(hEvent); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hEvent, false); } + ur_result_t result = pfnRelease(hEvent); + return result; } diff --git a/test/conformance/adapter/urAdapterRelease.cpp b/test/conformance/adapter/urAdapterRelease.cpp index 8b29fa0f2c..0b28287aa7 100644 --- a/test/conformance/adapter/urAdapterRelease.cpp +++ b/test/conformance/adapter/urAdapterRelease.cpp @@ -16,6 +16,7 @@ struct urAdapterReleaseTest : uur::runtime::urAdapterTest { TEST_F(urAdapterReleaseTest, Success) { uint32_t referenceCountBefore = 0; + ASSERT_SUCCESS(urAdapterRetain(adapter)); ASSERT_SUCCESS(urAdapterGetInfo(adapter, UR_ADAPTER_INFO_REFERENCE_COUNT, sizeof(referenceCountBefore), diff --git a/test/conformance/device/urDeviceRelease.cpp b/test/conformance/device/urDeviceRelease.cpp index a8f6a3bc9d..dd5510394f 100644 --- a/test/conformance/device/urDeviceRelease.cpp +++ b/test/conformance/device/urDeviceRelease.cpp @@ -8,6 +8,8 @@ struct urDeviceReleaseTest : uur::urAllDevicesTest {}; TEST_F(urDeviceReleaseTest, Success) { for (auto device : devices) { + ASSERT_SUCCESS(urDeviceRetain(device)); + uint32_t prevRefCount = 0; ASSERT_SUCCESS(uur::GetObjectReferenceCount(device, prevRefCount)); diff --git a/test/conformance/testing/include/uur/fixtures.h b/test/conformance/testing/include/uur/fixtures.h index 1900568292..b1c90883d8 100644 --- a/test/conformance/testing/include/uur/fixtures.h +++ b/test/conformance/testing/include/uur/fixtures.h @@ -95,6 +95,7 @@ struct urDeviceTest : urPlatformTest, void SetUp() override { UUR_RETURN_ON_FATAL_FAILURE(urPlatformTest::SetUp()); device = GetParam(); + EXPECT_SUCCESS(urDeviceRetain(device)); } void TearDown() override {