From 031fcd0f4ca113f27192fd1b85a128199f681785 Mon Sep 17 00:00:00 2001 From: Andrei Elovikov Date: Fri, 18 Nov 2022 13:39:17 -0800 Subject: [PATCH] [SYCL][Reduction] Support range version with multiple reductions --- sycl/include/sycl/handler.hpp | 34 ++++++------- sycl/include/sycl/reduction.hpp | 66 +++++++++++++++++-------- sycl/include/sycl/reduction_forward.hpp | 8 ++- 3 files changed, 64 insertions(+), 44 deletions(-) diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index 07e7bcba6ddd2..f679f64db82fa 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -2029,25 +2029,24 @@ class __SYCL_EXPORT handler { /// Reductions @{ - template + template std::enable_if_t< - detail::IsReduction::value && + (sizeof...(RestT) > 1) && + detail::AreAllButLastReductions::value && ext::oneapi::experimental::is_property_list::value> - parallel_for(range Range, PropertiesT Properties, Reduction Redu, - _KERNELFUNCPARAM(KernelFunc)) { - detail::reduction_parallel_for(*this, Range, Properties, Redu, - std::move(KernelFunc)); + parallel_for(range Range, PropertiesT Properties, RestT &&...Rest) { + detail::reduction_parallel_for(*this, Range, Properties, + std::forward(Rest)...); } - template - std::enable_if_t::value> - parallel_for(range Range, Reduction Redu, - _KERNELFUNCPARAM(KernelFunc)) { + template + std::enable_if_t::value> + parallel_for(range Range, RestT &&...Rest) { parallel_for( - Range, ext::oneapi::experimental::detail::empty_properties_t{}, Redu, - std::move(KernelFunc)); + Range, ext::oneapi::experimental::detail::empty_properties_t{}, + std::forward(Rest)...); } template - friend void detail::reduction_parallel_for(handler &CGH, range Range, + typename PropertiesT, typename... RestT> + friend void detail::reduction_parallel_for(handler &CGH, range NDRange, PropertiesT Properties, - Reduction Redu, - KernelType KernelFunc); + RestT... Rest); template diff --git a/sycl/include/sycl/reduction.hpp b/sycl/include/sycl/reduction.hpp index 3754d77da18c8..e2bf27cef89e3 100644 --- a/sycl/include/sycl/reduction.hpp +++ b/sycl/include/sycl/reduction.hpp @@ -2302,16 +2302,29 @@ __SYCL_EXPORT uint32_t reduGetMaxNumConcurrentWorkGroups(std::shared_ptr Queue); template + typename PropertiesT, typename... RestT> void reduction_parallel_for(handler &CGH, range Range, - PropertiesT Properties, Reduction Redu, - KernelType KernelFunc) { + PropertiesT Properties, RestT... Rest) { + std::tuple ArgsTuple(Rest...); + constexpr size_t NumArgs = sizeof...(RestT); + static_assert(NumArgs > 1, "No reduction!"); + auto KernelFunc = std::get(ArgsTuple); + auto ReduIndices = std::make_index_sequence(); + auto ReduTuple = detail::tuple_select_elements(ArgsTuple, ReduIndices); + // Before running the kernels, check that device has enough local memory // to hold local arrays required for the tree-reduction algorithm. - constexpr bool IsTreeReduction = - !Reduction::has_fast_reduce && !Reduction::has_fast_atomics; - size_t OneElemSize = - IsTreeReduction ? sizeof(typename Reduction::result_type) : 0; + size_t OneElemSize = [&]() { + if constexpr (NumArgs == 2) { + using Reduction = std::tuple_element_t<0, decltype(ReduTuple)>; + constexpr bool IsTreeReduction = + !Reduction::has_fast_reduce && !Reduction::has_fast_atomics; + return IsTreeReduction ? sizeof(typename Reduction::result_type) : 0; + } else { + return reduGetMemPerWorkItem(ReduTuple, ReduIndices); + } + }(); + uint32_t NumConcurrentWorkGroups = #ifdef __SYCL_REDUCTION_NUM_CONCURRENT_WORKGROUPS __SYCL_REDUCTION_NUM_CONCURRENT_WORKGROUPS; @@ -2341,7 +2354,7 @@ void reduction_parallel_for(handler &CGH, range Range, // stride equal to 1. For each of the index the given the original KernelFunc // is called and the reduction value hold in \p Reducer is accumulated in // those calls. - auto UpdatedKernelFunc = [=](auto NDId, auto &Reducer) { + auto UpdatedKernelFunc = [=](auto NDId, auto &...Reducers) { // Divide into contiguous chunks and assign each chunk to a Group // Rely on precomputed division to avoid repeating expensive operations // TODO: Some devices may prefer alternative remainder handling @@ -2357,23 +2370,34 @@ void reduction_parallel_for(handler &CGH, range Range, size_t End = GroupEnd; size_t Stride = NDId.get_local_range(0); for (size_t I = Start; I < End; I += Stride) - KernelFunc(getDelinearizedId(Range, I), Reducer); + KernelFunc(getDelinearizedId(Range, I), Reducers...); }; + if constexpr (NumArgs == 2) { + using Reduction = std::tuple_element_t<0, decltype(ReduTuple)>; + auto &Redu = std::get<0>(ReduTuple); - constexpr auto StrategyToUse = [&]() { - if constexpr (Strategy != reduction::strategy::auto_select) - return Strategy; + constexpr auto StrategyToUse = [&]() { + if constexpr (Strategy != reduction::strategy::auto_select) + return Strategy; - if constexpr (Reduction::has_fast_reduce) - return reduction::strategy::group_reduce_and_last_wg_detection; - else if constexpr (Reduction::has_fast_atomics) - return reduction::strategy::local_atomic_and_atomic_cross_wg; - else - return reduction::strategy::range_basic; - }(); + if constexpr (Reduction::has_fast_reduce) + return reduction::strategy::group_reduce_and_last_wg_detection; + else if constexpr (Reduction::has_fast_atomics) + return reduction::strategy::local_atomic_and_atomic_cross_wg; + else + return reduction::strategy::range_basic; + }(); - reduction_parallel_for(CGH, NDRange, Properties, - Redu, UpdatedKernelFunc); + reduction_parallel_for(CGH, NDRange, Properties, + Redu, UpdatedKernelFunc); + } else { + return std::apply( + [&](auto &...Reds) { + return reduction_parallel_for( + CGH, NDRange, Properties, Reds..., UpdatedKernelFunc); + }, + ReduTuple); + } } } // namespace detail diff --git a/sycl/include/sycl/reduction_forward.hpp b/sycl/include/sycl/reduction_forward.hpp index 8fd7a8c3b6423..17bd1dfb65bca 100644 --- a/sycl/include/sycl/reduction_forward.hpp +++ b/sycl/include/sycl/reduction_forward.hpp @@ -46,11 +46,9 @@ template void withAuxHandler(handler &CGH, FunctorTy Func); template -void reduction_parallel_for(handler &CGH, range Range, - PropertiesT Properties, Reduction Redu, - KernelType KernelFunc); + int Dims, typename PropertiesT, typename... RestT> +void reduction_parallel_for(handler &CGH, range NDRange, + PropertiesT Properties, RestT... Rest); template