Skip to content

Commit 98a339d

Browse files
authored
[SYCL][LIBCLC] Add support for non-uniform [I/F]Mul op in ptx-nvidiacl (#4217)
1 parent 545a243 commit 98a339d

File tree

1 file changed

+45
-3
lines changed

1 file changed

+45
-3
lines changed

libclc/ptx-nvidiacl/libspirv/group/collectives.cl

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ __clc__SubgroupBitwiseAny(uint op, bool predicate, bool *carry) {
151151
#define __CLC_MAX(x, y) ((x > y) ? (x) : (y))
152152
#define __CLC_OR(x, y) (x | y)
153153
#define __CLC_AND(x, y) (x & y)
154+
#define __CLC_MUL(x, y) (x * y)
154155

155156
#define __CLC_SUBGROUP_COLLECTIVE_BODY(OP, TYPE, IDENTITY) \
156157
uint sg_lid = __spirv_SubgroupLocalInvocationId(); \
@@ -210,6 +211,18 @@ __CLC_SUBGROUP_COLLECTIVE(FAdd, __CLC_ADD, half, 0)
210211
__CLC_SUBGROUP_COLLECTIVE(FAdd, __CLC_ADD, float, 0)
211212
__CLC_SUBGROUP_COLLECTIVE(FAdd, __CLC_ADD, double, 0)
212213

214+
__CLC_SUBGROUP_COLLECTIVE(IMul, __CLC_MUL, char, 0)
215+
__CLC_SUBGROUP_COLLECTIVE(IMul, __CLC_MUL, uchar, 0)
216+
__CLC_SUBGROUP_COLLECTIVE(IMul, __CLC_MUL, short, 0)
217+
__CLC_SUBGROUP_COLLECTIVE(IMul, __CLC_MUL, ushort, 0)
218+
__CLC_SUBGROUP_COLLECTIVE(IMul, __CLC_MUL, int, 0)
219+
__CLC_SUBGROUP_COLLECTIVE(IMul, __CLC_MUL, uint, 0)
220+
__CLC_SUBGROUP_COLLECTIVE(IMul, __CLC_MUL, long, 0)
221+
__CLC_SUBGROUP_COLLECTIVE(IMul, __CLC_MUL, ulong, 0)
222+
__CLC_SUBGROUP_COLLECTIVE(FMul, __CLC_MUL, half, 0)
223+
__CLC_SUBGROUP_COLLECTIVE(FMul, __CLC_MUL, float, 0)
224+
__CLC_SUBGROUP_COLLECTIVE(FMul, __CLC_MUL, double, 0)
225+
213226
__CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, char, CHAR_MAX)
214227
__CLC_SUBGROUP_COLLECTIVE(UMin, __CLC_MIN, uchar, UCHAR_MAX)
215228
__CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, short, SHRT_MAX)
@@ -238,12 +251,12 @@ __CLC_SUBGROUP_COLLECTIVE(FMax, __CLC_MAX, double, -DBL_MAX)
238251
#undef __CLC_SUBGROUP_COLLECTIVE
239252
#undef __CLC_SUBGROUP_COLLECTIVE_REDUX
240253

241-
#define __CLC_GROUP_COLLECTIVE(NAME, OP, TYPE, IDENTITY) \
254+
#define __CLC_GROUP_COLLECTIVE_INNER(SPIRV_NAME, CLC_NAME, OP, TYPE, IDENTITY) \
242255
_CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
243-
__spirv_Group, NAME)(uint scope, uint op, TYPE x) { \
256+
__spirv_Group, SPIRV_NAME)(uint scope, uint op, TYPE x) { \
244257
TYPE carry = IDENTITY; \
245258
/* Perform GroupOperation within sub-group */ \
246-
TYPE sg_x = __CLC_APPEND(__clc__Subgroup, NAME)(op, x, &carry); \
259+
TYPE sg_x = __CLC_APPEND(__clc__Subgroup, CLC_NAME)(op, x, &carry); \
247260
if (scope == Subgroup) { \
248261
return sg_x; \
249262
} \
@@ -283,6 +296,18 @@ __CLC_SUBGROUP_COLLECTIVE(FMax, __CLC_MAX, double, -DBL_MAX)
283296
return result; \
284297
}
285298

