From 6f6569a2ccb7c16c87800446dc5a5234997f75c6 Mon Sep 17 00:00:00 2001 From: Arseniy Obolenskiy Date: Fri, 10 Dec 2021 10:59:26 +0300 Subject: [PATCH 1/2] [SYCL] Fix support for classes implicitly converted from items in parallel_for --- sycl/include/CL/sycl/handler.hpp | 27 +++++++--- .../basic_tests/parallel_for_user_types.cpp | 54 +++++++++++++++++++ 2 files changed, 75 insertions(+), 6 deletions(-) create mode 100644 sycl/test/basic_tests/parallel_for_user_types.cpp diff --git a/sycl/include/CL/sycl/handler.hpp b/sycl/include/CL/sycl/handler.hpp index c3ddbc1930258..2d1f84d5f8469 100644 --- a/sycl/include/CL/sycl/handler.hpp +++ b/sycl/include/CL/sycl/handler.hpp @@ -935,10 +935,17 @@ class __SYCL_EXPORT handler { using LambdaArgType = sycl::detail::lambda_arg_type>; // If 1D kernel argument is an integral type, convert it to sycl::item<1> - using TransformedArgType = - typename std::conditional::value && - Dims == 1, - item, LambdaArgType>::type; + // If type is convertible to sycl::item/sycl::nd_item, convert it to + // sycl::item/sycl::nd_item + using TransformedArgType = typename std::conditional< + std::is_integral::value && Dims == 1, item, + typename std::conditional< + std::is_convertible, LambdaArgType>::value, + nd_item, + typename std::conditional< + std::is_convertible, LambdaArgType>::value, + item, LambdaArgType>::type>::type>::type; + using NameT = typename detail::get_kernel_name_t::name; @@ -1560,12 +1567,20 @@ class __SYCL_EXPORT handler { verifyUsedKernelBundle(detail::KernelInfo::getName()); using LambdaArgType = sycl::detail::lambda_arg_type>; + // If type is convertible to sycl::item/sycl::nd_item, convert it to + // sycl::item/sycl::nd_item + using TransformedArgType = typename std::conditional< + std::is_convertible, LambdaArgType>::value, nd_item, + typename std::conditional< + std::is_convertible, LambdaArgType>::value, item, + LambdaArgType>::type>::type; (void)ExecutionRange; - kernel_parallel_for_wrapper(KernelFunc); + kernel_parallel_for_wrapper(KernelFunc); #ifndef __SYCL_DEVICE_ONLY__ detail::checkValueRange(ExecutionRange); MNDRDesc.set(std::move(ExecutionRange)); - StoreLambda(std::move(KernelFunc)); + StoreLambda( + std::move(KernelFunc)); setType(detail::CG::Kernel); #endif } diff --git a/sycl/test/basic_tests/parallel_for_user_types.cpp b/sycl/test/basic_tests/parallel_for_user_types.cpp new file mode 100644 index 0000000000000..ff616fe73bd18 --- /dev/null +++ b/sycl/test/basic_tests/parallel_for_user_types.cpp @@ -0,0 +1,54 @@ +// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %RUN_ON_HOST %t.out + +// This test performs basic check of supporting user defined class that are +// implicitly converted from sycl::item/sycl::nd_item in parallel_for. + +#include +#include + +template class item_wrapper { +public: + item_wrapper(sycl::item it) : m_item(it) {} + +private: + sycl::item m_item; +}; + +template class nd_item_wrapper { +public: + nd_item_wrapper(sycl::nd_item it) : m_item(it) {} + +private: + sycl::nd_item m_item; +}; + +template class item_wrapper2 { +public: + item_wrapper2(sycl::item it) : m_item(it), m_value(T()) {} + +private: + sycl::item m_item; + T m_value; +}; + +template class nd_item_wrapper2 { +public: + nd_item_wrapper2(sycl::nd_item it) : m_item(it), m_value(T()) {} + +private: + sycl::nd_item m_item; + T m_value; +}; + +int main() { + sycl::queue q; + + q.parallel_for(sycl::range<1>{1}, [=](item_wrapper<1> item) {}); + q.parallel_for(sycl::nd_range<1>{1, 1}, [=](nd_item_wrapper<1> item) {}); + q.parallel_for(sycl::range<1>{1}, [=](item_wrapper2<1, int> item) {}); + q.parallel_for(sycl::nd_range<1>{1, 1}, + [=](nd_item_wrapper2<1, int> item) {}); + + return 0; +} From dc0e75b09f0f0fe883ae5730110d0170c1fca0e1 Mon Sep 17 00:00:00 2001 From: Arseniy Obolenskiy Date: Fri, 10 Dec 2021 11:44:01 +0300 Subject: [PATCH 2/2] Fix comment and extract common pattern to separate struct --- sycl/include/CL/sycl/handler.hpp | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/sycl/include/CL/sycl/handler.hpp b/sycl/include/CL/sycl/handler.hpp index 2d1f84d5f8469..156334d841174 100644 --- a/sycl/include/CL/sycl/handler.hpp +++ b/sycl/include/CL/sycl/handler.hpp @@ -917,6 +917,14 @@ class __SYCL_EXPORT handler { AccessMode == access::mode::discard_read_write; } + template struct TransformUserItemType { + using type = typename std::conditional< + std::is_convertible, LambdaArgType>::value, nd_item, + typename std::conditional< + std::is_convertible, LambdaArgType>::value, item, + LambdaArgType>::type>::type; + }; + /// Defines and invokes a SYCL kernel function for the specified range. /// /// The SYCL kernel function is defined as a lambda function or a named @@ -935,16 +943,11 @@ class __SYCL_EXPORT handler { using LambdaArgType = sycl::detail::lambda_arg_type>; // If 1D kernel argument is an integral type, convert it to sycl::item<1> - // If type is convertible to sycl::item/sycl::nd_item, convert it to - // sycl::item/sycl::nd_item + // If user type is convertible from sycl::item/sycl::nd_item, use + // sycl::item/sycl::nd_item to transport item information using TransformedArgType = typename std::conditional< std::is_integral::value && Dims == 1, item, - typename std::conditional< - std::is_convertible, LambdaArgType>::value, - nd_item, - typename std::conditional< - std::is_convertible, LambdaArgType>::value, - item, LambdaArgType>::type>::type>::type; + typename TransformUserItemType::type>::type; using NameT = typename detail::get_kernel_name_t::name; @@ -1567,13 +1570,10 @@ class __SYCL_EXPORT handler { verifyUsedKernelBundle(detail::KernelInfo::getName()); using LambdaArgType = sycl::detail::lambda_arg_type>; - // If type is convertible to sycl::item/sycl::nd_item, convert it to - // sycl::item/sycl::nd_item - using TransformedArgType = typename std::conditional< - std::is_convertible, LambdaArgType>::value, nd_item, - typename std::conditional< - std::is_convertible, LambdaArgType>::value, item, - LambdaArgType>::type>::type; + // If user type is convertible from sycl::item/sycl::nd_item, use + // sycl::item/sycl::nd_item to transport item information + using TransformedArgType = + typename TransformUserItemType::type; (void)ExecutionRange; kernel_parallel_for_wrapper(KernelFunc); #ifndef __SYCL_DEVICE_ONLY__