@@ -164,6 +164,126 @@ struct slice_area {
164164  };
165165};
166166
167+ 
168+ //  a way to get value_type from both accessors and USM that is needed for transform_init
169+ template  <typename  Unknown>
170+ struct  accessor_traits 
171+ {
172+ };
173+ 
174+ template  <typename  T, int  Dim, sycl::access::mode AccMode, sycl::access::target AccTarget,
175+           sycl::access::placeholder Placeholder>
176+ struct  accessor_traits <sycl::accessor<T, Dim, AccMode, AccTarget, Placeholder>>
177+ {
178+     using  value_type = typename  sycl::accessor<T, Dim, AccMode, AccTarget, Placeholder>::value_type;
179+ };
180+ 
181+ template  <typename  RawArrayValueType>
182+ struct  accessor_traits <RawArrayValueType*>
183+ {
184+     using  value_type = RawArrayValueType;
185+ };
186+ 
187+ //  calculate shift where we should start processing on current item
188+ template  <typename  NDItemId, typename  GlobalIdx, typename  SizeNIter, typename  SizeN>
189+ SizeN
190+ calc_shift (const  NDItemId item_id, const  GlobalIdx global_idx, SizeNIter& n_iter, const  SizeN n)
191+ {
192+     auto  global_range_size = item_id.get_global_range ().size ();
193+ 
194+     auto  start = n_iter * global_idx;
195+     auto  global_shift = global_idx + n_iter * global_range_size;
196+     if  (n_iter > 0  && global_shift > n)
197+     {
198+         start += n % global_range_size - global_idx;
199+     }
200+     else  if  (global_shift < n)
201+     {
202+         n_iter++;
203+     }
204+     return  start;
205+ }
206+ 
207+ 
208+ template  <typename  ExecutionPolicy, typename  Operation1, typename  Operation2>
209+ struct  transform_init 
210+ {
211+     Operation1 binary_op;
212+     Operation2 unary_op;
213+ 
214+     template  <typename  NDItemId, typename  GlobalIdx, typename  Size, typename  AccLocal, typename ... Acc>
215+     void 
216+     operator ()(const  NDItemId item_id, const  GlobalIdx global_idx, Size n, AccLocal& local_mem,
217+                const  Acc&... acc)
218+     {
219+         auto  local_idx = item_id.get_local_id (0 );
220+         auto  global_range_size = item_id.get_global_range ().size ();
221+         auto  n_iter = n / global_range_size;
222+         auto  start = calc_shift (item_id, global_idx, n_iter, n);
223+         auto  shifted_global_idx = global_idx + start;
224+ 
225+         typename  accessor_traits<AccLocal>::value_type res;
226+         if  (global_idx < n)
227+         {
228+             res = unary_op (shifted_global_idx, acc...);
229+         }
230+         //  Add neighbour to the current local_mem
231+         for  (decltype (n_iter) i = 1 ; i < n_iter; ++i)
232+         {
233+             res = binary_op (res, unary_op (shifted_global_idx + i, acc...));
234+         }
235+         if  (global_idx < n)
236+         {
237+             local_mem[local_idx] = res;
238+         }
239+     }
240+ };
241+ 
242+ 
243+ //  Reduce on local memory
244+ template  <typename  ExecutionPolicy, typename  BinaryOperation1, typename  Tp>
245+ struct  reduce 
246+ {
247+     BinaryOperation1 bin_op1;
248+ 
249+     template  <typename  NDItemId, typename  GlobalIdx, typename  Size, typename  AccLocal>
250+     Tp
251+     operator ()(const  NDItemId item_id, const  GlobalIdx global_idx, const  Size n, AccLocal& local_mem)
252+     {
253+         auto  local_idx = item_id.get_local_id (0 );
254+         auto  group_size = item_id.get_local_range ().size ();
255+ 
256+         auto  k = 1 ;
257+         do 
258+         {
259+             item_id.barrier (sycl::access::fence_space::local_space);
260+             if  (local_idx % (2  * k) == 0  && local_idx + k < group_size && global_idx < n &&
261+                 global_idx + k < n)
262+             {
263+                 local_mem[local_idx] = bin_op1 (local_mem[local_idx], local_mem[local_idx + k]);
264+             }
265+             k *= 2 ;
266+         } while  (k < group_size);
267+         return  local_mem[local_idx];
268+     }
269+ };
270+ 
271+ 
272+ //  walk through the data
273+ template  <typename  ExecutionPolicy, typename  F>
274+ struct  walk_n 
275+ {
276+     F f;
277+ 
278+     template  <typename  ItemId, typename ... Ranges>
279+     auto 
280+     operator ()(const  ItemId idx, Ranges&&... rngs) -> decltype (f(rngs[idx]...))
281+     {
282+         return  f (rngs[idx]...);
283+     }
284+ };
285+ 
286+ 
167287//  This option uses a parallel for to fill the buffer and then
168288//  uses a tranform_init with plus/no_op and then
169289//  a local reduction then global reduction.
@@ -189,21 +309,18 @@ float calc_pi_dpstd_native3(size_t num_steps, int groups, Policy&& policy) {
189309  auto  calc_begin = oneapi::dpl::begin (buf);
190310  auto  calc_end = oneapi::dpl::end (buf);
191311
192-   using  Functor = oneapi::dpl::unseq_backend:: walk_n<Policy, my_no_op>;
312+   using  Functor = walk_n<Policy, my_no_op>;
193313  float  result;
194314
195315  //  Functor will do nothing for tranform_init and will use plus for reduce.
196316  //  In this example we have done the calculation and filled the buffer above
197317  //  The way transform_init works is that you need to have the value already
198318  //  populated in the buffer.
199-   auto  tf_init =
200-       oneapi::dpl::unseq_backend::transform_init<Policy, std::plus<float >,
201-                                                  Functor>{std::plus<float >(),
202-                                                           Functor{my_no_op ()}};
319+   auto  tf_init = transform_init<Policy, std::plus<float >,
320+                    Functor>{std::plus<float >(), Functor{my_no_op ()}};
203321
204322  auto  combine = std::plus<float >();
205-   auto  brick_reduce =
206-       oneapi::dpl::unseq_backend::reduce<Policy, std::plus<float >, float >{
323+   auto  brick_reduce = reduce<Policy, std::plus<float >, float >{
207324          std::plus<float >()};
208325  auto  workgroup_size =
209326      policy.queue ()
@@ -295,19 +412,17 @@ float calc_pi_dpstd_native4(size_t num_steps, int groups, Policy&& policy) {
295412  auto  calc_begin = oneapi::dpl::begin (buf2);
296413  auto  calc_end = oneapi::dpl::end (buf2);
297414
298-   using  Functor2 = oneapi::dpl::unseq_backend:: walk_n<Policy, slice_area>;
415+   using  Functor2 = walk_n<Policy, slice_area>;
299416
300417  //  The buffer has 1...num it at and now we will use that as an input
301418  //  to the slice structue which will calculate the area of each
302419  //  rectangle.
303-   auto  tf_init =
304-       oneapi::dpl::unseq_backend::transform_init<Policy, std::plus<float >,
420+   auto  tf_init = transform_init<Policy, std::plus<float >,
305421                                                 Functor2>{
306422          std::plus<float >(), Functor2{slice_area (num_steps)}};
307423
308424  auto  combine = std::plus<float >();
309-   auto  brick_reduce =
310-       oneapi::dpl::unseq_backend::reduce<Policy, std::plus<float >, float >{
425+   auto  brick_reduce = reduce<Policy, std::plus<float >, float >{
311426          std::plus<float >()};
312427
313428  //  get workgroup_size from the device
0 commit comments