Skip to content

Commit 57313f4

Browse files
authored
[SYCL] Improve checks of kernel execution range in INT_MAX limit. (#2423)
Improved check is for the following use-cases. When we have a lot of small work-groups the number of work-items may still be out of INT_MAX limit Also, sum of range and offset may exceed the limitation while each value is within INT_MAX limit on its own. Signed-off-by: Sergey Kanaev <[email protected]>
1 parent f6418b4 commit 57313f4

File tree

3 files changed

+175
-47
lines changed

3 files changed

+175
-47
lines changed

sycl/include/CL/sycl/detail/defines.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,15 @@
3838
#define SYCL_EXTERNAL
3939
#endif
4040

41-
#if defined(__SYCL_ID_QUERIES_FIT_IN_INT__) && __has_builtin(__builtin_assume)
41+
#ifndef __SYCL_ID_QUERIES_FIT_IN_INT__
42+
#define __SYCL_ID_QUERIES_FIT_IN_INT__ 0
43+
#endif
44+
45+
#if __SYCL_ID_QUERIES_FIT_IN_INT__ && __has_builtin(__builtin_assume)
4246
#define __SYCL_ASSUME_INT(x) __builtin_assume((x) <= INT_MAX)
4347
#else
4448
#define __SYCL_ASSUME_INT(x)
45-
#if defined(__SYCL_ID_QUERIES_FIT_IN_INT__) && !__has_builtin(__builtin_assume)
49+
#if __SYCL_ID_QUERIES_FIT_IN_INT__ && !__has_builtin(__builtin_assume)
4650
#warning "No assumptions will be emitted due to no __builtin_assume available"
4751
#endif
4852
#endif

sycl/include/CL/sycl/handler.hpp

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ struct check_fn_signature<F, RetT(Args...)> {
143143

144144
__SYCL_EXPORT device getDeviceFromHandler(handler &);
145145

146-
#if defined(__SYCL_ID_QUERIES_FIT_IN_INT__)
146+
#if __SYCL_ID_QUERIES_FIT_IN_INT__
147147
template <typename T> struct NotIntMsg;
148148

149149
template <int Dims> struct NotIntMsg<range<Dims>> {
@@ -159,16 +159,65 @@ template <int Dims> struct NotIntMsg<id<Dims>> {
159159
};
160160
#endif
161161

162+
#if __SYCL_ID_QUERIES_FIT_IN_INT__
163+
template <typename T, typename ValT>
164+
typename std::enable_if<std::is_same<ValT, size_t>::value ||
165+
std::is_same<ValT, unsigned long long>::value>::type
166+
checkValueRangeImpl(ValT V) {
167+
static constexpr size_t Limit =
168+
static_cast<size_t>((std::numeric_limits<int>::max)());
169+
if (V > Limit)
170+
throw runtime_error(NotIntMsg<T>::Msg, PI_INVALID_VALUE);
171+
}
172+
#endif
173+
162174
template <int Dims, typename T>
163175
typename std::enable_if<std::is_same<T, range<Dims>>::value ||
164176
std::is_same<T, id<Dims>>::value>::type
165177
checkValueRange(const T &V) {
166-
#if defined(__SYCL_ID_QUERIES_FIT_IN_INT__)
167-
static constexpr size_t Limit =
168-
static_cast<size_t>((std::numeric_limits<int>::max)());
178+
#if __SYCL_ID_QUERIES_FIT_IN_INT__
169179
for (size_t Dim = 0; Dim < Dims; ++Dim)
170-
if (V[Dim] > Limit)
171-
throw runtime_error(NotIntMsg<T>::Msg, PI_INVALID_VALUE);
180+
checkValueRangeImpl<T>(V[Dim]);
181+
182+
{
183+
unsigned long long Product = 1;
184+
for (size_t Dim = 0; Dim < Dims; ++Dim) {
185+
Product *= V[Dim];
186+
// check value now to prevent product overflow in the end
187+
checkValueRangeImpl<T>(Product);
188+
}
189+
}
190+
#else
191+
(void)V;
192+
#endif
193+
}
194+
195+
template <int Dims>
196+
void checkValueRange(const range<Dims> &R, const id<Dims> &O) {
197+
#if __SYCL_ID_QUERIES_FIT_IN_INT__
198+
checkValueRange<Dims>(R);
199+
checkValueRange<Dims>(O);
200+
201+
for (size_t Dim = 0; Dim < Dims; ++Dim) {
202+
unsigned long long Sum = R[Dim] + O[Dim];
203+
204+
checkValueRangeImpl<range<Dims>>(Sum);
205+
}
206+
#else
207+
(void)R;
208+
(void)O;
209+
#endif
210+
}
211+
212+
template <int Dims, typename T>
213+
typename std::enable_if<std::is_same<T, nd_range<Dims>>::value>::type
214+
checkValueRange(const T &V) {
215+
#if __SYCL_ID_QUERIES_FIT_IN_INT__
216+
checkValueRange<Dims>(V.get_global_range());
217+
checkValueRange<Dims>(V.get_local_range());
218+
checkValueRange<Dims>(V.get_offset());
219+
220+
checkValueRange<Dims>(V.get_global_range(), V.get_offset());
172221
#else
173222
(void)V;
174223
#endif
@@ -982,8 +1031,7 @@ class __SYCL_EXPORT handler {
9821031
(void)WorkItemOffset;
9831032
kernel_parallel_for<NameT, LambdaArgType>(KernelFunc);
9841033
#else
985-
detail::checkValueRange<Dims>(NumWorkItems);
986-
detail::checkValueRange<Dims>(WorkItemOffset);
1034+
detail::checkValueRange<Dims>(NumWorkItems, WorkItemOffset);
9871035
MNDRDesc.set(std::move(NumWorkItems), std::move(WorkItemOffset));
9881036
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
9891037
MCGType = detail::CG::KERNEL;
@@ -1015,9 +1063,7 @@ class __SYCL_EXPORT handler {
10151063
(void)ExecutionRange;
10161064
kernel_parallel_for<NameT, LambdaArgType>(KernelFunc);
10171065
#else
1018-
detail::checkValueRange<Dims>(ExecutionRange.get_global_range());
1019-
detail::checkValueRange<Dims>(ExecutionRange.get_local_range());
1020-
detail::checkValueRange<Dims>(ExecutionRange.get_offset());
1066+
detail::checkValueRange<Dims>(ExecutionRange);
10211067
MNDRDesc.set(std::move(ExecutionRange));
10221068
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
10231069
MCGType = detail::CG::KERNEL;
@@ -1225,9 +1271,7 @@ class __SYCL_EXPORT handler {
12251271
#else
12261272
nd_range<Dims> ExecRange =
12271273
nd_range<Dims>(NumWorkGroups * WorkGroupSize, WorkGroupSize);
1228-
detail::checkValueRange<Dims>(ExecRange.get_global_range());
1229-
detail::checkValueRange<Dims>(ExecRange.get_local_range());
1230-
detail::checkValueRange<Dims>(ExecRange.get_offset());
1274+
detail::checkValueRange<Dims>(ExecRange);
12311275
MNDRDesc.set(std::move(ExecRange));
12321276
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
12331277
MCGType = detail::CG::KERNEL;
@@ -1278,8 +1322,7 @@ class __SYCL_EXPORT handler {
12781322
throwIfActionIsCreated();
12791323
verifyKernelInvoc(Kernel);
12801324
MKernel = detail::getSyclObjImpl(std::move(Kernel));
1281-
detail::checkValueRange<Dims>(NumWorkItems);
1282-
detail::checkValueRange<Dims>(WorkItemOffset);
1325+
detail::checkValueRange<Dims>(NumWorkItems, WorkItemOffset);
12831326
MNDRDesc.set(std::move(NumWorkItems), std::move(WorkItemOffset));
12841327
MCGType = detail::CG::KERNEL;
12851328
extractArgsAndReqs();
@@ -1298,9 +1341,7 @@ class __SYCL_EXPORT handler {
12981341
throwIfActionIsCreated();
12991342
verifyKernelInvoc(Kernel);
13001343
MKernel = detail::getSyclObjImpl(std::move(Kernel));
1301-
detail::checkValueRange<Dims>(NDRange.get_global_range());
1302-
detail::checkValueRange<Dims>(NDRange.get_local_range());
1303-
detail::checkValueRange<Dims>(NDRange.get_offset());
1344+
detail::checkValueRange<Dims>(NDRange);
13041345
MNDRDesc.set(std::move(NDRange));
13051346
MCGType = detail::CG::KERNEL;
13061347
extractArgsAndReqs();
@@ -1400,8 +1441,7 @@ class __SYCL_EXPORT handler {
14001441
(void)WorkItemOffset;
14011442
kernel_parallel_for<NameT, LambdaArgType>(KernelFunc);
14021443
#else
1403-
detail::checkValueRange<Dims>(NumWorkItems);
1404-
detail::checkValueRange<Dims>(WorkItemOffset);
1444+
detail::checkValueRange<Dims>(NumWorkItems, WorkItemOffset);
14051445
MNDRDesc.set(std::move(NumWorkItems), std::move(WorkItemOffset));
14061446
MKernel = detail::getSyclObjImpl(std::move(Kernel));
14071447
MCGType = detail::CG::KERNEL;
@@ -1437,9 +1477,7 @@ class __SYCL_EXPORT handler {
14371477
(void)NDRange;
14381478
kernel_parallel_for<NameT, LambdaArgType>(KernelFunc);
14391479
#else
1440-
detail::checkValueRange<Dims>(NDRange.get_global_range());
1441-
detail::checkValueRange<Dims>(NDRange.get_local_range());
1442-
detail::checkValueRange<Dims>(NDRange.get_offset());
1480+
detail::checkValueRange<Dims>(NDRange);
14431481
MNDRDesc.set(std::move(NDRange));
14441482
MKernel = detail::getSyclObjImpl(std::move(Kernel));
14451483
MCGType = detail::CG::KERNEL;
@@ -1520,9 +1558,7 @@ class __SYCL_EXPORT handler {
15201558
#else
15211559
nd_range<Dims> ExecRange =
15221560
nd_range<Dims>(NumWorkGroups * WorkGroupSize, WorkGroupSize);
1523-
detail::checkValueRange<Dims>(ExecRange.get_global_range());
1524-
detail::checkValueRange<Dims>(ExecRange.get_local_range());
1525-
detail::checkValueRange<Dims>(ExecRange.get_offset());
1561+
detail::checkValueRange<Dims>(ExecRange);
15261562
MNDRDesc.set(std::move(ExecRange));
15271563
MKernel = detail::getSyclObjImpl(std::move(Kernel));
15281564
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));

0 commit comments

Comments
 (0)