@@ -1571,48 +1571,6 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
15711571 }
15721572};
15731573
1574- // Auto-dispatch. Must be the last one.
1575- template <> struct NDRangeReduction <reduction::strategy::auto_select> {
1576- // Some readability aliases, to increase signal/noise ratio below.
1577- template <reduction::strategy Strategy>
1578- using Impl = NDRangeReduction<Strategy>;
1579- using S = reduction::strategy;
1580-
1581- template <typename KernelName, int Dims, typename PropertiesT,
1582- typename KernelType, typename Reduction>
1583- static void run (handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
1584- nd_range<Dims> NDRange, PropertiesT &Properties,
1585- Reduction &Redu, KernelType &KernelFunc) {
1586- auto Delegate = [&](auto Impl) {
1587- Impl.template run <KernelName>(CGH, Queue, NDRange, Properties, Redu,
1588- KernelFunc);
1589- };
1590-
1591- if constexpr (Reduction::has_float64_atomics) {
1592- if (getDeviceFromHandler (CGH).has (aspect::atomic64))
1593- return Delegate (Impl<S::group_reduce_and_atomic_cross_wg>{});
1594-
1595- if constexpr (Reduction::has_fast_reduce)
1596- return Delegate (Impl<S::group_reduce_and_multiple_kernels>{});
1597- else
1598- return Delegate (Impl<S::basic>{});
1599- } else if constexpr (Reduction::has_fast_atomics) {
1600- if constexpr (Reduction::has_fast_reduce) {
1601- return Delegate (Impl<S::group_reduce_and_atomic_cross_wg>{});
1602- } else {
1603- return Delegate (Impl<S::local_mem_tree_and_atomic_cross_wg>{});
1604- }
1605- } else {
1606- if constexpr (Reduction::has_fast_reduce)
1607- return Delegate (Impl<S::group_reduce_and_multiple_kernels>{});
1608- else
1609- return Delegate (Impl<S::basic>{});
1610- }
1611-
1612- assert (false && " Must be unreachable!" );
1613- }
1614- };
1615-
16161574// / For the given 'Reductions' types pack and indices enumerating them this
16171575// / function either creates new temporary accessors for partial sums (if IsOneWG
16181576// / is false) or returns user's accessor/USM-pointer if (IsOneWG is true).
@@ -2230,21 +2188,109 @@ tuple_select_elements(TupleT Tuple, std::index_sequence<Is...>) {
22302188 return {std::get<Is>(std::move (Tuple))...};
22312189}
22322190
2191+ template <> struct NDRangeReduction <reduction::strategy::multi> {
2192+ template <typename KernelName, int Dims, typename PropertiesT,
2193+ typename ... RestT>
2194+ static void run (handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2195+ nd_range<Dims> NDRange, PropertiesT &Properties,
2196+ RestT... Rest) {
2197+ std::tuple<RestT...> ArgsTuple (Rest...);
2198+ constexpr size_t NumArgs = sizeof ...(RestT);
2199+ auto KernelFunc = std::get<NumArgs - 1 >(ArgsTuple);
2200+ auto ReduIndices = std::make_index_sequence<NumArgs - 1 >();
2201+ auto ReduTuple = detail::tuple_select_elements (ArgsTuple, ReduIndices);
2202+
2203+ size_t LocalMemPerWorkItem = reduGetMemPerWorkItem (ReduTuple, ReduIndices);
2204+ // TODO: currently the maximal work group size is determined for the given
2205+ // queue/device, while it is safer to use queries to the kernel compiled
2206+ // for the device.
2207+ size_t MaxWGSize = reduGetMaxWGSize (Queue, LocalMemPerWorkItem);
2208+ if (NDRange.get_local_range ().size () > MaxWGSize)
2209+ throw sycl::runtime_error (" The implementation handling parallel_for with"
2210+ " reduction requires work group size not bigger"
2211+ " than " +
2212+ std::to_string (MaxWGSize),
2213+ PI_ERROR_INVALID_WORK_GROUP_SIZE);
2214+
2215+ reduCGFuncMulti<KernelName>(CGH, KernelFunc, NDRange, Properties, ReduTuple,
2216+ ReduIndices);
2217+ reduction::finalizeHandler (CGH);
2218+
2219+ size_t NWorkItems = NDRange.get_group_range ().size ();
2220+ while (NWorkItems > 1 ) {
2221+ reduction::withAuxHandler (CGH, [&](handler &AuxHandler) {
2222+ NWorkItems = reduAuxCGFunc<KernelName, decltype (KernelFunc)>(
2223+ AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
2224+ });
2225+ } // end while (NWorkItems > 1)
2226+ }
2227+ };
2228+
2229+ // Auto-dispatch. Must be the last one.
2230+ template <> struct NDRangeReduction <reduction::strategy::auto_select> {
2231+ // Some readability aliases, to increase signal/noise ratio below.
2232+ template <reduction::strategy Strategy>
2233+ using Impl = NDRangeReduction<Strategy>;
2234+ using Strat = reduction::strategy;
2235+
2236+ template <typename KernelName, int Dims, typename PropertiesT,
2237+ typename KernelType, typename Reduction>
2238+ static void run (handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2239+ nd_range<Dims> NDRange, PropertiesT &Properties,
2240+ Reduction &Redu, KernelType &KernelFunc) {
2241+ auto Delegate = [&](auto Impl) {
2242+ Impl.template run <KernelName>(CGH, Queue, NDRange, Properties, Redu,
2243+ KernelFunc);
2244+ };
2245+
2246+ if constexpr (Reduction::has_float64_atomics) {
2247+ if (getDeviceFromHandler (CGH).has (aspect::atomic64))
2248+ return Delegate (Impl<Strat::group_reduce_and_atomic_cross_wg>{});
2249+
2250+ if constexpr (Reduction::has_fast_reduce)
2251+ return Delegate (Impl<Strat::group_reduce_and_multiple_kernels>{});
2252+ else
2253+ return Delegate (Impl<Strat::basic>{});
2254+ } else if constexpr (Reduction::has_fast_atomics) {
2255+ if constexpr (Reduction::has_fast_reduce) {
2256+ return Delegate (Impl<Strat::group_reduce_and_atomic_cross_wg>{});
2257+ } else {
2258+ return Delegate (Impl<Strat::local_mem_tree_and_atomic_cross_wg>{});
2259+ }
2260+ } else {
2261+ if constexpr (Reduction::has_fast_reduce)
2262+ return Delegate (Impl<Strat::group_reduce_and_multiple_kernels>{});
2263+ else
2264+ return Delegate (Impl<Strat::basic>{});
2265+ }
2266+
2267+ assert (false && " Must be unreachable!" );
2268+ }
2269+ template <typename KernelName, int Dims, typename PropertiesT,
2270+ typename ... RestT>
2271+ static void run (handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2272+ nd_range<Dims> NDRange, PropertiesT &Properties,
2273+ RestT... Rest) {
2274+ return Impl<Strat::multi>::run<KernelName>(CGH, Queue, NDRange, Properties,
2275+ Rest...);
2276+ }
2277+ };
2278+
22332279template <typename KernelName, reduction::strategy Strategy, int Dims,
2234- typename PropertiesT, typename KernelType, typename Reduction >
2280+ typename PropertiesT, typename ... RestT >
22352281void reduction_parallel_for (handler &CGH,
22362282 std::shared_ptr<detail::queue_impl> Queue,
22372283 nd_range<Dims> NDRange, PropertiesT Properties,
2238- Reduction Redu, KernelType KernelFunc ) {
2239- NDRangeReduction<Strategy>::template run<KernelName>(
2240- CGH, Queue, NDRange, Properties, Redu, KernelFunc );
2284+ RestT... Rest ) {
2285+ NDRangeReduction<Strategy>::template run<KernelName>(CGH, Queue, NDRange,
2286+ Properties, Rest... );
22412287}
22422288
22432289__SYCL_EXPORT uint32_t
22442290reduGetMaxNumConcurrentWorkGroups (std::shared_ptr<queue_impl> Queue);
22452291
2246- template <typename KernelName, int Dims, typename PropertiesT ,
2247- typename KernelType, typename Reduction>
2292+ template <typename KernelName, reduction::strategy Strategy, int Dims ,
2293+ typename PropertiesT, typename KernelType, typename Reduction>
22482294void reduction_parallel_for (handler &CGH,
22492295 std::shared_ptr<detail::queue_impl> Queue,
22502296 range<Dims> Range, PropertiesT Properties,
@@ -2303,7 +2349,10 @@ void reduction_parallel_for(handler &CGH,
23032349 KernelFunc (getDelinearizedId (Range, I), Reducer);
23042350 };
23052351
2306- constexpr auto Strategy = [&]() {
2352+ constexpr auto StrategyToUse = [&]() {
2353+ if constexpr (Strategy != reduction::strategy::auto_select)
2354+ return Strategy;
2355+
23072356 if constexpr (Reduction::has_fast_reduce)
23082357 return reduction::strategy::group_reduce_and_last_wg_detection;
23092358 else if constexpr (Reduction::has_fast_atomics)
@@ -2312,57 +2361,8 @@ void reduction_parallel_for(handler &CGH,
23122361 return reduction::strategy::range_basic;
23132362 }();
23142363
2315- reduction_parallel_for<KernelName, Strategy>(CGH, Queue, NDRange, Properties,
2316- Redu, UpdatedKernelFunc);
2317- }
2318-
2319- template <> struct NDRangeReduction <reduction::strategy::multi> {
2320- template <typename KernelName, int Dims, typename PropertiesT,
2321- typename ... RestT>
2322- static void run (handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2323- nd_range<Dims> NDRange, PropertiesT &Properties,
2324- RestT... Rest) {
2325- std::tuple<RestT...> ArgsTuple (Rest...);
2326- constexpr size_t NumArgs = sizeof ...(RestT);
2327- auto KernelFunc = std::get<NumArgs - 1 >(ArgsTuple);
2328- auto ReduIndices = std::make_index_sequence<NumArgs - 1 >();
2329- auto ReduTuple = detail::tuple_select_elements (ArgsTuple, ReduIndices);
2330-
2331- size_t LocalMemPerWorkItem = reduGetMemPerWorkItem (ReduTuple, ReduIndices);
2332- // TODO: currently the maximal work group size is determined for the given
2333- // queue/device, while it is safer to use queries to the kernel compiled
2334- // for the device.
2335- size_t MaxWGSize = reduGetMaxWGSize (Queue, LocalMemPerWorkItem);
2336- if (NDRange.get_local_range ().size () > MaxWGSize)
2337- throw sycl::runtime_error (" The implementation handling parallel_for with"
2338- " reduction requires work group size not bigger"
2339- " than " +
2340- std::to_string (MaxWGSize),
2341- PI_ERROR_INVALID_WORK_GROUP_SIZE);
2342-
2343- reduCGFuncMulti<KernelName>(CGH, KernelFunc, NDRange, Properties, ReduTuple,
2344- ReduIndices);
2345- reduction::finalizeHandler (CGH);
2346-
2347- size_t NWorkItems = NDRange.get_group_range ().size ();
2348- while (NWorkItems > 1 ) {
2349- reduction::withAuxHandler (CGH, [&](handler &AuxHandler) {
2350- NWorkItems = reduAuxCGFunc<KernelName, decltype (KernelFunc)>(
2351- AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
2352- });
2353- } // end while (NWorkItems > 1)
2354- }
2355- };
2356-
2357- template <typename KernelName, int Dims, typename PropertiesT,
2358- typename ... RestT>
2359- void reduction_parallel_for (handler &CGH,
2360- std::shared_ptr<detail::queue_impl> Queue,
2361- nd_range<Dims> NDRange, PropertiesT Properties,
2362- RestT... Rest) {
2363- constexpr auto Strategy = reduction::strategy::multi;
2364- NDRangeReduction<Strategy>::template run<KernelName>(CGH, Queue, NDRange,
2365- Properties, Rest...);
2364+ reduction_parallel_for<KernelName, StrategyToUse>(
2365+ CGH, Queue, NDRange, Properties, Redu, UpdatedKernelFunc);
23662366}
23672367} // namespace detail
23682368
0 commit comments