2727template <typename vtype, typename reg_t >
2828X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit (reg_t zmm);
2929
30- template <typename vtype, typename reg_t >
31- X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_32bit (reg_t zmm);
30+ struct avx512_32bit_swizzle_ops ;
3231
3332template <>
3433struct zmm_vector <int32_t > {
@@ -39,6 +38,8 @@ struct zmm_vector<int32_t> {
3938 static const uint8_t numlanes = 16 ;
4039 static constexpr int network_sort_threshold = 256 ;
4140 static constexpr int partition_unroll_factor = 2 ;
41+
42+ using swizzle_ops = avx512_32bit_swizzle_ops;
4243
4344 static type_t type_max ()
4445 {
@@ -138,14 +139,16 @@ struct zmm_vector<int32_t> {
138139 const auto rev_index = _mm512_set_epi32 (NETWORK_32BIT_5);
139140 return permutexvar (rev_index, zmm);
140141 }
141- static reg_t bitonic_merge (reg_t x)
142- {
143- return bitonic_merge_zmm_32bit<zmm_vector<type_t >>(x);
144- }
145142 static reg_t sort_vec (reg_t x)
146143 {
147144 return sort_zmm_32bit<zmm_vector<type_t >>(x);
148145 }
146+ static reg_t cast_from (__m512i v){
147+ return v;
148+ }
149+ static __m512i cast_to (reg_t v){
150+ return v;
151+ }
149152};
150153template <>
151154struct zmm_vector <uint32_t > {
@@ -156,6 +159,8 @@ struct zmm_vector<uint32_t> {
156159 static const uint8_t numlanes = 16 ;
157160 static constexpr int network_sort_threshold = 256 ;
158161 static constexpr int partition_unroll_factor = 2 ;
162+
163+ using swizzle_ops = avx512_32bit_swizzle_ops;
159164
160165 static type_t type_max ()
161166 {
@@ -255,14 +260,16 @@ struct zmm_vector<uint32_t> {
255260 const auto rev_index = _mm512_set_epi32 (NETWORK_32BIT_5);
256261 return permutexvar (rev_index, zmm);
257262 }
258- static reg_t bitonic_merge (reg_t x)
259- {
260- return bitonic_merge_zmm_32bit<zmm_vector<type_t >>(x);
261- }
262263 static reg_t sort_vec (reg_t x)
263264 {
264265 return sort_zmm_32bit<zmm_vector<type_t >>(x);
265266 }
267+ static reg_t cast_from (__m512i v){
268+ return v;
269+ }
270+ static __m512i cast_to (reg_t v){
271+ return v;
272+ }
266273};
267274template <>
268275struct zmm_vector <float > {
@@ -273,6 +280,8 @@ struct zmm_vector<float> {
273280 static const uint8_t numlanes = 16 ;
274281 static constexpr int network_sort_threshold = 256 ;
275282 static constexpr int partition_unroll_factor = 2 ;
283+
284+ using swizzle_ops = avx512_32bit_swizzle_ops;
276285
277286 static type_t type_max ()
278287 {
@@ -386,14 +395,16 @@ struct zmm_vector<float> {
386395 const auto rev_index = _mm512_set_epi32 (NETWORK_32BIT_5);
387396 return permutexvar (rev_index, zmm);
388397 }
389- static reg_t bitonic_merge (reg_t x)
390- {
391- return bitonic_merge_zmm_32bit<zmm_vector<type_t >>(x);
392- }
393398 static reg_t sort_vec (reg_t x)
394399 {
395400 return sort_zmm_32bit<zmm_vector<type_t >>(x);
396401 }
402+ static reg_t cast_from (__m512i v){
403+ return _mm512_castsi512_ps (v);
404+ }
405+ static __m512i cast_to (reg_t v){
406+ return _mm512_castps_si512 (v);
407+ }
397408};
398409
399410/*
@@ -446,31 +457,66 @@ X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm)
446457 return zmm;
447458}
448459
449- // Assumes zmm is bitonic and performs a recursive half cleaner
450- template <typename vtype, typename reg_t = typename vtype::reg_t >
451- X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_32bit (reg_t zmm)
452- {
453- // 1) half_cleaner[16]: compare 1-9, 2-10, 3-11 etc ..
454- zmm = cmp_merge<vtype>(
455- zmm,
456- vtype::permutexvar (_mm512_set_epi32 (NETWORK_32BIT_7), zmm),
457- 0xFF00 );
458- // 2) half_cleaner[8]: compare 1-5, 2-6, 3-7 etc ..
459- zmm = cmp_merge<vtype>(
460- zmm,
461- vtype::permutexvar (_mm512_set_epi32 (NETWORK_32BIT_6), zmm),
462- 0xF0F0 );
463- // 3) half_cleaner[4]
464- zmm = cmp_merge<vtype>(
465- zmm,
466- vtype::template shuffle<SHUFFLE_MASK (1 , 0 , 3 , 2 )>(zmm),
467- 0xCCCC );
468- // 3) half_cleaner[1]
469- zmm = cmp_merge<vtype>(
470- zmm,
471- vtype::template shuffle<SHUFFLE_MASK (2 , 3 , 0 , 1 )>(zmm),
472- 0xAAAA );
473- return zmm;
474- }
460+ struct avx512_32bit_swizzle_ops {
461+ template <typename vtype, int scale>
462+ X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n (typename vtype::reg_t reg){
463+ __m512i v = vtype::cast_to (reg);
464+
465+ if constexpr (scale == 2 ){
466+ v = _mm512_shuffle_epi32 (v, (_MM_PERM_ENUM)0b10110001 );
467+ }else if constexpr (scale == 4 ){
468+ v = _mm512_shuffle_epi32 (v, (_MM_PERM_ENUM)0b01001110 );
469+ }else if constexpr (scale == 8 ){
470+ v = _mm512_shuffle_i64x2 (v, v, 0b10110001 );
471+ }else if constexpr (scale == 16 ){
472+ v = _mm512_shuffle_i64x2 (v, v, 0b01001110 );
473+ }else {
474+ static_assert (scale == -1 , " should not be reached" );
475+ }
476+
477+ return vtype::cast_from (v);
478+ }
479+
480+ template <typename vtype, int scale>
481+ X86_SIMD_SORT_INLINE typename vtype::reg_t reverse_n (typename vtype::reg_t reg){
482+ __m512i v = vtype::cast_to (reg);
483+
484+ if constexpr (scale == 2 ){
485+ return swap_n<vtype, 2 >(reg);
486+ }else if constexpr (scale == 4 ){
487+ __m512i mask = _mm512_set_epi32 (12 ,13 ,14 ,15 ,8 ,9 ,10 ,11 ,4 ,5 ,6 ,7 ,0 ,1 ,2 ,3 );
488+ v = _mm512_permutexvar_epi32 (mask, v);
489+ }else if constexpr (scale == 8 ){
490+ __m512i mask = _mm512_set_epi32 (8 ,9 ,10 ,11 ,12 ,13 ,14 ,15 ,0 ,1 ,2 ,3 ,4 ,5 ,6 ,7 );
491+ v = _mm512_permutexvar_epi32 (mask, v);
492+ }else if constexpr (scale == 16 ){
493+ return vtype::reverse (reg);
494+ }else {
495+ static_assert (scale == -1 , " should not be reached" );
496+ }
497+
498+ return vtype::cast_from (v);
499+ }
500+
501+ template <typename vtype, int scale>
502+ X86_SIMD_SORT_INLINE typename vtype::reg_t merge_n (typename vtype::reg_t reg, typename vtype::reg_t other){
503+ __m512i v1 = vtype::cast_to (reg);
504+ __m512i v2 = vtype::cast_to (other);
505+
506+ if constexpr (scale == 2 ){
507+ v1 = _mm512_mask_blend_epi32 (0b0101010101010101 , v1, v2);
508+ }else if constexpr (scale == 4 ){
509+ v1 = _mm512_mask_blend_epi32 (0b0011001100110011 , v1, v2);
510+ }else if constexpr (scale == 8 ){
511+ v1 = _mm512_mask_blend_epi32 (0b0000111100001111 , v1, v2);
512+ }else if constexpr (scale == 16 ){
513+ v1 = _mm512_mask_blend_epi32 (0b0000000011111111 , v1, v2);
514+ }else {
515+ static_assert (scale == -1 , " should not be reached" );
516+ }
517+
518+ return vtype::cast_from (v1);
519+ }
520+ };
475521
476522#endif // AVX512_QSORT_32BIT
0 commit comments