@@ -151,6 +151,7 @@ __clc__SubgroupBitwiseAny(uint op, bool predicate, bool *carry) {
151151#define __CLC_MAX (x , y ) ((x > y) ? (x) : (y))
152152#define __CLC_OR (x , y ) (x | y)
153153#define __CLC_AND (x , y ) (x & y)
154+ #define __CLC_MUL (x , y ) (x * y)
154155
155156#define __CLC_SUBGROUP_COLLECTIVE_BODY (OP , TYPE , IDENTITY ) \
156157 uint sg_lid = __spirv_SubgroupLocalInvocationId(); \
@@ -210,6 +211,18 @@ __CLC_SUBGROUP_COLLECTIVE(FAdd, __CLC_ADD, half, 0)
210211__CLC_SUBGROUP_COLLECTIVE (FAdd , __CLC_ADD , float , 0 )
211212__CLC_SUBGROUP_COLLECTIVE (FAdd , __CLC_ADD , double , 0 )
212213
214+ __CLC_SUBGROUP_COLLECTIVE (IMul , __CLC_MUL , char , 0 )
215+ __CLC_SUBGROUP_COLLECTIVE (IMul , __CLC_MUL , uchar , 0 )
216+ __CLC_SUBGROUP_COLLECTIVE (IMul , __CLC_MUL , short , 0 )
217+ __CLC_SUBGROUP_COLLECTIVE (IMul , __CLC_MUL , ushort , 0 )
218+ __CLC_SUBGROUP_COLLECTIVE (IMul , __CLC_MUL , int , 0 )
219+ __CLC_SUBGROUP_COLLECTIVE (IMul , __CLC_MUL , uint , 0 )
220+ __CLC_SUBGROUP_COLLECTIVE (IMul , __CLC_MUL , long , 0 )
221+ __CLC_SUBGROUP_COLLECTIVE (IMul , __CLC_MUL , ulong , 0 )
222+ __CLC_SUBGROUP_COLLECTIVE (FMul , __CLC_MUL , half , 0 )
223+ __CLC_SUBGROUP_COLLECTIVE (FMul , __CLC_MUL , float , 0 )
224+ __CLC_SUBGROUP_COLLECTIVE (FMul , __CLC_MUL , double , 0 )
225+
213226__CLC_SUBGROUP_COLLECTIVE (SMin , __CLC_MIN , char , CHAR_MAX )
214227__CLC_SUBGROUP_COLLECTIVE (UMin , __CLC_MIN , uchar , UCHAR_MAX )
215228__CLC_SUBGROUP_COLLECTIVE (SMin , __CLC_MIN , short , SHRT_MAX )
@@ -238,12 +251,12 @@ __CLC_SUBGROUP_COLLECTIVE(FMax, __CLC_MAX, double, -DBL_MAX)
238251#undef __CLC_SUBGROUP_COLLECTIVE
239252#undef __CLC_SUBGROUP_COLLECTIVE_REDUX
240253
241- #define __CLC_GROUP_COLLECTIVE ( NAME , OP , TYPE , IDENTITY ) \
254+ #define __CLC_GROUP_COLLECTIVE_INNER ( SPIRV_NAME , CLC_NAME , OP , TYPE , IDENTITY ) \
242255 _CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
243- __spirv_Group, NAME )(uint scope, uint op, TYPE x) { \
256+ __spirv_Group, SPIRV_NAME )(uint scope, uint op, TYPE x) { \
244257 TYPE carry = IDENTITY; \
245258 /* Perform GroupOperation within sub-group */ \
246- TYPE sg_x = __CLC_APPEND (__clc__Subgroup , NAME )(op , x , & carry ); \
259+ TYPE sg_x = __CLC_APPEND (__clc__Subgroup , CLC_NAME )(op , x , & carry ); \
247260 if (scope == Subgroup ) { \
248261 return sg_x ; \
249262 } \
@@ -283,6 +296,18 @@ __CLC_SUBGROUP_COLLECTIVE(FMax, __CLC_MAX, double, -DBL_MAX)
283296 return result ; \
284297 }
285298
299+ #define __CLC_GROUP_COLLECTIVE_4 (NAME , OP , TYPE , IDENTITY ) \
300+ __CLC_GROUP_COLLECTIVE_INNER(NAME, NAME, OP, TYPE, IDENTITY)
301+ #define __CLC_GROUP_COLLECTIVE_5 (SPIRV_NAME , CLC_NAME , OP , TYPE , IDENTITY ) \
302+ __CLC_GROUP_COLLECTIVE_INNER(SPIRV_NAME, CLC_NAME, OP, TYPE, IDENTITY)
303+
304+ #define DISPATCH_TO_CLC_GROUP_COLLECTIVE_MACRO (_1 , _2 , _3 , _4 , _5 , NAME , ...) \
305+ NAME
306+ #define __CLC_GROUP_COLLECTIVE (...) \
307+ DISPATCH_TO_CLC_GROUP_COLLECTIVE_MACRO( \
308+ __VA_ARGS__, __CLC_GROUP_COLLECTIVE_5, __CLC_GROUP_COLLECTIVE_4) \
309+ (__VA_ARGS__)
310+
286311__CLC_GROUP_COLLECTIVE (BitwiseOr , __CLC_OR , bool , false);
287312__CLC_GROUP_COLLECTIVE (BitwiseAny , __CLC_AND , bool , true);
288313_CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT bool __spirv_GroupAny (uint scope ,
@@ -306,6 +331,19 @@ __CLC_GROUP_COLLECTIVE(FAdd, __CLC_ADD, half, 0)
306331__CLC_GROUP_COLLECTIVE (FAdd , __CLC_ADD , float , 0 )
307332__CLC_GROUP_COLLECTIVE (FAdd , __CLC_ADD , double , 0 )
308333
334+ // There is no Mul group op in SPIR-V, use non-uniform variant instead.
335+ __CLC_GROUP_COLLECTIVE (NonUniformIMul , IMul , __CLC_MUL , char , 0 )
336+ __CLC_GROUP_COLLECTIVE (NonUniformIMul , IMul , __CLC_MUL , uchar , 0 )
337+ __CLC_GROUP_COLLECTIVE (NonUniformIMul , IMul , __CLC_MUL , short , 0 )
338+ __CLC_GROUP_COLLECTIVE (NonUniformIMul , IMul , __CLC_MUL , ushort , 0 )
339+ __CLC_GROUP_COLLECTIVE (NonUniformIMul , IMul , __CLC_MUL , int , 0 )
340+ __CLC_GROUP_COLLECTIVE (NonUniformIMul , IMul , __CLC_MUL , uint , 0 )
341+ __CLC_GROUP_COLLECTIVE (NonUniformIMul , IMul , __CLC_MUL , long , 0 )
342+ __CLC_GROUP_COLLECTIVE (NonUniformIMul , IMul , __CLC_MUL , ulong , 0 )
343+ __CLC_GROUP_COLLECTIVE (NonUniformFMul , FMul , __CLC_MUL , half , 0 )
344+ __CLC_GROUP_COLLECTIVE (NonUniformFMul , FMul , __CLC_MUL , float , 0 )
345+ __CLC_GROUP_COLLECTIVE (NonUniformFMul , FMul , __CLC_MUL , double , 0 )
346+
309347__CLC_GROUP_COLLECTIVE (SMin , __CLC_MIN , char , CHAR_MAX )
310348__CLC_GROUP_COLLECTIVE (UMin , __CLC_MIN , uchar , UCHAR_MAX )
311349__CLC_GROUP_COLLECTIVE (SMin , __CLC_MIN , short , SHRT_MAX )
@@ -344,13 +382,17 @@ _CLC_DECL _CLC_CONVERGENT half _Z17__spirv_GroupFMaxjjDF16_(uint scope, uint op,
344382 return __spirv_GroupFMax (scope , op , x );
345383}
346384
385+ #undef __CLC_GROUP_COLLECTIVE_4
386+ #undef __CLC_GROUP_COLLECTIVE_5
387+ #undef DISPATCH_TO_CLC_GROUP_COLLECTIVE_MACRO
347388#undef __CLC_GROUP_COLLECTIVE
348389
349390#undef __CLC_AND
350391#undef __CLC_OR
351392#undef __CLC_MAX
352393#undef __CLC_MIN
353394#undef __CLC_ADD
395+ #undef __CLC_MUL
354396
355397long __clc__get_linear_local_id () {
356398 size_t id_x = __spirv_LocalInvocationId_x ();
0 commit comments