Skip to content

Commit 281f8e1

Browse files
authored
[SYCL] Refactor cmath wrapper templates (#19483)
Follow up from: #18706 * Include `<type_traits>`, this was working because `<cmath>` pulls it in but this is cleaner. * Use C++14 and C++17 `_v` and `_t` helpers. * Switch to variadic template for `__sycl_promote`, using C++17 fold expressions. In theory only needed for 1 to 3 types, but this is an internal helper anyway so it doesn't seem worth restricting. * Switch `typedef` to `using` With this patch the header now require C++17 support to compile. Which is fine since it's the minimum requirement for SYCL 2020, and this header is only compiled on the device side which is always handled by clang, so there shouldn't be any support issues.
1 parent fabd1cc commit 281f8e1

File tree

1 file changed

+29
-46
lines changed

1 file changed

+29
-46
lines changed

sycl/include/sycl/stl_wrappers/__sycl_cmath_wrapper_impl.hpp

Lines changed: 29 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,27 @@
1919
#define __SYCL_DEVICE_C \
2020
extern "C" __attribute__((sycl_device_only, always_inline))
2121

22+
// For std::enable_if, std::is_integral, std::is_floating_point, std::is_same,
23+
// and std::conjunction
24+
#include <type_traits>
25+
2226
// Promotion templates: the C++ standard library provides overloads that allow
2327
// arguments of math functions to be promoted. Any floating-point argument is
2428
// allowed to accept any integer type, which should then be promoted to double.
2529
// When multiple floating point arguments are available passing arguments with
2630
// different precision should promote to the larger type. The template helpers
2731
// below provide the machinery to define these promoting overloads.
28-
template <typename T, bool = (std::is_integral<T>::value ||
29-
std::is_floating_point<T>::value)>
32+
template <typename T,
33+
bool = (std::is_integral_v<T> || std::is_floating_point_v<T>)>
3034
struct __sycl_promote {
3135
private:
3236
// Integer types are promoted to double.
3337
template <typename U>
34-
static typename std::enable_if<std::is_integral<U>::value, double>::type
35-
test();
38+
static std::enable_if_t<std::is_integral_v<U>, double> test();
3639

3740
// Floating point types are used as-is.
3841
template <typename U>
39-
static typename std::enable_if<std::is_floating_point<U>::value, U>::type
40-
test();
42+
static std::enable_if_t<std::is_floating_point_v<U>, U> test();
4143

4244
public:
4345
// We rely on dummy templated methods and decltype to select the right type
@@ -48,29 +50,17 @@ struct __sycl_promote {
4850
// Variant without ::type to allow SFINAE for non-promotable types.
4951
template <typename T> struct __sycl_promote<T, false> {};
5052

51-
// With a single paramter we only need to promote integers.
52-
template <typename T>
53-
using __sycl_promote_1 = std::enable_if<std::is_integral<T>::value, double>;
54-
5553
// With two or three parameters we need to promote integers and possibly
5654
// floating point types. We rely on operator+ with decltype to deduce the
5755
// overall promotion type. This is only needed if at least one of the parameter
5856
// is an integer, or if there's multiple different floating point types.
59-
template <typename T, typename U>
60-
using __sycl_promote_2 =
61-
std::enable_if<!std::is_same<T, U>::value || std::is_integral<T>::value ||
62-
std::is_integral<U>::value,
63-
decltype(typename __sycl_promote<T>::type(0) +
64-
typename __sycl_promote<U>::type(0))>;
65-
66-
template <typename T, typename U, typename V>
67-
using __sycl_promote_3 =
68-
std::enable_if<!(std::is_same<T, U>::value && std::is_same<U, V>::value) ||
69-
std::is_integral<T>::value ||
70-
std::is_integral<U>::value || std::is_integral<V>::value,
71-
decltype(typename __sycl_promote<T>::type(0) +
72-
typename __sycl_promote<U>::type(0) +
73-
typename __sycl_promote<V>::type(0))>;
57+
template <typename T, typename... Ts>
58+
using __sycl_promote_t =
59+
std::enable_if_t<!std::conjunction_v<std::is_same<T, Ts>...> ||
60+
std::is_integral_v<T> ||
61+
(std::is_integral_v<Ts> || ...),
62+
decltype((typename __sycl_promote<Ts>::type(0) + ... +
63+
typename __sycl_promote<T>::type(0)))>;
7464

7565
// For each math built-in we need to define float and double overloads, an
7666
// extern "C" float variant with the 'f' suffix, and a version that promotes
@@ -85,8 +75,7 @@ using __sycl_promote_3 =
8575
__SYCL_DEVICE_C float NAME##f(float x) { return __spirv_ocl_##NAME(x); } \
8676
__SYCL_DEVICE float NAME(float x) { return __spirv_ocl_##NAME(x); } \
8777
__SYCL_DEVICE double NAME(double x) { return __spirv_ocl_##NAME(x); } \
88-
template <typename T> \
89-
__SYCL_DEVICE typename __sycl_promote_1<T>::type NAME(T x) { \
78+
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> NAME(T x) { \
9079
return __spirv_ocl_##NAME((double)x); \
9180
}
9281

@@ -101,8 +90,8 @@ using __sycl_promote_3 =
10190
return __spirv_ocl_##NAME(x, y); \
10291
} \
10392
template <typename T, typename U> \
104-
__SYCL_DEVICE __sycl_promote_2<T, U>::type NAME(T x, U y) { \
105-
typedef typename __sycl_promote_2<T, U>::type type; \
93+
__SYCL_DEVICE __sycl_promote_t<T, U> NAME(T x, U y) { \
94+
using type = __sycl_promote_t<T, U>; \
10695
return __spirv_ocl_##NAME((type)x, (type)y); \
10796
}
10897

@@ -127,8 +116,7 @@ __SYCL_DEVICE double abs(double x) { return x < 0 ? -x : x; }
127116
__SYCL_DEVICE float fabs(float x) { return x < 0 ? -x : x; }
128117
__SYCL_DEVICE_C float fabsf(float x) { return x < 0 ? -x : x; }
129118
__SYCL_DEVICE double fabs(double x) { return x < 0 ? -x : x; }
130-
template <typename T>
131-
__SYCL_DEVICE typename __sycl_promote_1<T>::type fabs(T x) {
119+
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> fabs(T x) {
132120
return x < 0 ? -x : x;
133121
}
134122

@@ -145,8 +133,8 @@ __SYCL_DEVICE double remquo(double x, double y, int *q) {
145133
return __spirv_ocl_remquo(x, y, q);
146134
}
147135
template <typename T, typename U>
148-
__SYCL_DEVICE typename __sycl_promote_2<T, U>::type remquo(T x, U y, int *q) {
149-
typedef typename __sycl_promote_2<T, U>::type type;
136+
__SYCL_DEVICE __sycl_promote_t<T, U> remquo(T x, U y, int *q) {
137+
using type = __sycl_promote_t<T, U>;
150138
return __spirv_ocl_remquo((type)x, (type)y, q);
151139
}
152140

@@ -160,8 +148,8 @@ __SYCL_DEVICE double fma(double x, double y, double z) {
160148
return __spirv_ocl_fma(x, y, z);
161149
}
162150
template <typename T, typename U, typename V>
163-
__SYCL_DEVICE typename __sycl_promote_3<T, U, V>::type fma(T x, U y, V z) {
164-
typedef typename __sycl_promote_3<T, U, V>::type type;
151+
__SYCL_DEVICE __sycl_promote_t<T, U, V> fma(T x, U y, V z) {
152+
using type = __sycl_promote_t<T, U, V>;
165153
return __spirv_ocl_fma((type)x, (type)y, (type)z);
166154
}
167155

@@ -256,8 +244,7 @@ __SYCL_DEVICE float frexp(float x, int *exp) {
256244
__SYCL_DEVICE double frexp(double x, int *exp) {
257245
return __spirv_ocl_frexp(x, exp);
258246
}
259-
template <typename T>
260-
__SYCL_DEVICE typename __sycl_promote_1<T>::type frexp(T x, int *exp) {
247+
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> frexp(T x, int *exp) {
261248
return __spirv_ocl_frexp((double)x, exp);
262249
}
263250

@@ -270,8 +257,7 @@ __SYCL_DEVICE float ldexp(float x, int exp) {
270257
__SYCL_DEVICE double ldexp(double x, int exp) {
271258
return __spirv_ocl_ldexp(x, exp);
272259
}
273-
template <typename T>
274-
__SYCL_DEVICE typename __sycl_promote_1<T>::type ldexp(T x, int exp) {
260+
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> ldexp(T x, int exp) {
275261
return __spirv_ocl_ldexp((double)x, exp);
276262
}
277263

@@ -286,7 +272,7 @@ __SYCL_DEVICE double modf(double x, double *intpart) {
286272
}
287273
// modf only supports integer x when the intpart is double.
288274
template <typename T>
289-
__SYCL_DEVICE typename __sycl_promote_1<T>::type modf(T x, double *intpart) {
275+
__SYCL_DEVICE __sycl_promote_t<T> modf(T x, double *intpart) {
290276
return __spirv_ocl_modf((double)x, intpart);
291277
}
292278

@@ -299,8 +285,7 @@ __SYCL_DEVICE float scalbn(float x, int exp) {
299285
__SYCL_DEVICE double scalbn(double x, int exp) {
300286
return __spirv_ocl_ldexp(x, exp);
301287
}
302-
template <typename T>
303-
__SYCL_DEVICE typename __sycl_promote_1<T>::type scalbn(T x, int exp) {
288+
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> scalbn(T x, int exp) {
304289
return __spirv_ocl_ldexp((double)x, exp);
305290
}
306291

@@ -313,8 +298,7 @@ __SYCL_DEVICE float scalbln(float x, long exp) {
313298
__SYCL_DEVICE double scalbln(double x, long exp) {
314299
return __spirv_ocl_ldexp(x, (int)exp);
315300
}
316-
template <typename T>
317-
__SYCL_DEVICE typename __sycl_promote_1<T>::type scalbln(T x, long exp) {
301+
template <typename T> __SYCL_DEVICE __sycl_promote_t<T> scalbln(T x, long exp) {
318302
return __spirv_ocl_ldexp((double)x, (int)exp);
319303
}
320304

@@ -323,8 +307,7 @@ __SYCL_DEVICE int ilogb(float x) { return __spirv_ocl_ilogb(x); }
323307
__SYCL_DEVICE int ilogb(double x) { return __spirv_ocl_ilogb(x); }
324308
// ilogb needs a special template since its signature doesn't include the
325309
// promoted type anywhere, so it needs to be specialized differently.
326-
template <typename T, typename std::enable_if<std::is_integral<T>::value,
327-
bool>::type = true>
310+
template <typename T, std::enable_if_t<std::is_integral_v<T>, bool> = true>
328311
__SYCL_DEVICE int ilogb(T x) {
329312
return __spirv_ocl_ilogb((double)x);
330313
}

0 commit comments

Comments
 (0)