From d7275f7e258ecbda759a2f8a0292928d104e5a74 Mon Sep 17 00:00:00 2001 From: Andrei Elovikov Date: Wed, 23 Apr 2025 13:03:29 -0700 Subject: [PATCH 1/2] [SYCL][NFCI] Ensure `device_impl` is only created via `platform_impl::getOrMakeDeviceImpl` --- sycl/source/detail/device_impl.cpp | 3 ++- sycl/source/detail/device_impl.hpp | 9 ++++++- sycl/source/detail/platform_impl.cpp | 3 ++- sycl/unittests/program_manager/SubDevices.cpp | 6 ++--- sycl/unittests/queue/DeviceCheck.cpp | 27 +++++++++++++++---- 5 files changed, 36 insertions(+), 12 deletions(-) diff --git a/sycl/source/detail/device_impl.cpp b/sycl/source/detail/device_impl.cpp index f2f2673562f82..7d279786157ca 100644 --- a/sycl/source/detail/device_impl.cpp +++ b/sycl/source/detail/device_impl.cpp @@ -21,7 +21,8 @@ namespace detail { /// Constructs a SYCL device instance using the provided /// UR device instance. -device_impl::device_impl(ur_device_handle_t Device, platform_impl &Platform) +device_impl::device_impl(ur_device_handle_t Device, platform_impl &Platform, + device_impl::private_tag) : MDevice(Device), MPlatform(Platform.shared_from_this()), MDeviceHostBaseTime(std::make_pair(0, 0)) { const AdapterPtr &Adapter = Platform.getAdapter(); diff --git a/sycl/source/detail/device_impl.hpp b/sycl/source/detail/device_impl.hpp index 2b678fe475f31..b11384c555d73 100644 --- a/sycl/source/detail/device_impl.hpp +++ b/sycl/source/detail/device_impl.hpp @@ -33,10 +33,17 @@ class platform_impl; // TODO: Make code thread-safe class device_impl { + struct private_tag {}; + friend class platform_impl; + public: /// Constructs a SYCL device instance using the provided /// UR device instance. - explicit device_impl(ur_device_handle_t Device, platform_impl &Platform); + // + // Must be called through `platform_impl::getOrMakeDeviceImpl` only. + // `private_tag` ensures that is true. + explicit device_impl(ur_device_handle_t Device, platform_impl &Platform, + private_tag); ~device_impl(); diff --git a/sycl/source/detail/platform_impl.cpp b/sycl/source/detail/platform_impl.cpp index 8fefa49480134..3bf34a90492dc 100644 --- a/sycl/source/detail/platform_impl.cpp +++ b/sycl/source/detail/platform_impl.cpp @@ -304,7 +304,8 @@ platform_impl::getOrMakeDeviceImpl(ur_device_handle_t UrDevice) { return Result; // Otherwise make the impl - Result = std::make_shared(UrDevice, *this); + Result = std::make_shared(UrDevice, *this, + device_impl::private_tag{}); MDeviceCache.emplace_back(Result); return Result; diff --git a/sycl/unittests/program_manager/SubDevices.cpp b/sycl/unittests/program_manager/SubDevices.cpp index b65eb6e12cb8f..39163a15c8f91 100644 --- a/sycl/unittests/program_manager/SubDevices.cpp +++ b/sycl/unittests/program_manager/SubDevices.cpp @@ -106,10 +106,8 @@ TEST(SubDevices, DISABLED_BuildProgramForSubdevices) { rootDevice = sycl::detail::getSyclObjImpl(device)->getHandleRef(); // Initialize sub-devices sycl::detail::platform_impl &PltImpl = *sycl::detail::getSyclObjImpl(Plt); - auto subDev1 = - std::make_shared(urSubDev1, PltImpl); - auto subDev2 = - std::make_shared(urSubDev2, PltImpl); + auto subDev1 = PltImpl.getOrMakeDeviceImpl(urSubDev1); + auto subDev2 = PltImpl.getOrMakeDeviceImpl(urSubDev2); sycl::context Ctx{ {device, sycl::detail::createSyclObjFromImpl(subDev1), sycl::detail::createSyclObjFromImpl(subDev2)}}; diff --git a/sycl/unittests/queue/DeviceCheck.cpp b/sycl/unittests/queue/DeviceCheck.cpp index 09e8be76e064c..49ff4fd64f79e 100644 --- a/sycl/unittests/queue/DeviceCheck.cpp +++ b/sycl/unittests/queue/DeviceCheck.cpp @@ -62,6 +62,18 @@ ur_result_t redefinedDevicePartitionAfter(void *pParams) { **params.ppNumDevicesRet = *params.pNumDevices; return UR_RESULT_SUCCESS; } +ur_result_t redefinedPlatformGet(void *pParams) { + auto params = reinterpret_cast(pParams); + if (*params->ppNumPlatforms) + **params->ppNumPlatforms = 2; + + if (*params->pphPlatforms && *params->pNumEntries > 0) { + (*params->pphPlatforms)[0] = reinterpret_cast(1); + (*params->pphPlatforms)[1] = reinterpret_cast(2); + } + + return UR_RESULT_SUCCESS; +} // Check that the device is verified to be either a member of the context or a // descendant of its member. @@ -71,6 +83,8 @@ TEST(QueueDeviceCheck, CheckDeviceRestriction) { detail::SYCLConfig::reset); sycl::unittest::UrMock<> Mock; + mock::getCallbacks().set_replace_callback("urPlatformGet", + &redefinedPlatformGet); sycl::platform Plt = sycl::platform(); UrPlatform = detail::getSyclObjImpl(Plt)->getHandleRef(); @@ -116,12 +130,15 @@ TEST(QueueDeviceCheck, CheckDeviceRestriction) { // Device is neither of the two. { ParentDevice = nullptr; - device Device = detail::createSyclObjFromImpl( - std::make_shared( - reinterpret_cast(0x01), - *detail::getSyclObjImpl(Plt))); + + auto Plts = sycl::platform::get_platforms(); + EXPECT_TRUE(Plts.size() == 2); + sycl::platform OtherPlt = Plts[1]; + + device Device = OtherPlt.get_devices()[0]; queue Q{Device}; - EXPECT_NE(Q.get_context(), DefaultCtx); + auto Ctx = Q.get_context(); + EXPECT_NE(Ctx, DefaultCtx); try { queue Q2{DefaultCtx, Device}; EXPECT_TRUE(false); From e3ed7c3f9fb6bd09b2e95431591db7913fd041f7 Mon Sep 17 00:00:00 2001 From: Andrei Elovikov Date: Tue, 29 Apr 2025 14:27:58 -0700 Subject: [PATCH 2/2] Make `private_tag`'s ctor explicit --- sycl/source/detail/device_impl.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sycl/source/detail/device_impl.hpp b/sycl/source/detail/device_impl.hpp index b11384c555d73..48957e935107e 100644 --- a/sycl/source/detail/device_impl.hpp +++ b/sycl/source/detail/device_impl.hpp @@ -33,7 +33,9 @@ class platform_impl; // TODO: Make code thread-safe class device_impl { - struct private_tag {}; + struct private_tag { + explicit private_tag() = default; + }; friend class platform_impl; public: