From 2f64dc9d2e9d602a891ce9bf3c4aa5df25cc382b Mon Sep 17 00:00:00 2001 From: "Larsen, Steffen" Date: Wed, 23 Apr 2025 05:46:59 -0700 Subject: [PATCH] [SYCL] Prioritize set kernels over lookup The current implementation of SYCL kernel launches prioritizes looking up kernels through the kernel bundles rather than using the set kernel. These changes instead prioritizes using the kernel, which not only saves the look-up overhead and fixes a kernel implementation lifetime issue caused by #17380. Signed-off-by: Larsen, Steffen --- sycl/source/detail/graph_impl.cpp | 16 +++---- sycl/source/detail/helpers.cpp | 14 +++--- sycl/source/detail/scheduler/commands.cpp | 54 +++++++++++------------ 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index b5f1f53e32d7f..8cb8647c49ee9 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -1474,18 +1474,18 @@ void exec_graph_impl::populateURKernelUpdateStructs( ur_kernel_handle_t UrKernel = nullptr; auto Kernel = ExecCG.MSyclKernel; auto KernelBundleImplPtr = ExecCG.MKernelBundle; - std::shared_ptr SyclKernelImpl = nullptr; const sycl::detail::KernelArgMask *EliminatedArgMask = nullptr; - if (auto SyclKernelImpl = KernelBundleImplPtr - ? KernelBundleImplPtr->tryGetKernel( - ExecCG.MKernelName, KernelBundleImplPtr) - : std::shared_ptr{nullptr}) { - UrKernel = SyclKernelImpl->getHandleRef(); - EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); - } else if (Kernel != nullptr) { + if (Kernel != nullptr) { UrKernel = Kernel->getHandleRef(); EliminatedArgMask = Kernel->getKernelArgMask(); + } else if (auto SyclKernelImpl = + KernelBundleImplPtr + ? KernelBundleImplPtr->tryGetKernel(ExecCG.MKernelName, + KernelBundleImplPtr) + : std::shared_ptr{nullptr}) { + UrKernel = SyclKernelImpl->getHandleRef(); + EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); } else { ur_program_handle_t UrProgram = nullptr; std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) = diff --git a/sycl/source/detail/helpers.cpp b/sycl/source/detail/helpers.cpp index 8a50f61070909..51fa452c3ccf2 100644 --- a/sycl/source/detail/helpers.cpp +++ b/sycl/source/detail/helpers.cpp @@ -72,16 +72,16 @@ retrieveKernelBinary(const QueueImplPtr &Queue, const char *KernelName, const RTDeviceBinaryImage *DeviceImage = nullptr; ur_program_handle_t Program = nullptr; auto KernelBundleImpl = KernelCG->getKernelBundle(); - if (auto SyclKernelImpl = - KernelBundleImpl - ? KernelBundleImpl->tryGetKernel(KernelName, KernelBundleImpl) - : std::shared_ptr{nullptr}) { + if (KernelCG->MSyclKernel != nullptr) { + DeviceImage = KernelCG->MSyclKernel->getDeviceImage()->get_bin_image_ref(); + Program = KernelCG->MSyclKernel->getDeviceImage()->get_ur_program_ref(); + } else if (auto SyclKernelImpl = + KernelBundleImpl ? KernelBundleImpl->tryGetKernel( + KernelName, KernelBundleImpl) + : std::shared_ptr{nullptr}) { // Retrieve the device image from the kernel bundle. DeviceImage = SyclKernelImpl->getDeviceImage()->get_bin_image_ref(); Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref(); - } else if (KernelCG->MSyclKernel != nullptr) { - DeviceImage = KernelCG->MSyclKernel->getDeviceImage()->get_bin_image_ref(); - Program = KernelCG->MSyclKernel->getDeviceImage()->get_ur_program_ref(); } else { auto ContextImpl = Queue->getContextImplPtr(); auto DeviceImpl = Queue->getDeviceImplPtr(); diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index b8b07e0053881..a4b6e1d115f68 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -1997,16 +1997,16 @@ void instrumentationAddExtraKernelMetadata( std::mutex *KernelMutex = nullptr; const KernelArgMask *EliminatedArgMask = nullptr; - if (auto SyclKernelImpl = KernelBundleImplPtr - ? KernelBundleImplPtr->tryGetKernel( - KernelName, KernelBundleImplPtr) - : std::shared_ptr{nullptr}) { - EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); - Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref(); - } else if (nullptr != SyclKernel) { + if (nullptr != SyclKernel) { Program = SyclKernel->getProgramRef(); if (!SyclKernel->isCreatedFromSource()) EliminatedArgMask = SyclKernel->getKernelArgMask(); + } else if (auto SyclKernelImpl = + KernelBundleImplPtr ? KernelBundleImplPtr->tryGetKernel( + KernelName, KernelBundleImplPtr) + : std::shared_ptr{nullptr}) { + EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); + Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref(); } else if (Queue) { // NOTE: Queue can be null when kernel is directly enqueued to a command // buffer @@ -2521,17 +2521,17 @@ getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl, const KernelArgMask *EliminatedArgMask = nullptr; auto &KernelBundleImplPtr = CommandGroup.MKernelBundle; - if (auto SyclKernelImpl = - KernelBundleImplPtr - ? KernelBundleImplPtr->tryGetKernel(CommandGroup.MKernelName, - KernelBundleImplPtr) - : std::shared_ptr{nullptr}) { + if (auto Kernel = CommandGroup.MSyclKernel; Kernel != nullptr) { + UrKernel = Kernel->getHandleRef(); + EliminatedArgMask = Kernel->getKernelArgMask(); + } else if (auto SyclKernelImpl = + KernelBundleImplPtr + ? KernelBundleImplPtr->tryGetKernel( + CommandGroup.MKernelName, KernelBundleImplPtr) + : std::shared_ptr{nullptr}) { UrKernel = SyclKernelImpl->getHandleRef(); DeviceImageImpl = SyclKernelImpl->getDeviceImage(); EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); - } else if (auto Kernel = CommandGroup.MSyclKernel; Kernel != nullptr) { - UrKernel = Kernel->getHandleRef(); - EliminatedArgMask = Kernel->getKernelArgMask(); } else { ur_program_handle_t UrProgram = nullptr; std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) = @@ -2678,18 +2678,7 @@ void enqueueImpKernel( std::shared_ptr SyclKernelImpl; std::shared_ptr DeviceImageImpl; - if ((SyclKernelImpl = KernelBundleImplPtr - ? KernelBundleImplPtr->tryGetKernel( - KernelName, KernelBundleImplPtr) - : std::shared_ptr{nullptr})) { - Kernel = SyclKernelImpl->getHandleRef(); - DeviceImageImpl = SyclKernelImpl->getDeviceImage(); - - Program = DeviceImageImpl->get_ur_program_ref(); - - EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); - KernelMutex = SyclKernelImpl->getCacheMutex(); - } else if (nullptr != MSyclKernel) { + if (nullptr != MSyclKernel) { assert(MSyclKernel->get_info() == Queue->get_context()); Kernel = MSyclKernel->getHandleRef(); @@ -2703,6 +2692,17 @@ void enqueueImpKernel( // their duplication in such cases. KernelMutex = &MSyclKernel->getNoncacheableEnqueueMutex(); EliminatedArgMask = MSyclKernel->getKernelArgMask(); + } else if ((SyclKernelImpl = KernelBundleImplPtr + ? KernelBundleImplPtr->tryGetKernel( + KernelName, KernelBundleImplPtr) + : std::shared_ptr{nullptr})) { + Kernel = SyclKernelImpl->getHandleRef(); + DeviceImageImpl = SyclKernelImpl->getDeviceImage(); + + Program = DeviceImageImpl->get_ur_program_ref(); + + EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); + KernelMutex = SyclKernelImpl->getCacheMutex(); } else { std::tie(Kernel, KernelMutex, EliminatedArgMask, Program) = detail::ProgramManager::getInstance().getOrCreateKernel(