diff --git a/sycl/include/sycl/ext/oneapi/reduction.hpp b/sycl/include/sycl/ext/oneapi/reduction.hpp index 83f02301a7624..3005dc3417944 100644 --- a/sycl/include/sycl/ext/oneapi/reduction.hpp +++ b/sycl/include/sycl/ext/oneapi/reduction.hpp @@ -30,6 +30,14 @@ namespace oneapi { namespace detail { +template +event withAuxHandler(std::shared_ptr Queue, bool IsHost, + FunctorTy Func) { + handler AuxHandler(Queue, IsHost); + Func(AuxHandler); + return AuxHandler.finalize(); +} + using cl::sycl::detail::bool_constant; using cl::sycl::detail::enable_if_t; using cl::sycl::detail::queue_impl; @@ -2434,6 +2442,7 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize, bool Pow2WG = (WGSize & (WGSize - 1)) == 0; bool IsOneWG = NWorkGroups == 1; + bool HasUniformWG = Pow2WG && (NWorkGroups * WGSize == NWorkItems); // Like reduCGFuncImpl, we also have to split out scalar and array reductions IsScalarReduction ScalarPredicate; @@ -2442,28 +2451,27 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize, IsArrayReduction ArrayPredicate; auto ArrayIs = filterSequence(ArrayPredicate, ReduIndices); + size_t LocalAccSize = WGSize + (HasUniformWG ? 0 : 1); + auto LocalAccsTuple = + createReduLocalAccs(LocalAccSize, CGH, ReduIndices); + auto InAccsTuple = + getReadAccsToPreviousPartialReds(CGH, ReduTuple, ReduIndices); + + auto IdentitiesTuple = getReduIdentities(ReduTuple, ReduIndices); + auto BOPsTuple = getReduBOPs(ReduTuple, ReduIndices); + auto InitToIdentityProps = + getInitToIdentityProperties(ReduTuple, ReduIndices); + // Predicate/OutAccsTuple below have different type depending on us having // just a single WG or multiple WGs. Use this lambda to avoid code // duplication. auto Rest = [&](auto Predicate, auto OutAccsTuple) { auto AccReduIndices = filterSequence(Predicate, ReduIndices); associateReduAccsWithHandler(CGH, ReduTuple, AccReduIndices); - - size_t LocalAccSize = WGSize + (Pow2WG ? 0 : 1); - auto LocalAccsTuple = - createReduLocalAccs(LocalAccSize, CGH, ReduIndices); - auto InAccsTuple = - getReadAccsToPreviousPartialReds(CGH, ReduTuple, ReduIndices); - - auto IdentitiesTuple = getReduIdentities(ReduTuple, ReduIndices); - auto BOPsTuple = getReduBOPs(ReduTuple, ReduIndices); - auto InitToIdentityProps = - getInitToIdentityProperties(ReduTuple, ReduIndices); - using Name = __sycl_reduction_kernel; // TODO: Opportunity to parallelize across number of elements - range<1> GlobalRange = {Pow2WG ? NWorkItems : NWorkGroups * WGSize}; + range<1> GlobalRange = {HasUniformWG ? NWorkItems : NWorkGroups * WGSize}; nd_range<1> Range{GlobalRange, range<1>(WGSize)}; CGH.parallel_for(Range, [=](nd_item<1> NDIt) { size_t WGSize = NDIt.get_local_range().size(); @@ -2472,12 +2480,12 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize, // Handle scalar and array reductions reduAuxCGFuncImplScalar( - Pow2WG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, LocalAccsTuple, - InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple, + HasUniformWG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, + LocalAccsTuple, InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple, InitToIdentityProps, ScalarIs); reduAuxCGFuncImplArray( - Pow2WG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, LocalAccsTuple, - InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple, + HasUniformWG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, + LocalAccsTuple, InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple, InitToIdentityProps, ArrayIs); }); }; @@ -2504,7 +2512,7 @@ void reduSaveFinalResultToUserMemHelper( if constexpr (!Reduction::is_usm) { if (Redu.hasUserDiscardWriteAccessor()) { event CopyEvent = - handler::withAuxHandler(Queue, IsHost, [&](handler &CopyHandler) { + withAuxHandler(Queue, IsHost, [&](handler &CopyHandler) { auto InAcc = Redu.getReadAccToPreviousPartialReds(CopyHandler); auto OutAcc = Redu.getUserDiscardWriteAccessor(); Redu.associateWithHandler(CopyHandler); diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index b682a9ca50d9f..c6ee4bc17c675 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -315,6 +315,9 @@ tuple_select_elements(TupleT Tuple, std::index_sequence); template struct AreAllButLastReductions; +template +event withAuxHandler(std::shared_ptr Queue, bool IsHost, + FunctorTy Func); } // namespace detail } // namespace oneapi } // namespace ext @@ -476,12 +479,9 @@ class __SYCL_EXPORT handler { } template - static event withAuxHandler(std::shared_ptr Queue, - bool IsHost, FunctorTy Func) { - handler AuxHandler(Queue, IsHost); - Func(AuxHandler); - return AuxHandler.finalize(); - } + friend event + ext::oneapi::detail::withAuxHandler(std::shared_ptr Queue, + bool IsHost, FunctorTy Func); /// }@ /// Saves buffers created by handling reduction feature in handler.