diff --git a/SYCL/Reduction/reduction_range_3d_rw.cpp b/SYCL/Reduction/reduction_range_3d_rw.cpp index 522669e310..f29c828bea 100644 --- a/SYCL/Reduction/reduction_range_3d_rw.cpp +++ b/SYCL/Reduction/reduction_range_3d_rw.cpp @@ -67,14 +67,12 @@ int main() { tests(Q, 99, PlusWithoutIdentity{}, range<3>{2, 2, 2}); tests(Q, 99, PlusWithoutIdentity{}, range<3>{2, 3, 4}); - /* Temporarily disabled tests(Q, 99, PlusWithoutIdentity{}, range<3>{1, 1, MaxWGSize + 1}); tests(Q, 99, PlusWithoutIdentity{}, range<3>{1, MaxWGSize + 1, 1}); tests(Q, 99, PlusWithoutIdentity{}, range<3>{MaxWGSize + 1, 1, 1}); - */ tests(Q, 99, PlusWithoutIdentity{}, range<3>{2, 5, MaxWGSize * 2}); @@ -83,6 +81,27 @@ int main() { tests(Q, 99, PlusWithoutIdentity{}, range<3>{MaxWGSize * 3, 8, 4}); + tests(Q, 99, MultipliesWithoutIdentity{}, + range<3>{1, 1, 1}); + tests(Q, 99, MultipliesWithoutIdentity{}, + range<3>{2, 2, 2}); + tests(Q, 99, MultipliesWithoutIdentity{}, + range<3>{2, 3, 4}); + + tests(Q, 99, MultipliesWithoutIdentity{}, + range<3>{1, 1, MaxWGSize + 1}); + tests(Q, 99, MultipliesWithoutIdentity{}, + range<3>{1, MaxWGSize + 1, 1}); + tests(Q, 99, MultipliesWithoutIdentity{}, + range<3>{MaxWGSize + 1, 1, 1}); + + tests(Q, 99, MultipliesWithoutIdentity{}, + range<3>{2, 5, MaxWGSize * 2}); + tests(Q, 99, MultipliesWithoutIdentity{}, + range<3>{3, MaxWGSize * 3, 2}); + tests(Q, 99, MultipliesWithoutIdentity{}, + range<3>{MaxWGSize * 3, 8, 4}); + printFinalStatus(NumErrors); return NumErrors; } diff --git a/SYCL/Reduction/reduction_utils.hpp b/SYCL/Reduction/reduction_utils.hpp index d5875fbbba..78a097bba9 100644 --- a/SYCL/Reduction/reduction_utils.hpp +++ b/SYCL/Reduction/reduction_utils.hpp @@ -141,6 +141,10 @@ template struct PlusWithoutIdentity { T operator()(const T &A, const T &B) const { return A + B; } }; +template struct MultipliesWithoutIdentity { + T operator()(const T &A, const T &B) const { return A * B; } +}; + template T getMinimumFPValue() { return std::numeric_limits::has_infinity ? static_cast(-std::numeric_limits::infinity())