299+
#define __CLC_GROUP_COLLECTIVE_4(NAME, OP, TYPE, IDENTITY) \
300+
__CLC_GROUP_COLLECTIVE_INNER(NAME, NAME, OP, TYPE, IDENTITY)
301+
#define __CLC_GROUP_COLLECTIVE_5(SPIRV_NAME, CLC_NAME, OP, TYPE, IDENTITY) \
302+
__CLC_GROUP_COLLECTIVE_INNER(SPIRV_NAME, CLC_NAME, OP, TYPE, IDENTITY)
303+
304+
#define DISPATCH_TO_CLC_GROUP_COLLECTIVE_MACRO(_1, _2, _3, _4, _5, NAME, ...) \
305+
NAME
306+
#define __CLC_GROUP_COLLECTIVE(...) \
307+
DISPATCH_TO_CLC_GROUP_COLLECTIVE_MACRO( \
308+
__VA_ARGS__, __CLC_GROUP_COLLECTIVE_5, __CLC_GROUP_COLLECTIVE_4) \
309+
(__VA_ARGS__)
310+
286311
__CLC_GROUP_COLLECTIVE(BitwiseOr, __CLC_OR, bool, false);
287312
__CLC_GROUP_COLLECTIVE(BitwiseAny, __CLC_AND, bool, true);
288313
_CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT bool __spirv_GroupAny(uint scope,
@@ -306,6 +331,19 @@ __CLC_GROUP_COLLECTIVE(FAdd, __CLC_ADD, half, 0)
306331
__CLC_GROUP_COLLECTIVE(FAdd, __CLC_ADD, float, 0)
307332
__CLC_GROUP_COLLECTIVE(FAdd, __CLC_ADD, double, 0)
308333

334+
// There is no Mul group op in SPIR-V, use non-uniform variant instead.
335+
__CLC_GROUP_COLLECTIVE(NonUniformIMul, IMul, __CLC_MUL, char, 0)
336+
__CLC_GROUP_COLLECTIVE(NonUniformIMul, IMul, __CLC_MUL, uchar, 0)
337+
__CLC_GROUP_COLLECTIVE(NonUniformIMul, IMul, __CLC_MUL, short, 0)
338+
__CLC_GROUP_COLLECTIVE(NonUniformIMul, IMul, __CLC_MUL, ushort, 0)
339+
__CLC_GROUP_COLLECTIVE(NonUniformIMul, IMul, __CLC_MUL, int, 0)
340+
__CLC_GROUP_COLLECTIVE(NonUniformIMul, IMul, __CLC_MUL, uint, 0)
341+
__CLC_GROUP_COLLECTIVE(NonUniformIMul, IMul, __CLC_MUL, long, 0)
342+
__CLC_GROUP_COLLECTIVE(NonUniformIMul, IMul, __CLC_MUL, ulong, 0)
343+
__CLC_GROUP_COLLECTIVE(NonUniformFMul, FMul, __CLC_MUL, half, 0)
344+
__CLC_GROUP_COLLECTIVE(NonUniformFMul, FMul, __CLC_MUL, float, 0)
345+
__CLC_GROUP_COLLECTIVE(NonUniformFMul, FMul, __CLC_MUL, double, 0)
346+
309347
__CLC_GROUP_COLLECTIVE(SMin, __CLC_MIN, char, CHAR_MAX)
310348
__CLC_GROUP_COLLECTIVE(UMin, __CLC_MIN, uchar, UCHAR_MAX)
311349
__CLC_GROUP_COLLECTIVE(SMin, __CLC_MIN, short, SHRT_MAX)
@@ -344,13 +382,17 @@ _CLC_DECL _CLC_CONVERGENT half _Z17__spirv_GroupFMaxjjDF16_(uint scope, uint op,
344382
return __spirv_GroupFMax(scope, op, x);
345383
}
346384

385+
#undef __CLC_GROUP_COLLECTIVE_4
386+
#undef __CLC_GROUP_COLLECTIVE_5
387+
#undef DISPATCH_TO_CLC_GROUP_COLLECTIVE_MACRO
347388
#undef __CLC_GROUP_COLLECTIVE
348389

349390
#undef __CLC_AND
350391
#undef __CLC_OR
351392
#undef __CLC_MAX
352393
#undef __CLC_MIN
353394
#undef __CLC_ADD
395+
#undef __CLC_MUL
354396

355397
long __clc__get_linear_local_id() {
356398
size_t id_x = __spirv_LocalInvocationId_x();

0 commit comments

Comments
 (0)