diff --git a/sycl/include/sycl/detail/spirv.hpp b/sycl/include/sycl/detail/spirv.hpp index 4650d7a58bbbb..a20eec899d95a 100644 --- a/sycl/include/sycl/detail/spirv.hpp +++ b/sycl/include/sycl/detail/spirv.hpp @@ -537,6 +537,49 @@ using EnableIfVectorShuffle = std::enable_if_t::value, T>; #endif // ifndef __NVPTX__ +// Bitcast shuffles can be implemented using a single SubgroupShuffle +// intrinsic, but require type-punning via an appropriate integer type +#ifndef __NVPTX__ +template +using EnableIfBitcastShuffle = + std::enable_if_t::value && + (std::is_trivially_copyable_v && + (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || + sizeof(T) == 8)), + T>; +#else +template +using EnableIfBitcastShuffle = + std::enable_if_t && + (sizeof(T) <= sizeof(int32_t))) && + !detail::is_vector_arithmetic::value && + (std::is_trivially_copyable_v && + (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)), + T>; +#endif // ifndef __NVPTX__ + +// Generic shuffles may require multiple calls to SubgroupShuffle +// intrinsics, and should use the fewest shuffles possible: +// - Loop over 64-bit chunks until remaining bytes < 64-bit +// - At most one 32-bit, 16-bit and 8-bit chunk left over +#ifndef __NVPTX__ +template +using EnableIfGenericShuffle = + std::enable_if_t::value && + !(std::is_trivially_copyable_v && + (sizeof(T) == 1 || sizeof(T) == 2 || + sizeof(T) == 4 || sizeof(T) == 8)), + T>; +#else +template +using EnableIfGenericShuffle = std::enable_if_t< + !(std::is_integral::value && (sizeof(T) <= sizeof(int32_t))) && + !detail::is_vector_arithmetic::value && + !(std::is_trivially_copyable_v && + (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)), + T>; +#endif + #ifdef __NVPTX__ inline uint32_t membermask() { // use a full mask as sync operations are required to be convergent and exited @@ -545,6 +588,31 @@ inline uint32_t membermask() { } #endif +// Forward declarations for template overloadings +template +EnableIfBitcastShuffle SubgroupShuffle(T x, id<1> local_id); + +template +EnableIfBitcastShuffle SubgroupShuffleXor(T x, id<1> local_id); + +template +EnableIfBitcastShuffle SubgroupShuffleDown(T x, id<1> local_id); + +template +EnableIfBitcastShuffle SubgroupShuffleUp(T x, id<1> local_id); + +template +EnableIfGenericShuffle SubgroupShuffle(T x, id<1> local_id); + +template +EnableIfGenericShuffle SubgroupShuffleXor(T x, id<1> local_id); + +template +EnableIfGenericShuffle SubgroupShuffleDown(T x, id<1> local_id); + +template +EnableIfGenericShuffle SubgroupShuffleUp(T x, id<1> local_id); + template EnableIfNativeShuffle SubgroupShuffle(T x, id<1> local_id) { #ifndef __NVPTX__ @@ -623,26 +691,6 @@ EnableIfVectorShuffle SubgroupShuffleUp(T x, uint32_t delta) { return result; } -// Bitcast shuffles can be implemented using a single SubgroupShuffle -// intrinsic, but require type-punning via an appropriate integer type -#ifndef __NVPTX__ -template -using EnableIfBitcastShuffle = - detail::enable_if_t::value && - (std::is_trivially_copyable::value && - (sizeof(T) == 1 || sizeof(T) == 2 || - sizeof(T) == 4 || sizeof(T) == 8)), - T>; -#else -template -using EnableIfBitcastShuffle = detail::enable_if_t< - !(std::is_integral::value && (sizeof(T) <= sizeof(int32_t))) && - !detail::is_vector_arithmetic::value && - (std::is_trivially_copyable::value && - (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)), - T>; -#endif - template using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t; @@ -699,28 +747,6 @@ EnableIfBitcastShuffle SubgroupShuffleUp(T x, uint32_t delta) { return bit_cast(Result); } -// Generic shuffles may require multiple calls to SubgroupShuffle -// intrinsics, and should use the fewest shuffles possible: -// - Loop over 64-bit chunks until remaining bytes < 64-bit -// - At most one 32-bit, 16-bit and 8-bit chunk left over -#ifndef __NVPTX__ -template -using EnableIfGenericShuffle = - detail::enable_if_t::value && - !(std::is_trivially_copyable::value && - (sizeof(T) == 1 || sizeof(T) == 2 || - sizeof(T) == 4 || sizeof(T) == 8)), - T>; -#else -template -using EnableIfGenericShuffle = detail::enable_if_t< - !(std::is_integral::value && (sizeof(T) <= sizeof(int32_t))) && - !detail::is_vector_arithmetic::value && - !(std::is_trivially_copyable::value && - (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)), - T>; -#endif - template EnableIfGenericShuffle SubgroupShuffle(T x, id<1> local_id) { T Result;