From 6ca949055f70a44d90e0666082b97a0a20659af7 Mon Sep 17 00:00:00 2001 From: Andrei Elovikov Date: Tue, 24 Jun 2025 15:22:19 -0700 Subject: [PATCH] [NFC][SYCL] Raw `context_impl` in `getInteropContext` and `queue_impl` ctor Splitting into two PRs would result in unnecessary temporarily adjustments and merge conflicts later on between these changes, so perform in a single PR. They are both small enough anyway. Continuation of the refactoring in https://github.com/intel/llvm/pull/18795 https://github.com/intel/llvm/pull/18877 https://github.com/intel/llvm/pull/18966 https://github.com/intel/llvm/pull/18979 https://github.com/intel/llvm/pull/18980 https://github.com/intel/llvm/pull/18981 https://github.com/intel/llvm/pull/19007 https://github.com/intel/llvm/pull/19030 https://github.com/intel/llvm/pull/19123 --- sycl/source/backend.cpp | 4 +- sycl/source/detail/graph_impl.cpp | 2 +- sycl/source/detail/queue_impl.hpp | 39 +++++++++---------- .../source/detail/scheduler/graph_builder.cpp | 6 +-- sycl/source/detail/sycl_mem_obj_i.hpp | 3 +- sycl/source/detail/sycl_mem_obj_t.hpp | 7 ++-- sycl/source/queue.cpp | 6 +-- .../scheduler/HostTaskAndBarrier.cpp | 8 ++-- .../scheduler/LinkedAllocaDependencies.cpp | 3 +- 9 files changed, 37 insertions(+), 41 deletions(-) diff --git a/sycl/source/backend.cpp b/sycl/source/backend.cpp index d262256634d61..98c5ceb06f7fc 100644 --- a/sycl/source/backend.cpp +++ b/sycl/source/backend.cpp @@ -126,7 +126,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle, ur_device_handle_t UrDevice = Device ? getSyclObjImpl(*Device)->getHandleRef() : nullptr; const auto &Adapter = getAdapter(Backend); - const auto &ContextImpl = getSyclObjImpl(Context); + context_impl &ContextImpl = *getSyclObjImpl(Context); if (PropList.has_property()) { throw sycl::exception( @@ -156,7 +156,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle, ur_queue_handle_t UrQueue = nullptr; Adapter->call( - NativeHandle, ContextImpl->getHandleRef(), UrDevice, &NativeProperties, + NativeHandle, ContextImpl.getHandleRef(), UrDevice, &NativeProperties, &UrQueue); // Construct the SYCL queue from UR queue. return detail::createSyclObjFromImpl( diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index bf4beb5ae83d8..65aaafc80af49 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -996,7 +996,7 @@ exec_graph_impl::exec_graph_impl(sycl::context Context, : MSchedule(), MGraphImpl(GraphImpl), MSyncPoints(), MQueueImpl(sycl::detail::queue_impl::create( *sycl::detail::getSyclObjImpl(GraphImpl->getDevice()), - sycl::detail::getSyclObjImpl(Context), sycl::async_handler{}, + *sycl::detail::getSyclObjImpl(Context), sycl::async_handler{}, sycl::property_list{})), MDevice(GraphImpl->getDevice()), MContext(Context), MRequirements(), MSchedulerDependencies(), diff --git a/sycl/source/detail/queue_impl.hpp b/sycl/source/detail/queue_impl.hpp index c0ac158b61d1a..65ee8b673fada 100644 --- a/sycl/source/detail/queue_impl.hpp +++ b/sycl/source/detail/queue_impl.hpp @@ -117,11 +117,11 @@ class queue_impl : public std::enable_shared_from_this { /// constructed. /// \param AsyncHandler is a SYCL asynchronous exception handler. /// \param PropList is a list of properties to use for queue construction. - queue_impl(device_impl &Device, const ContextImplPtr &Context, + queue_impl(device_impl &Device, std::shared_ptr &&Context, const async_handler &AsyncHandler, const property_list &PropList, private_tag) - : MDevice(Device), MContext(Context), MAsyncHandler(AsyncHandler), - MPropList(PropList), + : MDevice(Device), MContext(std::move(Context)), + MAsyncHandler(AsyncHandler), MPropList(PropList), MIsInorder(has_property()), MIsProfilingEnabled(has_property()), MQueueID{ @@ -146,8 +146,8 @@ class queue_impl : public std::enable_shared_from_this { "Queue compute index must be a non-negative number less than " "device's number of available compute queue indices."); } - if (!Context->isDeviceValid(Device)) { - if (Context->getBackend() == backend::opencl) + if (!MContext->isDeviceValid(Device)) { + if (MContext->getBackend() == backend::opencl) throw sycl::exception( make_error_code(errc::invalid), "Queue cannot be constructed with the given context and device " @@ -177,17 +177,13 @@ class queue_impl : public std::enable_shared_from_this { trySwitchingToNoEventsMode(); } - sycl::detail::optional getLastEvent(); + queue_impl(device_impl &Device, context_impl &Context, + const async_handler &AsyncHandler, const property_list &PropList, + private_tag Tag) + : queue_impl(Device, Context.shared_from_this(), AsyncHandler, PropList, + Tag) {} - /// Constructs a SYCL queue from adapter interoperability handle. - /// - /// \param UrQueue is a raw UR queue handle. - /// \param Context is a SYCL context to associate with the queue being - /// constructed. - /// \param AsyncHandler is a SYCL asynchronous exception handler. - queue_impl(ur_queue_handle_t UrQueue, const ContextImplPtr &Context, - const async_handler &AsyncHandler, private_tag tag) - : queue_impl(UrQueue, Context, AsyncHandler, {}, tag) {} + sycl::detail::optional getLastEvent(); /// Constructs a SYCL queue from adapter interoperability handle. /// @@ -196,18 +192,18 @@ class queue_impl : public std::enable_shared_from_this { /// constructed. /// \param AsyncHandler is a SYCL asynchronous exception handler. /// \param PropList is the queue properties. - queue_impl(ur_queue_handle_t UrQueue, const ContextImplPtr &Context, + queue_impl(ur_queue_handle_t UrQueue, context_impl &Context, const async_handler &AsyncHandler, const property_list &PropList, private_tag) : MDevice([&]() -> device_impl & { ur_device_handle_t DeviceUr{}; - const AdapterPtr &Adapter = Context->getAdapter(); + const AdapterPtr &Adapter = Context.getAdapter(); // TODO catch an exception and put it to list of asynchronous // exceptions Adapter->call( UrQueue, UR_QUEUE_INFO_DEVICE, sizeof(DeviceUr), &DeviceUr, nullptr); - device_impl *Device = Context->findMatchingDeviceImpl(DeviceUr); + device_impl *Device = Context.findMatchingDeviceImpl(DeviceUr); if (Device == nullptr) { throw sycl::exception( make_error_code(errc::invalid), @@ -215,8 +211,9 @@ class queue_impl : public std::enable_shared_from_this { } return *Device; }()), - MContext(Context), MAsyncHandler(AsyncHandler), MPropList(PropList), - MQueue(UrQueue), MIsInorder(has_property()), + MContext(Context.shared_from_this()), MAsyncHandler(AsyncHandler), + MPropList(PropList), MQueue(UrQueue), + MIsInorder(has_property()), MIsProfilingEnabled(has_property()), MQueueID{ MNextAvailableQueueID.fetch_add(1, std::memory_order_relaxed)} { @@ -988,7 +985,7 @@ class queue_impl : public std::enable_shared_from_this { mutable std::mutex MMutex; device_impl &MDevice; - const ContextImplPtr MContext; + const std::shared_ptr MContext; /// These events are tracked, but not owned, by the queue. std::vector> MEventsWeak; diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index 8f7bfb29fb109..b56c1654ca21c 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -210,7 +210,7 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue, cleanupCommand(Cmd); }; - const ContextImplPtr &InteropCtxPtr = Req->MSYCLMemObj->getInteropContext(); + context_impl *InteropCtxPtr = Req->MSYCLMemObj->getInteropContext(); if (InteropCtxPtr) { // The memory object has been constructed using interoperability constructor // which means that there is already an allocation(cl_mem) in some context. @@ -225,10 +225,10 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue, // here, we need to create a dummy queue bound to the context and one of the // devices from the context. std::shared_ptr InteropQueuePtr = queue_impl::create( - Dev, InteropCtxPtr, async_handler{}, property_list{}); + Dev, *InteropCtxPtr, async_handler{}, property_list{}); MemObject->MRecord.reset( - new MemObjRecord{InteropCtxPtr.get(), LeafLimit, AllocateDependency}); + new MemObjRecord{InteropCtxPtr, LeafLimit, AllocateDependency}); std::vector ToEnqueue; getOrCreateAllocaForReq(MemObject->MRecord.get(), Req, InteropQueuePtr, ToEnqueue); diff --git a/sycl/source/detail/sycl_mem_obj_i.hpp b/sycl/source/detail/sycl_mem_obj_i.hpp index 68c8de30cfd21..4ea7840c6d179 100644 --- a/sycl/source/detail/sycl_mem_obj_i.hpp +++ b/sycl/source/detail/sycl_mem_obj_i.hpp @@ -22,7 +22,6 @@ class context_impl; struct MemObjRecord; using EventImplPtr = std::shared_ptr; -using ContextImplPtr = std::shared_ptr; // The class serves as an interface in the scheduler for all SYCL memory // objects. @@ -72,7 +71,7 @@ class SYCLMemObjI { // Returns the context which is passed if a memory object is created using // interoperability constructor, nullptr otherwise. - virtual ContextImplPtr getInteropContext() const = 0; + virtual detail::context_impl *getInteropContext() const = 0; protected: // Pointer to the record that contains the memory commands. This is managed diff --git a/sycl/source/detail/sycl_mem_obj_t.hpp b/sycl/source/detail/sycl_mem_obj_t.hpp index a5fbdcdd7d9c1..fc302e3750743 100644 --- a/sycl/source/detail/sycl_mem_obj_t.hpp +++ b/sycl/source/detail/sycl_mem_obj_t.hpp @@ -36,7 +36,6 @@ class event_impl; class Adapter; using AdapterPtr = std::shared_ptr; -using ContextImplPtr = std::shared_ptr; using EventImplPtr = std::shared_ptr; // The class serves as a base for all SYCL memory objects. @@ -281,7 +280,9 @@ class SYCLMemObjT : public SYCLMemObjI { MemObjType getType() const override { return MemObjType::Undefined; } - ContextImplPtr getInteropContext() const override { return MInteropContext; } + context_impl *getInteropContext() const override { + return MInteropContext.get(); + } bool isInterop() const override; @@ -339,7 +340,7 @@ class SYCLMemObjT : public SYCLMemObjI { // Should wait on this event before start working with such memory object. EventImplPtr MInteropEvent; // Context passed by user to interoperability constructor. - ContextImplPtr MInteropContext; + std::shared_ptr MInteropContext; // Native backend memory object handle passed by user to interoperability // constructor. ur_mem_handle_t MInteropMemObject; diff --git a/sycl/source/queue.cpp b/sycl/source/queue.cpp index d0e8078aacfdd..f92febe6bcece 100644 --- a/sycl/source/queue.cpp +++ b/sycl/source/queue.cpp @@ -65,14 +65,14 @@ queue::queue(const context &SyclContext, const device_selector &DeviceSelector, const device &SyclDevice = *std::max_element(Devs.begin(), Devs.end(), Comp); impl = detail::queue_impl::create(*detail::getSyclObjImpl(SyclDevice), - detail::getSyclObjImpl(SyclContext), + *detail::getSyclObjImpl(SyclContext), AsyncHandler, PropList); } queue::queue(const context &SyclContext, const device &SyclDevice, const async_handler &AsyncHandler, const property_list &PropList) { impl = detail::queue_impl::create(*detail::getSyclObjImpl(SyclDevice), - detail::getSyclObjImpl(SyclContext), + *detail::getSyclObjImpl(SyclContext), AsyncHandler, PropList); } @@ -100,7 +100,7 @@ queue::queue(cl_command_queue clQueue, const context &SyclContext, impl = detail::queue_impl::create( // TODO(pi2ur): Don't cast straight from cl_command_queue reinterpret_cast(clQueue), - detail::getSyclObjImpl(SyclContext), AsyncHandler, PropList); + *detail::getSyclObjImpl(SyclContext), AsyncHandler, PropList); } cl_command_queue queue::get() const { return impl->get(); } diff --git a/sycl/unittests/scheduler/HostTaskAndBarrier.cpp b/sycl/unittests/scheduler/HostTaskAndBarrier.cpp index 8c68af2c53b79..e5ab7b00a2a28 100644 --- a/sycl/unittests/scheduler/HostTaskAndBarrier.cpp +++ b/sycl/unittests/scheduler/HostTaskAndBarrier.cpp @@ -20,15 +20,15 @@ namespace { using namespace sycl; using EventImplPtr = std::shared_ptr; -using ContextImplPtr = std::shared_ptr; constexpr auto DisableCleanupName = "SYCL_DISABLE_EXECUTION_GRAPH_CLEANUP"; class TestQueueImpl : public sycl::detail::queue_impl { public: - TestQueueImpl(ContextImplPtr SyclContext, sycl::detail::device_impl &Dev) + TestQueueImpl(sycl::detail::context_impl &SyclContext, + sycl::detail::device_impl &Dev) : sycl::detail::queue_impl(Dev, SyclContext, - SyclContext->get_async_handler(), {}, + SyclContext.get_async_handler(), {}, sycl::detail::queue_impl::private_tag{}) {} using sycl::detail::queue_impl::MDefaultGraphDeps; using sycl::detail::queue_impl::MExtGraphDeps; @@ -46,7 +46,7 @@ class BarrierHandlingWithHostTask : public ::testing::Test { sycl::device SyclDev = sycl::detail::select_device(sycl::default_selector_v, SyclContext); QueueDevImpl.reset( - new TestQueueImpl(sycl::detail::getSyclObjImpl(SyclContext), + new TestQueueImpl(*sycl::detail::getSyclObjImpl(SyclContext), *sycl::detail::getSyclObjImpl(SyclDev))); MainLock.lock(); diff --git a/sycl/unittests/scheduler/LinkedAllocaDependencies.cpp b/sycl/unittests/scheduler/LinkedAllocaDependencies.cpp index 96d688bfedd00..81dafa6b5ef43 100644 --- a/sycl/unittests/scheduler/LinkedAllocaDependencies.cpp +++ b/sycl/unittests/scheduler/LinkedAllocaDependencies.cpp @@ -38,8 +38,7 @@ class MemObjMock : public sycl::detail::SYCLMemObjI { bool isHostPointerReadOnly() const override { return false; } bool usesPinnedHostMemory() const override { return false; } - std::shared_ptr - getInteropContext() const override { + sycl::detail::context_impl *getInteropContext() const override { return nullptr; } };