Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sycl/include/CL/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <CL/sycl/item.hpp>
#include <CL/sycl/kernel.hpp>
#include <CL/sycl/kernel_bundle.hpp>
#include <CL/sycl/kernel_handler.hpp>
#include <CL/sycl/marray.hpp>
#include <CL/sycl/multi_ptr.hpp>
#include <CL/sycl/nd_item.hpp>
Expand All @@ -47,6 +48,7 @@
#include <CL/sycl/range.hpp>
#include <CL/sycl/reduction.hpp>
#include <CL/sycl/sampler.hpp>
#include <CL/sycl/specialization_id.hpp>
#include <CL/sycl/stream.hpp>
#include <CL/sycl/types.hpp>
#include <CL/sycl/usm.hpp>
Expand Down
132 changes: 104 additions & 28 deletions sycl/include/CL/sycl/detail/cg_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <CL/sycl/interop_handle.hpp>
#include <CL/sycl/interop_handler.hpp>
#include <CL/sycl/kernel.hpp>
#include <CL/sycl/kernel_handler.hpp>
#include <CL/sycl/nd_item.hpp>
#include <CL/sycl/range.hpp>

Expand Down Expand Up @@ -122,6 +123,81 @@ class NDRDescT {
size_t Dims;
};

template <typename, typename T> struct check_fn_signature {
static_assert(std::integral_constant<T, false>::value,
"Second template parameter is required to be of function type");
};

template <typename F, typename RetT, typename... Args>
struct check_fn_signature<F, RetT(Args...)> {
Copy link
Contributor Author

@dm-vodopyanov dm-vodopyanov Mar 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from handler.hpp as is, except check_kernel_lambda_takes_args - this is a new type trait

private:
template <typename T>
static constexpr auto check(T *) -> typename std::is_same<
decltype(std::declval<T>().operator()(std::declval<Args>()...)),
RetT>::type;

template <typename> static constexpr std::false_type check(...);

using type = decltype(check<F>(0));

public:
static constexpr bool value = type::value;
};

template <typename F, typename... Args>
static constexpr bool check_kernel_lambda_takes_args() {
return check_fn_signature<std::remove_reference_t<F>, void(Args...)>::value;
}

// Type traits to find out if kernal lambda has kernel_handler argument

template <typename KernelType, typename LambdaArgType = void,
typename std::enable_if_t<std::is_same<LambdaArgType, void>::value>
* = nullptr>
constexpr bool isKernelLambdaCallableWithKernelHandler() {
return check_kernel_lambda_takes_args<KernelType, kernel_handler>();
}

template <typename KernelType, typename LambdaArgType,
typename std::enable_if_t<!std::is_same<LambdaArgType, void>::value>
* = nullptr>
constexpr bool isKernelLambdaCallableWithKernelHandler() {
return check_kernel_lambda_takes_args<KernelType, LambdaArgType,
kernel_handler>();
}

// Helpers for running kernel lambda on the host device

template <typename KernelType,
typename std::enable_if_t<isKernelLambdaCallableWithKernelHandler<
KernelType>()> * = nullptr>
constexpr void runKernelWithoutArg(KernelType KernelName) {
kernel_handler KH;
KernelName(KH);
}

template <typename KernelType,
typename std::enable_if_t<!isKernelLambdaCallableWithKernelHandler<
KernelType>()> * = nullptr>
constexpr void runKernelWithoutArg(KernelType KernelName) {
KernelName();
}

template <typename ArgType, typename KernelType,
typename std::enable_if_t<isKernelLambdaCallableWithKernelHandler<
KernelType, ArgType>()> * = nullptr>
constexpr void runKernelWithArg(KernelType KernelName, ArgType Arg) {
kernel_handler KH;
KernelName(Arg, KH);
}

template <typename ArgType, typename KernelType,
typename std::enable_if_t<!isKernelLambdaCallableWithKernelHandler<
KernelType, ArgType>()> * = nullptr>
constexpr void runKernelWithArg(KernelType KernelName, ArgType Arg) {
KernelName(Arg);
}

