@@ -214,57 +214,56 @@ template <class Reducer> class combiner {
214214 : memory_scope::device;
215215 }
216216
217+ template <access::address_space Space, class T , class AtomicFunctor >
218+ void atomic_combine_impl (T *ReduVarPtr, AtomicFunctor Functor) const {
219+ auto reducer = static_cast <const Reducer *>(this );
220+ for (size_t E = 0 ; E < Extent; ++E) {
221+ auto AtomicRef =
222+ atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
223+ multi_ptr<T, Space>(ReduVarPtr)[E]);
224+ Functor (AtomicRef, reducer->getElement (E));
225+ }
226+ }
227+
228+ template <class _T , access::address_space Space, class BinaryOperation >
229+ static inline constexpr bool BasicCheck =
230+ std::is_same<typename remove_AS<_T>::type, T>::value &&
231+ (Space == access::address_space::global_space ||
232+ Space == access::address_space::local_space);
233+
217234public:
218235 // / Atomic ADD operation: *ReduVarPtr += MValue;
219236 template <access::address_space Space = access::address_space::global_space,
220237 typename _T = T, class _BinaryOperation = BinaryOperation>
221- enable_if_t <std::is_same< typename remove_AS<_T>::type, T>::value &&
238+ enable_if_t <BasicCheck<_T, Space, _BinaryOperation> &&
222239 (IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value ||
223240 IsReduOptForAtomic64Add<T, _BinaryOperation>::value) &&
224- sycl::detail::IsPlus<T, _BinaryOperation>::value &&
225- (Space == access::address_space::global_space ||
226- Space == access::address_space::local_space)>
241+ sycl::detail::IsPlus<T, _BinaryOperation>::value>
227242 atomic_combine (_T *ReduVarPtr) const {
228- auto reducer = static_cast <const Reducer *>(this );
229- for (size_t E = 0 ; E < Extent; ++E) {
230- atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
231- multi_ptr<T, Space>(ReduVarPtr)[E])
232- .fetch_add (reducer->getElement (E));
233- }
243+ atomic_combine_impl<Space>(
244+ ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_add (Val); });
234245 }
235246
236247 // / Atomic BITWISE OR operation: *ReduVarPtr |= MValue;
237248 template <access::address_space Space = access::address_space::global_space,
238249 typename _T = T, class _BinaryOperation = BinaryOperation>
239- enable_if_t <std::is_same< typename remove_AS<_T>::type, T>::value &&
250+ enable_if_t <BasicCheck<_T, Space, _BinaryOperation> &&
240251 IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
241- sycl::detail::IsBitOR<T, _BinaryOperation>::value &&
242- (Space == access::address_space::global_space ||
243- Space == access::address_space::local_space)>
252+ sycl::detail::IsBitOR<T, _BinaryOperation>::value>
244253 atomic_combine (_T *ReduVarPtr) const {
245- auto reducer = static_cast <const Reducer *>(this );
246- for (size_t E = 0 ; E < Extent; ++E) {
247- atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
248- multi_ptr<T, Space>(ReduVarPtr)[E])
249- .fetch_or (reducer->getElement (E));
250- }
254+ atomic_combine_impl<Space>(
255+ ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_or (Val); });
251256 }
252257
253258 // / Atomic BITWISE XOR operation: *ReduVarPtr ^= MValue;
254259 template <access::address_space Space = access::address_space::global_space,
255260 typename _T = T, class _BinaryOperation = BinaryOperation>
256- enable_if_t <std::is_same< typename remove_AS<_T>::type, T>::value &&
261+ enable_if_t <BasicCheck<_T, Space, _BinaryOperation> &&
257262 IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
258- sycl::detail::IsBitXOR<T, _BinaryOperation>::value &&
259- (Space == access::address_space::global_space ||
260- Space == access::address_space::local_space)>
263+ sycl::detail::IsBitXOR<T, _BinaryOperation>::value>
261264 atomic_combine (_T *ReduVarPtr) const {
262- auto reducer = static_cast <const Reducer *>(this );
263- for (size_t E = 0 ; E < Extent; ++E) {
264- atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
265- multi_ptr<T, Space>(ReduVarPtr)[E])
266- .fetch_xor (reducer->getElement (E));
267- }
265+ atomic_combine_impl<Space>(
266+ ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_xor (Val); });
268267 }
269268
270269 // / Atomic BITWISE AND operation: *ReduVarPtr &= MValue;
@@ -276,46 +275,30 @@ template <class Reducer> class combiner {
276275 (Space == access::address_space::global_space ||
277276 Space == access::address_space::local_space)>
278277 atomic_combine (_T *ReduVarPtr) const {
279- auto reducer = static_cast <const Reducer *>(this );
280- for (size_t E = 0 ; E < Extent; ++E) {
281- atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
282- multi_ptr<T, Space>(ReduVarPtr)[E])
283- .fetch_and (reducer->getElement (E));
284- }
278+ atomic_combine_impl<Space>(
279+ ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_and (Val); });
285280 }
286281
287282 // / Atomic MIN operation: *ReduVarPtr = sycl::minimum(*ReduVarPtr, MValue);
288283 template <access::address_space Space = access::address_space::global_space,
289284 typename _T = T, class _BinaryOperation = BinaryOperation>
290- enable_if_t <std::is_same< typename remove_AS<_T>::type, T>::value &&
285+ enable_if_t <BasicCheck<_T, Space, _BinaryOperation> &&
291286 IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
292- sycl::detail::IsMinimum<T, _BinaryOperation>::value &&
293- (Space == access::address_space::global_space ||
294- Space == access::address_space::local_space)>
287+ sycl::detail::IsMinimum<T, _BinaryOperation>::value>
295288 atomic_combine (_T *ReduVarPtr) const {
296- auto reducer = static_cast <const Reducer *>(this );
297- for (size_t E = 0 ; E < Extent; ++E) {
298- atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
299- multi_ptr<T, Space>(ReduVarPtr)[E])
300- .fetch_min (reducer->getElement (E));
301- }
289+ atomic_combine_impl<Space>(
290+ ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_min (Val); });
302291 }
303292
304293 // / Atomic MAX operation: *ReduVarPtr = sycl::maximum(*ReduVarPtr, MValue);
305294 template <access::address_space Space = access::address_space::global_space,
306295 typename _T = T, class _BinaryOperation = BinaryOperation>
307- enable_if_t <std::is_same< typename remove_AS<_T>::type, T>::value &&
296+ enable_if_t <BasicCheck<_T, Space, _BinaryOperation> &&
308297 IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
309- sycl::detail::IsMaximum<T, _BinaryOperation>::value &&
310- (Space == access::address_space::global_space ||
311- Space == access::address_space::local_space)>
298+ sycl::detail::IsMaximum<T, _BinaryOperation>::value>
312299 atomic_combine (_T *ReduVarPtr) const {
313- auto reducer = static_cast <const Reducer *>(this );
314- for (size_t E = 0 ; E < Extent; ++E) {
315- atomic_ref<T, memory_order::relaxed, getMemoryScope<Space>(), Space>(
316- multi_ptr<T, Space>(ReduVarPtr)[E])
317- .fetch_max (reducer->getElement (E));
318- }
300+ atomic_combine_impl<Space>(
301+ ReduVarPtr, [](auto Ref, auto Val) { return Ref.fetch_max (Val); });
319302 }
320303};
321304
@@ -415,8 +398,6 @@ class reducer<T, BinaryOperation, Dims, Extent, Algorithm, View,
415398 reducer (const T &Identity, BinaryOperation BOp)
416399 : MValue(Identity), MIdentity(Identity), MBinaryOp(BOp) {}
417400
418- // SYCL 2020 revision 4 says this should be const, but this is a bug
419- // see https://github.com/KhronosGroup/SYCL-Docs/pull/252
420401 reducer<T, BinaryOperation, Dims - 1 , Extent, Algorithm, true >
421402 operator [](size_t Index) {
422403 return {MValue[Index], MBinaryOp};
0 commit comments