Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions sycl/include/sycl/ext/oneapi/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ namespace oneapi {

namespace detail {

template <class FunctorTy>
event withAuxHandler(std::shared_ptr<detail::queue_impl> Queue, bool IsHost,
FunctorTy Func) {
handler AuxHandler(Queue, IsHost);
Func(AuxHandler);
return AuxHandler.finalize();
}

using cl::sycl::detail::bool_constant;
using cl::sycl::detail::enable_if_t;
using cl::sycl::detail::queue_impl;
Expand Down Expand Up @@ -2434,6 +2442,7 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,

bool Pow2WG = (WGSize & (WGSize - 1)) == 0;
bool IsOneWG = NWorkGroups == 1;
bool HasUniformWG = Pow2WG && (NWorkGroups * WGSize == NWorkItems);

// Like reduCGFuncImpl, we also have to split out scalar and array reductions
IsScalarReduction ScalarPredicate;
Expand All @@ -2442,28 +2451,27 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
IsArrayReduction ArrayPredicate;
auto ArrayIs = filterSequence<Reductions...>(ArrayPredicate, ReduIndices);

size_t LocalAccSize = WGSize + (HasUniformWG ? 0 : 1);
auto LocalAccsTuple =
createReduLocalAccs<Reductions...>(LocalAccSize, CGH, ReduIndices);
auto InAccsTuple =
getReadAccsToPreviousPartialReds(CGH, ReduTuple, ReduIndices);

auto IdentitiesTuple = getReduIdentities(ReduTuple, ReduIndices);
auto BOPsTuple = getReduBOPs(ReduTuple, ReduIndices);
auto InitToIdentityProps =
getInitToIdentityProperties(ReduTuple, ReduIndices);

// Predicate/OutAccsTuple below have different type depending on us having
// just a single WG or multiple WGs. Use this lambda to avoid code
// duplication.
auto Rest = [&](auto Predicate, auto OutAccsTuple) {
auto AccReduIndices = filterSequence<Reductions...>(Predicate, ReduIndices);
associateReduAccsWithHandler(CGH, ReduTuple, AccReduIndices);

size_t LocalAccSize = WGSize + (Pow2WG ? 0 : 1);
auto LocalAccsTuple =
createReduLocalAccs<Reductions...>(LocalAccSize, CGH, ReduIndices);
auto InAccsTuple =
getReadAccsToPreviousPartialReds(CGH, ReduTuple, ReduIndices);

auto IdentitiesTuple = getReduIdentities(ReduTuple, ReduIndices);
auto BOPsTuple = getReduBOPs(ReduTuple, ReduIndices);
auto InitToIdentityProps =
getInitToIdentityProperties(ReduTuple, ReduIndices);

using Name = __sycl_reduction_kernel<reduction::aux_krn::Multi, KernelName,
decltype(OutAccsTuple)>;
// TODO: Opportunity to parallelize across number of elements
range<1> GlobalRange = {Pow2WG ? NWorkItems : NWorkGroups * WGSize};
range<1> GlobalRange = {HasUniformWG ? NWorkItems : NWorkGroups * WGSize};
nd_range<1> Range{GlobalRange, range<1>(WGSize)};
CGH.parallel_for<Name>(Range, [=](nd_item<1> NDIt) {
size_t WGSize = NDIt.get_local_range().size();
Expand All @@ -2472,12 +2480,12 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,

// Handle scalar and array reductions
reduAuxCGFuncImplScalar<Reductions...>(
Pow2WG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, LocalAccsTuple,
InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
HasUniformWG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize,
LocalAccsTuple, InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
InitToIdentityProps, ScalarIs);
reduAuxCGFuncImplArray<Reductions...>(
Pow2WG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, LocalAccsTuple,
InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
HasUniformWG, IsOneWG, NDIt, LID, GID, NWorkItems, WGSize,
LocalAccsTuple, InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
InitToIdentityProps, ArrayIs);
});
};
Expand All @@ -2504,7 +2512,7 @@ void reduSaveFinalResultToUserMemHelper(
if constexpr (!Reduction::is_usm) {
if (Redu.hasUserDiscardWriteAccessor()) {
event CopyEvent =
handler::withAuxHandler(Queue, IsHost, [&](handler &CopyHandler) {
withAuxHandler(Queue, IsHost, [&](handler &CopyHandler) {
auto InAcc = Redu.getReadAccToPreviousPartialReds(CopyHandler);
auto OutAcc = Redu.getUserDiscardWriteAccessor();
Redu.associateWithHandler(CopyHandler);
Expand Down
12 changes: 6 additions & 6 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ tuple_select_elements(TupleT Tuple, std::index_sequence<Is...>);

template <typename FirstT, typename... RestT> struct AreAllButLastReductions;

template <class FunctorTy>
event withAuxHandler(std::shared_ptr<detail::queue_impl> Queue, bool IsHost,
FunctorTy Func);
} // namespace detail
} // namespace oneapi
} // namespace ext
Expand Down Expand Up @@ -476,12 +479,9 @@ class __SYCL_EXPORT handler {
}

template <class FunctorTy>
static event withAuxHandler(std::shared_ptr<detail::queue_impl> Queue,
bool IsHost, FunctorTy Func) {
handler AuxHandler(Queue, IsHost);
Func(AuxHandler);
return AuxHandler.finalize();
}
friend event
ext::oneapi::detail::withAuxHandler(std::shared_ptr<detail::queue_impl> Queue,
bool IsHost, FunctorTy Func);
/// }@

/// Saves buffers created by handling reduction feature in handler.
Expand Down