// The pure virtual class aimed to store lambda/functors of any type.
class HostKernelBase {
public:
Expand Down Expand Up @@ -197,7 +273,7 @@ class HostKernel : public HostKernelBase {
template <class ArgT = KernelArgType>
typename detail::enable_if_t<std::is_same<ArgT, void>::value>
runOnHost(const NDRDescT &) {
MKernel();
runKernelWithoutArg(MKernel);
}

template <class ArgT = KernelArgType>
Expand All @@ -218,18 +294,18 @@ class HostKernel : public HostKernelBase {
UpperBound[I] = Range[I] + Offset[I];
}

detail::NDLoop<Dims>::iterate(/*LowerBound=*/Offset, Stride, UpperBound,
[&](const sycl::id<Dims> &ID) {
sycl::item<Dims, /*Offset=*/true> Item =
IDBuilder::createItem<Dims, true>(
Range, ID, Offset);

if (StoreLocation) {
store_id(&ID);
store_item(&Item);
}
MKernel(ID);
});
detail::NDLoop<Dims>::iterate(
Copy link
Contributor Author

@dm-vodopyanov dm-vodopyanov Mar 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-format, the only change - runKernelWithArg... line

/*LowerBound=*/Offset, Stride, UpperBound,
[&](const sycl::id<Dims> &ID) {
sycl::item<Dims, /*Offset=*/true> Item =
IDBuilder::createItem<Dims, true>(Range, ID, Offset);

if (StoreLocation) {
store_id(&ID);
store_item(&Item);
}
runKernelWithArg<const sycl::id<Dims> &>(MKernel, ID);
});
}

template <class ArgT = KernelArgType>
Expand All @@ -253,7 +329,7 @@ class HostKernel : public HostKernelBase {
store_id(&ID);
store_item(&ItemWithOffset);
}
MKernel(Item);
runKernelWithArg<sycl::item<Dims, /*Offset=*/false>>(MKernel, Item);
});
}

Expand All @@ -276,18 +352,18 @@ class HostKernel : public HostKernelBase {
UpperBound[I] = Range[I] + Offset[I];
}

detail::NDLoop<Dims>::iterate(/*LowerBound=*/Offset, Stride, UpperBound,
[&](const sycl::id<Dims> &ID) {
sycl::item<Dims, /*Offset=*/true> Item =
IDBuilder::createItem<Dims, true>(
Range, ID, Offset);

if (StoreLocation) {
store_id(&ID);
store_item(&Item);
}
MKernel(Item);
});
detail::NDLoop<Dims>::iterate(
/*LowerBound=*/Offset, Stride, UpperBound,
[&](const sycl::id<Dims> &ID) {
sycl::item<Dims, /*Offset=*/true> Item =
IDBuilder::createItem<Dims, true>(Range, ID, Offset);

if (StoreLocation) {
store_id(&ID);
store_item(&Item);
}
runKernelWithArg<sycl::item<Dims, /*Offset=*/true>>(MKernel, Item);
});
}

template <class ArgT = KernelArgType>
Expand Down Expand Up @@ -336,7 +412,7 @@ class HostKernel : public HostKernelBase {
auto g = NDItem.get_group();
store_group(&g);
}
MKernel(NDItem);
runKernelWithArg<const sycl::nd_item<Dims>>(MKernel, NDItem);
});
});
}
Expand Down Expand Up @@ -364,7 +440,7 @@ class HostKernel : public HostKernelBase {
detail::NDLoop<Dims>::iterate(NGroups, [&](const id<Dims> &GroupID) {
sycl::group<Dims> Group =
IDBuilder::createGroup<Dims>(GlobalSize, LocalSize, NGroups, GroupID);
MKernel(Group);
runKernelWithArg<sycl::group<Dims>>(MKernel, Group);
});
}

Expand Down
Loading