@@ -139,11 +139,18 @@ struct sub_group {
139
139
return BinaryOperation::template calc<T, cl::__spirv::InclusiveScan>(x);
140
140
}
141
141
142
+ template <typename T>
143
+ using EnableIfIsArithmeticOrHalf = typename std::enable_if<
144
+ (std::is_arithmetic<T>::value ||
145
+ std::is_same<typename std::remove_const<T>::type, half>::value),
146
+ T>::type;
147
+
148
+
142
149
/* --- one - input shuffles --- */
143
150
/* indices in [0 , sub - group size ) */
144
151
145
152
template <typename T>
146
- typename std::enable_if<std::is_arithmetic<T>::value, T>::type
153
+ EnableIfIsArithmeticOrHalf<T>
147
154
shuffle (T x, id<1 > local_id) {
148
155
return cl::__spirv::OpSubgroupShuffleINTEL (x, local_id.get (0 ));
149
156
}
@@ -156,7 +163,7 @@ struct sub_group {
156
163
}
157
164
158
165
template <typename T>
159
- typename std::enable_if<std::is_arithmetic<T>::value, T>::type
166
+ EnableIfIsArithmeticOrHalf<T>
160
167
shuffle_down (T x, uint32_t delta) {
161
168
return shuffle_down (x, x, delta);
162
169
}
@@ -168,7 +175,7 @@ struct sub_group {
168
175
}
169
176
170
177
template <typename T>
171
- typename std::enable_if<std::is_arithmetic<T>::value, T>::type
178
+ EnableIfIsArithmeticOrHalf<T>
172
179
shuffle_up (T x, uint32_t delta) {
173
180
return shuffle_up (x, x, delta);
174
181
}
@@ -180,7 +187,7 @@ struct sub_group {
180
187
}
181
188
182
189
template <typename T>
183
- typename std::enable_if<std::is_arithmetic<T>::value, T>::type
190
+ EnableIfIsArithmeticOrHalf<T>
184
191
shuffle_xor (T x, id<1 > value) {
185
192
return cl::__spirv::OpSubgroupShuffleXorINTEL (x, (uint32_t )value.get (0 ));
186
193
}
@@ -195,7 +202,7 @@ struct sub_group {
195
202
/* --- two - input shuffles --- */
196
203
/* indices in [0 , 2* sub - group size ) */
197
204
template <typename T>
198
- typename std::enable_if<std::is_arithmetic<T>::value, T>::type
205
+ EnableIfIsArithmeticOrHalf<T>
199
206
shuffle (T x, T y, id<1 > local_id) {
200
207
return cl::__spirv::OpSubgroupShuffleDownINTEL (
201
208
x, y, local_id.get (0 ) - get_local_id ().get (0 ));
@@ -210,7 +217,7 @@ struct sub_group {
210
217
}
211
218
212
219
template <typename T>
213
- typename std::enable_if<std::is_arithmetic<T>::value, T>::type
220
+ EnableIfIsArithmeticOrHalf<T>
214
221
shuffle_down (T current, T next, uint32_t delta) {
215
222
return cl::__spirv::OpSubgroupShuffleDownINTEL (current, next, delta);
216
223
}
@@ -223,7 +230,7 @@ struct sub_group {
223
230
}
224
231
225
232
template <typename T>
226
- typename std::enable_if<std::is_arithmetic<T>::value, T>::type
233
+ EnableIfIsArithmeticOrHalf<T>
227
234
shuffle_up (T previous, T current, uint32_t delta) {
228
235
return cl::__spirv::OpSubgroupShuffleUpINTEL (previous, current, delta);
229
236
}
0 commit comments