diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index b5f1f53e32d7f..8cb8647c49ee9 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -1474,18 +1474,18 @@ void exec_graph_impl::populateURKernelUpdateStructs( ur_kernel_handle_t UrKernel = nullptr; auto Kernel = ExecCG.MSyclKernel; auto KernelBundleImplPtr = ExecCG.MKernelBundle; - std::shared_ptr SyclKernelImpl = nullptr; const sycl::detail::KernelArgMask *EliminatedArgMask = nullptr; - if (auto SyclKernelImpl = KernelBundleImplPtr - ? KernelBundleImplPtr->tryGetKernel( - ExecCG.MKernelName, KernelBundleImplPtr) - : std::shared_ptr{nullptr}) { - UrKernel = SyclKernelImpl->getHandleRef(); - EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); - } else if (Kernel != nullptr) { + if (Kernel != nullptr) { UrKernel = Kernel->getHandleRef(); EliminatedArgMask = Kernel->getKernelArgMask(); + } else if (auto SyclKernelImpl = + KernelBundleImplPtr + ? KernelBundleImplPtr->tryGetKernel(ExecCG.MKernelName, + KernelBundleImplPtr) + : std::shared_ptr{nullptr}) { + UrKernel = SyclKernelImpl->getHandleRef(); + EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); } else { ur_program_handle_t UrProgram = nullptr; std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) = diff --git a/sycl/source/detail/helpers.cpp b/sycl/source/detail/helpers.cpp index 8a50f61070909..51fa452c3ccf2 100644 --- a/sycl/source/detail/helpers.cpp +++ b/sycl/source/detail/helpers.cpp @@ -72,16 +72,16 @@ retrieveKernelBinary(const QueueImplPtr &Queue, const char *KernelName, const RTDeviceBinaryImage *DeviceImage = nullptr; ur_program_handle_t Program = nullptr; auto KernelBundleImpl = KernelCG->getKernelBundle(); - if (auto SyclKernelImpl = - KernelBundleImpl - ? KernelBundleImpl->tryGetKernel(KernelName, KernelBundleImpl) - : std::shared_ptr{nullptr}) { + if (KernelCG->MSyclKernel != nullptr) { + DeviceImage = KernelCG->MSyclKernel->getDeviceImage()->get_bin_image_ref(); + Program = KernelCG->MSyclKernel->getDeviceImage()->get_ur_program_ref(); + } else if (auto SyclKernelImpl = + KernelBundleImpl ? KernelBundleImpl->tryGetKernel( + KernelName, KernelBundleImpl) + : std::shared_ptr{nullptr}) { // Retrieve the device image from the kernel bundle. DeviceImage = SyclKernelImpl->getDeviceImage()->get_bin_image_ref(); Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref(); - } else if (KernelCG->MSyclKernel != nullptr) { - DeviceImage = KernelCG->MSyclKernel->getDeviceImage()->get_bin_image_ref(); - Program = KernelCG->MSyclKernel->getDeviceImage()->get_ur_program_ref(); } else { auto ContextImpl = Queue->getContextImplPtr(); auto DeviceImpl = Queue->getDeviceImplPtr(); diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index b8b07e0053881..a4b6e1d115f68 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -1997,16 +1997,16 @@ void instrumentationAddExtraKernelMetadata( std::mutex *KernelMutex = nullptr; const KernelArgMask *EliminatedArgMask = nullptr; - if (auto SyclKernelImpl = KernelBundleImplPtr - ? KernelBundleImplPtr->tryGetKernel( - KernelName, KernelBundleImplPtr) - : std::shared_ptr{nullptr}) { - EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); - Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref(); - } else if (nullptr != SyclKernel) { + if (nullptr != SyclKernel) { Program = SyclKernel->getProgramRef(); if (!SyclKernel->isCreatedFromSource()) EliminatedArgMask = SyclKernel->getKernelArgMask(); + } else if (auto SyclKernelImpl = + KernelBundleImplPtr ? KernelBundleImplPtr->tryGetKernel( + KernelName, KernelBundleImplPtr) + : std::shared_ptr{nullptr}) { + EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); + Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref(); } else if (Queue) { // NOTE: Queue can be null when kernel is directly enqueued to a command // buffer @@ -2521,17 +2521,17 @@ getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl, const KernelArgMask *EliminatedArgMask = nullptr; auto &KernelBundleImplPtr = CommandGroup.MKernelBundle; - if (auto SyclKernelImpl = - KernelBundleImplPtr - ? KernelBundleImplPtr->tryGetKernel(CommandGroup.MKernelName, - KernelBundleImplPtr) - : std::shared_ptr{nullptr}) { + if (auto Kernel = CommandGroup.MSyclKernel; Kernel != nullptr) { + UrKernel = Kernel->getHandleRef(); + EliminatedArgMask = Kernel->getKernelArgMask(); + } else if (auto SyclKernelImpl = + KernelBundleImplPtr + ? KernelBundleImplPtr->tryGetKernel( + CommandGroup.MKernelName, KernelBundleImplPtr) + : std::shared_ptr{nullptr}) { UrKernel = SyclKernelImpl->getHandleRef(); DeviceImageImpl = SyclKernelImpl->getDeviceImage(); EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); - } else if (auto Kernel = CommandGroup.MSyclKernel; Kernel != nullptr) { - UrKernel = Kernel->getHandleRef(); - EliminatedArgMask = Kernel->getKernelArgMask(); } else { ur_program_handle_t UrProgram = nullptr; std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) = @@ -2678,18 +2678,7 @@ void enqueueImpKernel( std::shared_ptr SyclKernelImpl; std::shared_ptr DeviceImageImpl; - if ((SyclKernelImpl = KernelBundleImplPtr - ? KernelBundleImplPtr->tryGetKernel( - KernelName, KernelBundleImplPtr) - : std::shared_ptr{nullptr})) { - Kernel = SyclKernelImpl->getHandleRef(); - DeviceImageImpl = SyclKernelImpl->getDeviceImage(); - - Program = DeviceImageImpl->get_ur_program_ref(); - - EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); - KernelMutex = SyclKernelImpl->getCacheMutex(); - } else if (nullptr != MSyclKernel) { + if (nullptr != MSyclKernel) { assert(MSyclKernel->get_info() == Queue->get_context()); Kernel = MSyclKernel->getHandleRef(); @@ -2703,6 +2692,17 @@ void enqueueImpKernel( // their duplication in such cases. KernelMutex = &MSyclKernel->getNoncacheableEnqueueMutex(); EliminatedArgMask = MSyclKernel->getKernelArgMask(); + } else if ((SyclKernelImpl = KernelBundleImplPtr + ? KernelBundleImplPtr->tryGetKernel( + KernelName, KernelBundleImplPtr) + : std::shared_ptr{nullptr})) { + Kernel = SyclKernelImpl->getHandleRef(); + DeviceImageImpl = SyclKernelImpl->getDeviceImage(); + + Program = DeviceImageImpl->get_ur_program_ref(); + + EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); + KernelMutex = SyclKernelImpl->getCacheMutex(); } else { std::tie(Kernel, KernelMutex, EliminatedArgMask, Program) = detail::ProgramManager::getInstance().getOrCreateKernel(