From 12306b6ff5f0293449616d3dbf1080c3fb4f4420 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Thu, 26 Oct 2023 13:59:57 -0700 Subject: [PATCH 1/9] Genericized key-value and argsort networks --- src/avx2-64bit-qsort.hpp | 9 - src/avx512-64bit-argsort.hpp | 193 +-------- src/avx512-64bit-common.h | 46 +++ src/avx512-64bit-keyvaluesort.hpp | 348 ----------------- src/xss-network-keyvaluesort.hpp | 625 ++++++++++++------------------ 5 files changed, 306 insertions(+), 915 deletions(-) diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index 6ffddbde..fd7f92af 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -11,15 +11,6 @@ #include "xss-common-qsort.h" #include "avx2-emu-funcs.hpp" -/* - * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -// ymm 3, 2, 1, 0 -#define NETWORK_64BIT_R 0, 1, 2, 3 -#define NETWORK_64BIT_1 1, 0, 3, 2 - /* * Assumes ymm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index c4084c68..ad8ecd23 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -352,195 +352,6 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, return l_store; } -template -X86_SIMD_SORT_INLINE void -argsort_8_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - using reg_t = typename vtype::reg_t; - typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; - argreg_t argzmm = argtype::maskz_loadu(load_mask, arg); - reg_t arrzmm = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask, argzmm, arr); - arrzmm = sort_zmm_64bit(arrzmm, argzmm); - argtype::mask_storeu(arg, load_mask, argzmm); -} - -template -X86_SIMD_SORT_INLINE void -argsort_16_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - if (N <= 8) { - argsort_8_64bit(arr, arg, N); - return; - } - using reg_t = typename vtype::reg_t; - typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - argreg_t argzmm1 = argtype::loadu(arg); - argreg_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8); - reg_t arrzmm1 = vtype::i64gather(arr, arg); - reg_t arrzmm2 = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask, argzmm2, arr); - arrzmm1 = sort_zmm_64bit(arrzmm1, argzmm1); - arrzmm2 = sort_zmm_64bit(arrzmm2, argzmm2); - bitonic_merge_two_zmm_64bit( - arrzmm1, arrzmm2, argzmm1, argzmm2); - argtype::storeu(arg, argzmm1); - argtype::mask_storeu(arg + 8, load_mask, argzmm2); -} - -template -X86_SIMD_SORT_INLINE void -argsort_32_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - if (N <= 16) { - argsort_16_64bit(arr, arg, N); - return; - } - using reg_t = typename vtype::reg_t; - using opmask_t = typename vtype::opmask_t; - reg_t arrzmm[4]; - argreg_t argzmm[4]; - - X86_SIMD_SORT_UNROLL_LOOP(2) - for (int ii = 0; ii < 2; ++ii) { - argzmm[ii] = argtype::loadu(arg + 8 * ii); - arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); - arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); - } - - uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; - opmask_t load_mask[2] = {0xFF, 0xFF}; - X86_SIMD_SORT_UNROLL_LOOP(2) - for (int ii = 0; ii < 2; ++ii) { - load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF; - argzmm[ii + 2] = argtype::maskz_loadu(load_mask[ii], arg + 16 + 8 * ii); - arrzmm[ii + 2] = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask[ii], argzmm[ii + 2], arr); - arrzmm[ii + 2] = sort_zmm_64bit(arrzmm[ii + 2], - argzmm[ii + 2]); - } - - bitonic_merge_two_zmm_64bit( - arrzmm[0], arrzmm[1], argzmm[0], argzmm[1]); - bitonic_merge_two_zmm_64bit( - arrzmm[2], arrzmm[3], argzmm[2], argzmm[3]); - bitonic_merge_four_zmm_64bit(arrzmm, argzmm); - - argtype::storeu(arg, argzmm[0]); - argtype::storeu(arg + 8, argzmm[1]); - argtype::mask_storeu(arg + 16, load_mask[0], argzmm[2]); - argtype::mask_storeu(arg + 24, load_mask[1], argzmm[3]); -} - -template -X86_SIMD_SORT_INLINE void -argsort_64_64bit(type_t *arr, arrsize_t *arg, int32_t N) -{ - if (N <= 32) { - argsort_32_64bit(arr, arg, N); - return; - } - using reg_t = typename vtype::reg_t; - using opmask_t = typename vtype::opmask_t; - reg_t arrzmm[8]; - argreg_t argzmm[8]; - - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - argzmm[ii] = argtype::loadu(arg + 8 * ii); - arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); - arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); - } - - opmask_t load_mask[4] = {0xFF, 0xFF, 0xFF, 0xFF}; - uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF; - argzmm[ii + 4] = argtype::maskz_loadu(load_mask[ii], arg + 32 + 8 * ii); - arrzmm[ii + 4] = vtype::template mask_i64gather( - vtype::zmm_max(), load_mask[ii], argzmm[ii + 4], arr); - arrzmm[ii + 4] = sort_zmm_64bit(arrzmm[ii + 4], - argzmm[ii + 4]); - } - - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 8; ii = ii + 2) { - bitonic_merge_two_zmm_64bit( - arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); - } - bitonic_merge_four_zmm_64bit(arrzmm, argzmm); - bitonic_merge_four_zmm_64bit(arrzmm + 4, argzmm + 4); - bitonic_merge_eight_zmm_64bit(arrzmm, argzmm); - - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - argtype::storeu(arg + 8 * ii, argzmm[ii]); - } - X86_SIMD_SORT_UNROLL_LOOP(4) - for (int ii = 0; ii < 4; ++ii) { - argtype::mask_storeu(arg + 32 + 8 * ii, load_mask[ii], argzmm[ii + 4]); - } -} - -/* arsort 128 doesn't seem to make much of a difference to perf*/ -//template -//X86_SIMD_SORT_INLINE void -//argsort_128_64bit(type_t *arr, arrsize_t *arg, int32_t N) -//{ -// if (N <= 64) { -// argsort_64_64bit(arr, arg, N); -// return; -// } -// using reg_t = typename vtype::reg_t; -// using opmask_t = typename vtype::opmask_t; -// reg_t arrzmm[16]; -// argreg_t argzmm[16]; -// -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argzmm[ii] = argtype::loadu(arg + 8*ii); -// arrzmm[ii] = vtype::i64gather(argzmm[ii], arr); -// arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); -// } -// -// opmask_t load_mask[8] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; -// if (N != 128) { -// uarrsize_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// load_mask[ii] = (combined_mask >> (ii*8)) & 0xFF; -// } -// } -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argzmm[ii+8] = argtype::maskz_loadu(load_mask[ii], arg + 64 + 8*ii); -// arrzmm[ii+8] = vtype::template mask_i64gather(vtype::zmm_max(), load_mask[ii], argzmm[ii+8], arr); -// arrzmm[ii+8] = sort_zmm_64bit(arrzmm[ii+8], argzmm[ii+8]); -// } -// -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 16; ii = ii + 2) { -// bitonic_merge_two_zmm_64bit(arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); -// } -// bitonic_merge_four_zmm_64bit(arrzmm, argzmm); -// bitonic_merge_four_zmm_64bit(arrzmm + 4, argzmm + 4); -// bitonic_merge_four_zmm_64bit(arrzmm + 8, argzmm + 8); -// bitonic_merge_four_zmm_64bit(arrzmm + 12, argzmm + 12); -// bitonic_merge_eight_zmm_64bit(arrzmm, argzmm); -// bitonic_merge_eight_zmm_64bit(arrzmm+8, argzmm+8); -// bitonic_merge_sixteen_zmm_64bit(arrzmm, argzmm); -// -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argtype::storeu(arg + 8*ii, argzmm[ii]); -// } -//X86_SIMD_SORT_UNROLL_LOOP(8) -// for (int ii = 0; ii < 8; ++ii) { -// argtype::mask_storeu(arg + 64 + 8*ii, load_mask[ii], argzmm[ii + 8]); -// } -//} - template X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, arrsize_t *arg, @@ -586,7 +397,7 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, * Base case: use bitonic networks to sort arrays <= 64 */ if (right + 1 - left <= 64) { - argsort_64_64bit(arr, arg + left, (int32_t)(right + 1 - left)); + argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); @@ -619,7 +430,7 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, * Base case: use bitonic networks to sort arrays <= 64 */ if (right + 1 - left <= 64) { - argsort_64_64bit(arr, arg + left, (int32_t)(right + 1 - left)); + argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 909f3b2b..d0fed753 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -8,6 +8,7 @@ #define AVX512_64BIT_COMMON #include "xss-common-includes.h" +#include "avx2-32bit-qsort.hpp" /* * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic @@ -32,6 +33,8 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + + using swizzle_ops = avx2_32bit_swizzle_ops; static type_t type_max() { @@ -194,6 +197,19 @@ struct ymm_vector { { _mm256_storeu_ps((float *)mem, x); } + static reg_t cast_from(__m256i v) + { + return _mm256_castsi256_ps(v); + } + static __m256i cast_to(reg_t v) + { + return _mm256_castps_si256(v); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } }; template <> struct ymm_vector { @@ -202,6 +218,8 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + + using swizzle_ops = avx2_32bit_swizzle_ops; static type_t type_max() { @@ -354,6 +372,19 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } }; template <> struct ymm_vector { @@ -362,6 +393,8 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + + using swizzle_ops = avx2_32bit_swizzle_ops; static type_t type_max() { @@ -514,6 +547,19 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } }; template <> struct zmm_vector { diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 55f79bb1..d28ad61f 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -182,355 +182,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t1 *keys, return l_store; } -template -X86_SIMD_SORT_INLINE void -sort_8_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - typename vtype1::opmask_t load_mask = (0x01 << N) - 0x01; - typename vtype1::reg_t key_zmm - = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys); - - typename vtype2::reg_t index_zmm - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes); - vtype1::mask_storeu(keys, - load_mask, - sort_zmm_64bit(key_zmm, index_zmm)); - vtype2::mask_storeu(indexes, load_mask, index_zmm); -} - -template -X86_SIMD_SORT_INLINE void -sort_16_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 8) { - sort_8_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using index_type = typename vtype2::reg_t; - - typename vtype1::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; - - reg_t key_zmm1 = vtype1::loadu(keys); - reg_t key_zmm2 = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys + 8); - - index_type index_zmm1 = vtype2::loadu(indexes); - index_type index_zmm2 - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes + 8); - - key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); - key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); - bitonic_merge_two_zmm_64bit( - key_zmm1, key_zmm2, index_zmm1, index_zmm2); - - vtype2::storeu(indexes, index_zmm1); - vtype2::mask_storeu(indexes + 8, load_mask, index_zmm2); - - vtype1::storeu(keys, key_zmm1); - vtype1::mask_storeu(keys + 8, load_mask, key_zmm2); -} - -template -X86_SIMD_SORT_INLINE void -sort_32_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 16) { - sort_16_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using opmask_t = typename vtype2::opmask_t; - using index_type = typename vtype2::reg_t; - reg_t key_zmm[4]; - index_type index_zmm[4]; - - key_zmm[0] = vtype1::loadu(keys); - key_zmm[1] = vtype1::loadu(keys + 8); - - index_zmm[0] = vtype2::loadu(indexes); - index_zmm[1] = vtype2::loadu(indexes + 8); - - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - key_zmm[2] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 16); - key_zmm[3] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 24); - - index_zmm[2] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 16); - index_zmm[3] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 24); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - - vtype2::storeu(indexes, index_zmm[0]); - vtype2::storeu(indexes + 8, index_zmm[1]); - vtype2::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); - vtype2::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); - - vtype1::storeu(keys, key_zmm[0]); - vtype1::storeu(keys + 8, key_zmm[1]); - vtype1::mask_storeu(keys + 16, load_mask1, key_zmm[2]); - vtype1::mask_storeu(keys + 24, load_mask2, key_zmm[3]); -} - -template -X86_SIMD_SORT_INLINE void -sort_64_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 32) { - sort_32_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using opmask_t = typename vtype1::opmask_t; - using index_type = typename vtype2::reg_t; - reg_t key_zmm[8]; - index_type index_zmm[8]; - - key_zmm[0] = vtype1::loadu(keys); - key_zmm[1] = vtype1::loadu(keys + 8); - key_zmm[2] = vtype1::loadu(keys + 16); - key_zmm[3] = vtype1::loadu(keys + 24); - - index_zmm[0] = vtype2::loadu(indexes); - index_zmm[1] = vtype2::loadu(indexes + 8); - index_zmm[2] = vtype2::loadu(indexes + 16); - index_zmm[3] = vtype2::loadu(indexes + 24); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - // N-32 >= 1 - uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - key_zmm[4] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 32); - key_zmm[5] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 40); - key_zmm[6] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 48); - key_zmm[7] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 56); - - index_zmm[4] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 32); - index_zmm[5] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 40); - index_zmm[6] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 48); - index_zmm[7] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 56); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); - - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( - key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( - key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - - vtype2::storeu(indexes, index_zmm[0]); - vtype2::storeu(indexes + 8, index_zmm[1]); - vtype2::storeu(indexes + 16, index_zmm[2]); - vtype2::storeu(indexes + 24, index_zmm[3]); - vtype2::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); - vtype2::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); - vtype2::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); - vtype2::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); - - vtype1::storeu(keys, key_zmm[0]); - vtype1::storeu(keys + 8, key_zmm[1]); - vtype1::storeu(keys + 16, key_zmm[2]); - vtype1::storeu(keys + 24, key_zmm[3]); - vtype1::mask_storeu(keys + 32, load_mask1, key_zmm[4]); - vtype1::mask_storeu(keys + 40, load_mask2, key_zmm[5]); - vtype1::mask_storeu(keys + 48, load_mask3, key_zmm[6]); - vtype1::mask_storeu(keys + 56, load_mask4, key_zmm[7]); -} - -template -X86_SIMD_SORT_INLINE void -sort_128_64bit(type1_t *keys, type2_t *indexes, int32_t N) -{ - if (N <= 64) { - sort_64_64bit(keys, indexes, N); - return; - } - using reg_t = typename vtype1::reg_t; - using index_type = typename vtype2::reg_t; - using opmask_t = typename vtype1::opmask_t; - reg_t key_zmm[16]; - index_type index_zmm[16]; - - key_zmm[0] = vtype1::loadu(keys); - key_zmm[1] = vtype1::loadu(keys + 8); - key_zmm[2] = vtype1::loadu(keys + 16); - key_zmm[3] = vtype1::loadu(keys + 24); - key_zmm[4] = vtype1::loadu(keys + 32); - key_zmm[5] = vtype1::loadu(keys + 40); - key_zmm[6] = vtype1::loadu(keys + 48); - key_zmm[7] = vtype1::loadu(keys + 56); - - index_zmm[0] = vtype2::loadu(indexes); - index_zmm[1] = vtype2::loadu(indexes + 8); - index_zmm[2] = vtype2::loadu(indexes + 16); - index_zmm[3] = vtype2::loadu(indexes + 24); - index_zmm[4] = vtype2::loadu(indexes + 32); - index_zmm[5] = vtype2::loadu(indexes + 40); - index_zmm[6] = vtype2::loadu(indexes + 48); - index_zmm[7] = vtype2::loadu(indexes + 56); - key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); - key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); - key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); - key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); - key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); - key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); - key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); - key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); - - opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; - opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; - opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; - opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; - if (N != 128) { - uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; - load_mask1 = (combined_mask)&0xFF; - load_mask2 = (combined_mask >> 8) & 0xFF; - load_mask3 = (combined_mask >> 16) & 0xFF; - load_mask4 = (combined_mask >> 24) & 0xFF; - load_mask5 = (combined_mask >> 32) & 0xFF; - load_mask6 = (combined_mask >> 40) & 0xFF; - load_mask7 = (combined_mask >> 48) & 0xFF; - load_mask8 = (combined_mask >> 56) & 0xFF; - } - key_zmm[8] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 64); - key_zmm[9] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 72); - key_zmm[10] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 80); - key_zmm[11] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 88); - key_zmm[12] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask5, keys + 96); - key_zmm[13] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask6, keys + 104); - key_zmm[14] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask7, keys + 112); - key_zmm[15] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask8, keys + 120); - - index_zmm[8] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 64); - index_zmm[9] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 72); - index_zmm[10] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 80); - index_zmm[11] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 88); - index_zmm[12] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask5, indexes + 96); - index_zmm[13] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask6, indexes + 104); - index_zmm[14] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask7, indexes + 112); - index_zmm[15] - = vtype2::mask_loadu(vtype2::zmm_max(), load_mask8, indexes + 120); - key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); - key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); - key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); - key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); - key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); - key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); - key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); - key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); - - bitonic_merge_two_zmm_64bit( - key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); - bitonic_merge_two_zmm_64bit( - key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); - bitonic_merge_two_zmm_64bit( - key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); - bitonic_merge_two_zmm_64bit( - key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); - bitonic_merge_two_zmm_64bit( - key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); - bitonic_merge_two_zmm_64bit( - key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); - bitonic_merge_two_zmm_64bit( - key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); - bitonic_merge_two_zmm_64bit( - key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); - bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); - bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); - bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); - bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); - bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); - vtype2::storeu(indexes, index_zmm[0]); - vtype2::storeu(indexes + 8, index_zmm[1]); - vtype2::storeu(indexes + 16, index_zmm[2]); - vtype2::storeu(indexes + 24, index_zmm[3]); - vtype2::storeu(indexes + 32, index_zmm[4]); - vtype2::storeu(indexes + 40, index_zmm[5]); - vtype2::storeu(indexes + 48, index_zmm[6]); - vtype2::storeu(indexes + 56, index_zmm[7]); - vtype2::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); - vtype2::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); - vtype2::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); - vtype2::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); - vtype2::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); - vtype2::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); - vtype2::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); - vtype2::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); - - vtype1::storeu(keys, key_zmm[0]); - vtype1::storeu(keys + 8, key_zmm[1]); - vtype1::storeu(keys + 16, key_zmm[2]); - vtype1::storeu(keys + 24, key_zmm[3]); - vtype1::storeu(keys + 32, key_zmm[4]); - vtype1::storeu(keys + 40, key_zmm[5]); - vtype1::storeu(keys + 48, key_zmm[6]); - vtype1::storeu(keys + 56, key_zmm[7]); - vtype1::mask_storeu(keys + 64, load_mask1, key_zmm[8]); - vtype1::mask_storeu(keys + 72, load_mask2, key_zmm[9]); - vtype1::mask_storeu(keys + 80, load_mask3, key_zmm[10]); - vtype1::mask_storeu(keys + 88, load_mask4, key_zmm[11]); - vtype1::mask_storeu(keys + 96, load_mask5, key_zmm[12]); - vtype1::mask_storeu(keys + 104, load_mask6, key_zmm[13]); - vtype1::mask_storeu(keys + 112, load_mask7, key_zmm[14]); - vtype1::mask_storeu(keys + 120, load_mask8, key_zmm[15]); -} template struct index_64bit_vector_type; +template <> struct index_64bit_vector_type<8> { using type = zmm_vector;}; +template <> struct index_64bit_vector_type<4> { using type = avx2_vector;}; template -X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(reg_t &key_zmm1, - reg_t &key_zmm2, - index_type &index_zmm1, - index_type &index_zmm2) -{ - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - // 1) First step of a merging network: coex of zmm1 and zmm2 reversed - key_zmm2 = vtype1::permutexvar(rev_index1, key_zmm2); - index_zmm2 = vtype2::permutexvar(rev_index2, index_zmm2); - reg_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2); - reg_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2); +template +X86_SIMD_SORT_INLINE void bitonic_merge_dispatch(typename keyType::reg_t &key, typename valueType::reg_t &value){ + constexpr int numlanes = keyType::numlanes; + if constexpr (numlanes == 8){ + key = bitonic_merge_zmm_64bit(key, value); + }else{ + static_assert(numlanes == -1, "should not reach here"); + UNUSED(key); + UNUSED(value); + } +} - typename vtype1::opmask_t movmask = vtype1::eq(key_zmm3, key_zmm1); +template +X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, typename valueType::reg_t &value){ + constexpr int numlanes = keyType::numlanes; + if constexpr (numlanes == 8){ + key = sort_zmm_64bit(key, value); + }else{ + static_assert(numlanes == -1, "should not reach here"); + UNUSED(key); + UNUSED(value); + } +} - index_type index_zmm3 = vtype2::mask_mov(index_zmm2, movmask, index_zmm1); - index_type index_zmm4 = vtype2::mask_mov(index_zmm1, movmask, index_zmm2); - /* need to reverse the lower registers to keep the correct order */ - key_zmm4 = vtype1::permutexvar(rev_index1, key_zmm4); - index_zmm4 = vtype2::permutexvar(rev_index2, index_zmm4); - // 2) Recursive half cleaner for each - key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); - index_zmm1 = index_zmm3; - index_zmm2 = index_zmm4; +template +X86_SIMD_SORT_INLINE void bitonic_clean_n_vec(typename keyType::reg_t *keys, typename valueType::reg_t *values) +{ + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int num = numVecs / 2; num >= 2; num /= 2) { + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int j = 0; j < numVecs; j += num) { + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < num / 2; i++) { + arrsize_t index1 = i + j; + arrsize_t index2 = i + j + num / 2; + COEX(keys[index1], keys[index2], values[index1], values[index2]); + } + } + } } -// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive -// half cleaner -template -X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(reg_t *key_zmm, - index_type *index_zmm) + + +template +X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(typename keyType::reg_t *keys, typename valueType::reg_t *values) +{ + // Do the reverse part + if constexpr (numVecs == 2) { + keys[1] = keyType::reverse(keys[1]); + values[1] = valueType::reverse(values[1]); + COEX(keys[0], keys[1], values[0], values[1]); + keys[1] = keyType::reverse(keys[1]); + values[1] = valueType::reverse(values[1]); + } + else if constexpr (numVecs > 2) { + // Reverse upper half + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + keys[numVecs - i - 1] = keyType::reverse(keys[numVecs - i - 1]); + values[numVecs - i - 1] = valueType::reverse(values[numVecs - i - 1]); + + COEX(keys[i], keys[numVecs - i - 1], values[i], values[numVecs - i - 1]); + + keys[numVecs - i - 1] = keyType::reverse(keys[numVecs - i - 1]); + values[numVecs - i - 1] = valueType::reverse(values[numVecs - i - 1]); + } + } + + // Call cleaner + bitonic_clean_n_vec(keys, values); + + // Now do bitonic_merge + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + bitonic_merge_dispatch(keys[i], values[i]); + } +} + +template +X86_SIMD_SORT_INLINE void bitonic_fullmerge_n_vec(typename keyType::reg_t *keys, typename valueType::reg_t *values) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - // 1) First step of a merging network - reg_t key_zmm2r = vtype1::permutexvar(rev_index1, key_zmm[2]); - reg_t key_zmm3r = vtype1::permutexvar(rev_index1, key_zmm[3]); - index_type index_zmm2r = vtype2::permutexvar(rev_index2, index_zmm[2]); - index_type index_zmm3r = vtype2::permutexvar(rev_index2, index_zmm[3]); - - reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm3r); - reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm2r); - reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r); - reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r); - - typename vtype1::opmask_t movmask1 = vtype1::eq(key_reg_t1, key_zmm[0]); - typename vtype1::opmask_t movmask2 = vtype1::eq(key_reg_t2, key_zmm[1]); - - index_type index_reg_t1 - = vtype2::mask_mov(index_zmm3r, movmask1, index_zmm[0]); - index_type index_zmm_m1 - = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm3r); - index_type index_reg_t2 - = vtype2::mask_mov(index_zmm2r, movmask2, index_zmm[1]); - index_type index_zmm_m2 - = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm2r); - - // 2) Recursive half clearer: 16 - reg_t key_reg_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2); - reg_t key_reg_t4 = vtype1::permutexvar(rev_index1, key_zmm_m1); - index_type index_reg_t3 = vtype2::permutexvar(rev_index2, index_zmm_m2); - index_type index_reg_t4 = vtype2::permutexvar(rev_index2, index_zmm_m1); - - reg_t key_zmm0 = vtype1::min(key_reg_t1, key_reg_t2); - reg_t key_zmm1 = vtype1::max(key_reg_t1, key_reg_t2); - reg_t key_zmm2 = vtype1::min(key_reg_t3, key_reg_t4); - reg_t key_zmm3 = vtype1::max(key_reg_t3, key_reg_t4); - - movmask1 = vtype1::eq(key_zmm0, key_reg_t1); - movmask2 = vtype1::eq(key_zmm2, key_reg_t3); - - index_type index_zmm0 - = vtype2::mask_mov(index_reg_t2, movmask1, index_reg_t1); - index_type index_zmm1 - = vtype2::mask_mov(index_reg_t1, movmask1, index_reg_t2); - index_type index_zmm2 - = vtype2::mask_mov(index_reg_t4, movmask2, index_reg_t3); - index_type index_zmm3 - = vtype2::mask_mov(index_reg_t3, movmask2, index_reg_t4); - - key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); - key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); - key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); - key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); - - index_zmm[0] = index_zmm0; - index_zmm[1] = index_zmm1; - index_zmm[2] = index_zmm2; - index_zmm[3] = index_zmm3; + if constexpr (numPer > numVecs) { + UNUSED(keys); + UNUSED(values); + return; + } + else { + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / numPer; i++) { + bitonic_merge_n_vec(keys + i * numPer, values + i * numPer); + } + bitonic_fullmerge_n_vec(keys, values); + } } -template -X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(reg_t *key_zmm, - index_type *index_zmm) +template +X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, typename indexType::type_t *indices, int N) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - reg_t key_zmm4r = vtype1::permutexvar(rev_index1, key_zmm[4]); - reg_t key_zmm5r = vtype1::permutexvar(rev_index1, key_zmm[5]); - reg_t key_zmm6r = vtype1::permutexvar(rev_index1, key_zmm[6]); - reg_t key_zmm7r = vtype1::permutexvar(rev_index1, key_zmm[7]); - index_type index_zmm4r = vtype2::permutexvar(rev_index2, index_zmm[4]); - index_type index_zmm5r = vtype2::permutexvar(rev_index2, index_zmm[5]); - index_type index_zmm6r = vtype2::permutexvar(rev_index2, index_zmm[6]); - index_type index_zmm7r = vtype2::permutexvar(rev_index2, index_zmm[7]); - - reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm7r); - reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm6r); - reg_t key_reg_t3 = vtype1::min(key_zmm[2], key_zmm5r); - reg_t key_reg_t4 = vtype1::min(key_zmm[3], key_zmm4r); - - reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm7r); - reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm6r); - reg_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r); - reg_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r); - - typename vtype1::opmask_t movmask1 = vtype1::eq(key_reg_t1, key_zmm[0]); - typename vtype1::opmask_t movmask2 = vtype1::eq(key_reg_t2, key_zmm[1]); - typename vtype1::opmask_t movmask3 = vtype1::eq(key_reg_t3, key_zmm[2]); - typename vtype1::opmask_t movmask4 = vtype1::eq(key_reg_t4, key_zmm[3]); - - index_type index_reg_t1 - = vtype2::mask_mov(index_zmm7r, movmask1, index_zmm[0]); - index_type index_zmm_m1 - = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm7r); - index_type index_reg_t2 - = vtype2::mask_mov(index_zmm6r, movmask2, index_zmm[1]); - index_type index_zmm_m2 - = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm6r); - index_type index_reg_t3 - = vtype2::mask_mov(index_zmm5r, movmask3, index_zmm[2]); - index_type index_zmm_m3 - = vtype2::mask_mov(index_zmm[2], movmask3, index_zmm5r); - index_type index_reg_t4 - = vtype2::mask_mov(index_zmm4r, movmask4, index_zmm[3]); - index_type index_zmm_m4 - = vtype2::mask_mov(index_zmm[3], movmask4, index_zmm4r); - - reg_t key_reg_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4); - reg_t key_reg_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3); - reg_t key_reg_t7 = vtype1::permutexvar(rev_index1, key_zmm_m2); - reg_t key_reg_t8 = vtype1::permutexvar(rev_index1, key_zmm_m1); - index_type index_reg_t5 = vtype2::permutexvar(rev_index2, index_zmm_m4); - index_type index_reg_t6 = vtype2::permutexvar(rev_index2, index_zmm_m3); - index_type index_reg_t7 = vtype2::permutexvar(rev_index2, index_zmm_m2); - index_type index_reg_t8 = vtype2::permutexvar(rev_index2, index_zmm_m1); - - COEX(key_reg_t1, key_reg_t3, index_reg_t1, index_reg_t3); - COEX(key_reg_t2, key_reg_t4, index_reg_t2, index_reg_t4); - COEX(key_reg_t5, key_reg_t7, index_reg_t5, index_reg_t7); - COEX(key_reg_t6, key_reg_t8, index_reg_t6, index_reg_t8); - COEX(key_reg_t1, key_reg_t2, index_reg_t1, index_reg_t2); - COEX(key_reg_t3, key_reg_t4, index_reg_t3, index_reg_t4); - COEX(key_reg_t5, key_reg_t6, index_reg_t5, index_reg_t6); - COEX(key_reg_t7, key_reg_t8, index_reg_t7, index_reg_t8); - key_zmm[0] - = bitonic_merge_zmm_64bit(key_reg_t1, index_reg_t1); - key_zmm[1] - = bitonic_merge_zmm_64bit(key_reg_t2, index_reg_t2); - key_zmm[2] - = bitonic_merge_zmm_64bit(key_reg_t3, index_reg_t3); - key_zmm[3] - = bitonic_merge_zmm_64bit(key_reg_t4, index_reg_t4); - key_zmm[4] - = bitonic_merge_zmm_64bit(key_reg_t5, index_reg_t5); - key_zmm[5] - = bitonic_merge_zmm_64bit(key_reg_t6, index_reg_t6); - key_zmm[6] - = bitonic_merge_zmm_64bit(key_reg_t7, index_reg_t7); - key_zmm[7] - = bitonic_merge_zmm_64bit(key_reg_t8, index_reg_t8); - - index_zmm[0] = index_reg_t1; - index_zmm[1] = index_reg_t2; - index_zmm[2] = index_reg_t3; - index_zmm[3] = index_reg_t4; - index_zmm[4] = index_reg_t5; - index_zmm[5] = index_reg_t6; - index_zmm[6] = index_reg_t7; - index_zmm[7] = index_reg_t8; + using kreg_t = typename keyType::reg_t; + using ireg_t = typename indexType::reg_t; + + static_assert(numVecs > 0, "numVecs should be > 0"); + if constexpr (numVecs > 1) { + if (N * 2 <= numVecs * keyType::numlanes) { + argsort_n_vec(keys, indices, N); + return; + } + } + + kreg_t keyVecs[numVecs]; + ireg_t indexVecs[numVecs]; + + // Generate masks for loading and storing + typename keyType::opmask_t ioMasks[numVecs - numVecs / 2]; + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), + (uint64_t)keyType::numlanes); + ioMasks[j] = keyType::get_partial_loadmask(num_to_read); + } + + // Unmasked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + indexVecs[i] = indexType::loadu(indices + i * indexType::numlanes); + keyVecs[i] = keyType::i64gather(keys, indices + i * indexType::numlanes); + } + // Masked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + indexVecs[i] = indexType::mask_loadu( + indexType::zmm_max(), ioMasks[j], indices + i * indexType::numlanes); + + keyVecs[i] = keyType::template mask_i64gather(keyType::zmm_max(), ioMasks[j], indexVecs[i], keys); + } + + // Sort each loaded vector + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + sort_vec_dispatch(keyVecs[i], indexVecs[i]); + } + + // Run the full merger + bitonic_fullmerge_n_vec(keyVecs, indexVecs); + + // Unmasked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + indexType::storeu(indices + i * indexType::numlanes, indexVecs[i]); + } + // Masked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + indexType::mask_storeu(indices + i * indexType::numlanes, ioMasks[j], indexVecs[i]); + } } -template -X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(reg_t *key_zmm, - index_type *index_zmm) +template +X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, typename valueType::type_t *values, int N) { - const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); - const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); - reg_t key_zmm8r = vtype1::permutexvar(rev_index1, key_zmm[8]); - reg_t key_zmm9r = vtype1::permutexvar(rev_index1, key_zmm[9]); - reg_t key_zmm10r = vtype1::permutexvar(rev_index1, key_zmm[10]); - reg_t key_zmm11r = vtype1::permutexvar(rev_index1, key_zmm[11]); - reg_t key_zmm12r = vtype1::permutexvar(rev_index1, key_zmm[12]); - reg_t key_zmm13r = vtype1::permutexvar(rev_index1, key_zmm[13]); - reg_t key_zmm14r = vtype1::permutexvar(rev_index1, key_zmm[14]); - reg_t key_zmm15r = vtype1::permutexvar(rev_index1, key_zmm[15]); - - index_type index_zmm8r = vtype2::permutexvar(rev_index2, index_zmm[8]); - index_type index_zmm9r = vtype2::permutexvar(rev_index2, index_zmm[9]); - index_type index_zmm10r = vtype2::permutexvar(rev_index2, index_zmm[10]); - index_type index_zmm11r = vtype2::permutexvar(rev_index2, index_zmm[11]); - index_type index_zmm12r = vtype2::permutexvar(rev_index2, index_zmm[12]); - index_type index_zmm13r = vtype2::permutexvar(rev_index2, index_zmm[13]); - index_type index_zmm14r = vtype2::permutexvar(rev_index2, index_zmm[14]); - index_type index_zmm15r = vtype2::permutexvar(rev_index2, index_zmm[15]); - - reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm15r); - reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm14r); - reg_t key_reg_t3 = vtype1::min(key_zmm[2], key_zmm13r); - reg_t key_reg_t4 = vtype1::min(key_zmm[3], key_zmm12r); - reg_t key_reg_t5 = vtype1::min(key_zmm[4], key_zmm11r); - reg_t key_reg_t6 = vtype1::min(key_zmm[5], key_zmm10r); - reg_t key_reg_t7 = vtype1::min(key_zmm[6], key_zmm9r); - reg_t key_reg_t8 = vtype1::min(key_zmm[7], key_zmm8r); - - reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm15r); - reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm14r); - reg_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm13r); - reg_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm12r); - reg_t key_zmm_m5 = vtype1::max(key_zmm[4], key_zmm11r); - reg_t key_zmm_m6 = vtype1::max(key_zmm[5], key_zmm10r); - reg_t key_zmm_m7 = vtype1::max(key_zmm[6], key_zmm9r); - reg_t key_zmm_m8 = vtype1::max(key_zmm[7], key_zmm8r); - - index_type index_reg_t1 = vtype2::mask_mov( - index_zmm15r, vtype1::eq(key_reg_t1, key_zmm[0]), index_zmm[0]); - index_type index_zmm_m1 = vtype2::mask_mov( - index_zmm[0], vtype1::eq(key_reg_t1, key_zmm[0]), index_zmm15r); - index_type index_reg_t2 = vtype2::mask_mov( - index_zmm14r, vtype1::eq(key_reg_t2, key_zmm[1]), index_zmm[1]); - index_type index_zmm_m2 = vtype2::mask_mov( - index_zmm[1], vtype1::eq(key_reg_t2, key_zmm[1]), index_zmm14r); - index_type index_reg_t3 = vtype2::mask_mov( - index_zmm13r, vtype1::eq(key_reg_t3, key_zmm[2]), index_zmm[2]); - index_type index_zmm_m3 = vtype2::mask_mov( - index_zmm[2], vtype1::eq(key_reg_t3, key_zmm[2]), index_zmm13r); - index_type index_reg_t4 = vtype2::mask_mov( - index_zmm12r, vtype1::eq(key_reg_t4, key_zmm[3]), index_zmm[3]); - index_type index_zmm_m4 = vtype2::mask_mov( - index_zmm[3], vtype1::eq(key_reg_t4, key_zmm[3]), index_zmm12r); - - index_type index_reg_t5 = vtype2::mask_mov( - index_zmm11r, vtype1::eq(key_reg_t5, key_zmm[4]), index_zmm[4]); - index_type index_zmm_m5 = vtype2::mask_mov( - index_zmm[4], vtype1::eq(key_reg_t5, key_zmm[4]), index_zmm11r); - index_type index_reg_t6 = vtype2::mask_mov( - index_zmm10r, vtype1::eq(key_reg_t6, key_zmm[5]), index_zmm[5]); - index_type index_zmm_m6 = vtype2::mask_mov( - index_zmm[5], vtype1::eq(key_reg_t6, key_zmm[5]), index_zmm10r); - index_type index_reg_t7 = vtype2::mask_mov( - index_zmm9r, vtype1::eq(key_reg_t7, key_zmm[6]), index_zmm[6]); - index_type index_zmm_m7 = vtype2::mask_mov( - index_zmm[6], vtype1::eq(key_reg_t7, key_zmm[6]), index_zmm9r); - index_type index_reg_t8 = vtype2::mask_mov( - index_zmm8r, vtype1::eq(key_reg_t8, key_zmm[7]), index_zmm[7]); - index_type index_zmm_m8 = vtype2::mask_mov( - index_zmm[7], vtype1::eq(key_reg_t8, key_zmm[7]), index_zmm8r); - - reg_t key_reg_t9 = vtype1::permutexvar(rev_index1, key_zmm_m8); - reg_t key_reg_t10 = vtype1::permutexvar(rev_index1, key_zmm_m7); - reg_t key_reg_t11 = vtype1::permutexvar(rev_index1, key_zmm_m6); - reg_t key_reg_t12 = vtype1::permutexvar(rev_index1, key_zmm_m5); - reg_t key_reg_t13 = vtype1::permutexvar(rev_index1, key_zmm_m4); - reg_t key_reg_t14 = vtype1::permutexvar(rev_index1, key_zmm_m3); - reg_t key_reg_t15 = vtype1::permutexvar(rev_index1, key_zmm_m2); - reg_t key_reg_t16 = vtype1::permutexvar(rev_index1, key_zmm_m1); - index_type index_reg_t9 = vtype2::permutexvar(rev_index2, index_zmm_m8); - index_type index_reg_t10 = vtype2::permutexvar(rev_index2, index_zmm_m7); - index_type index_reg_t11 = vtype2::permutexvar(rev_index2, index_zmm_m6); - index_type index_reg_t12 = vtype2::permutexvar(rev_index2, index_zmm_m5); - index_type index_reg_t13 = vtype2::permutexvar(rev_index2, index_zmm_m4); - index_type index_reg_t14 = vtype2::permutexvar(rev_index2, index_zmm_m3); - index_type index_reg_t15 = vtype2::permutexvar(rev_index2, index_zmm_m2); - index_type index_reg_t16 = vtype2::permutexvar(rev_index2, index_zmm_m1); - - COEX(key_reg_t1, key_reg_t5, index_reg_t1, index_reg_t5); - COEX(key_reg_t2, key_reg_t6, index_reg_t2, index_reg_t6); - COEX(key_reg_t3, key_reg_t7, index_reg_t3, index_reg_t7); - COEX(key_reg_t4, key_reg_t8, index_reg_t4, index_reg_t8); - COEX(key_reg_t9, key_reg_t13, index_reg_t9, index_reg_t13); - COEX( - key_reg_t10, key_reg_t14, index_reg_t10, index_reg_t14); - COEX( - key_reg_t11, key_reg_t15, index_reg_t11, index_reg_t15); - COEX( - key_reg_t12, key_reg_t16, index_reg_t12, index_reg_t16); - - COEX(key_reg_t1, key_reg_t3, index_reg_t1, index_reg_t3); - COEX(key_reg_t2, key_reg_t4, index_reg_t2, index_reg_t4); - COEX(key_reg_t5, key_reg_t7, index_reg_t5, index_reg_t7); - COEX(key_reg_t6, key_reg_t8, index_reg_t6, index_reg_t8); - COEX(key_reg_t9, key_reg_t11, index_reg_t9, index_reg_t11); - COEX( - key_reg_t10, key_reg_t12, index_reg_t10, index_reg_t12); - COEX( - key_reg_t13, key_reg_t15, index_reg_t13, index_reg_t15); - COEX( - key_reg_t14, key_reg_t16, index_reg_t14, index_reg_t16); - - COEX(key_reg_t1, key_reg_t2, index_reg_t1, index_reg_t2); - COEX(key_reg_t3, key_reg_t4, index_reg_t3, index_reg_t4); - COEX(key_reg_t5, key_reg_t6, index_reg_t5, index_reg_t6); - COEX(key_reg_t7, key_reg_t8, index_reg_t7, index_reg_t8); - COEX(key_reg_t9, key_reg_t10, index_reg_t9, index_reg_t10); - COEX( - key_reg_t11, key_reg_t12, index_reg_t11, index_reg_t12); - COEX( - key_reg_t13, key_reg_t14, index_reg_t13, index_reg_t14); - COEX( - key_reg_t15, key_reg_t16, index_reg_t15, index_reg_t16); - // - key_zmm[0] - = bitonic_merge_zmm_64bit(key_reg_t1, index_reg_t1); - key_zmm[1] - = bitonic_merge_zmm_64bit(key_reg_t2, index_reg_t2); - key_zmm[2] - = bitonic_merge_zmm_64bit(key_reg_t3, index_reg_t3); - key_zmm[3] - = bitonic_merge_zmm_64bit(key_reg_t4, index_reg_t4); - key_zmm[4] - = bitonic_merge_zmm_64bit(key_reg_t5, index_reg_t5); - key_zmm[5] - = bitonic_merge_zmm_64bit(key_reg_t6, index_reg_t6); - key_zmm[6] - = bitonic_merge_zmm_64bit(key_reg_t7, index_reg_t7); - key_zmm[7] - = bitonic_merge_zmm_64bit(key_reg_t8, index_reg_t8); - key_zmm[8] - = bitonic_merge_zmm_64bit(key_reg_t9, index_reg_t9); - key_zmm[9] = bitonic_merge_zmm_64bit(key_reg_t10, - index_reg_t10); - key_zmm[10] = bitonic_merge_zmm_64bit(key_reg_t11, - index_reg_t11); - key_zmm[11] = bitonic_merge_zmm_64bit(key_reg_t12, - index_reg_t12); - key_zmm[12] = bitonic_merge_zmm_64bit(key_reg_t13, - index_reg_t13); - key_zmm[13] = bitonic_merge_zmm_64bit(key_reg_t14, - index_reg_t14); - key_zmm[14] = bitonic_merge_zmm_64bit(key_reg_t15, - index_reg_t15); - key_zmm[15] = bitonic_merge_zmm_64bit(key_reg_t16, - index_reg_t16); - - index_zmm[0] = index_reg_t1; - index_zmm[1] = index_reg_t2; - index_zmm[2] = index_reg_t3; - index_zmm[3] = index_reg_t4; - index_zmm[4] = index_reg_t5; - index_zmm[5] = index_reg_t6; - index_zmm[6] = index_reg_t7; - index_zmm[7] = index_reg_t8; - index_zmm[8] = index_reg_t9; - index_zmm[9] = index_reg_t10; - index_zmm[10] = index_reg_t11; - index_zmm[11] = index_reg_t12; - index_zmm[12] = index_reg_t13; - index_zmm[13] = index_reg_t14; - index_zmm[14] = index_reg_t15; - index_zmm[15] = index_reg_t16; + using kreg_t = typename keyType::reg_t; + using vreg_t = typename valueType::reg_t; + + static_assert(numVecs > 0, "numVecs should be > 0"); + if constexpr (numVecs > 1) { + if (N * 2 <= numVecs * keyType::numlanes) { + kvsort_n_vec(keys, values, N); + return; + } + } + + kreg_t keyVecs[numVecs]; + vreg_t valueVecs[numVecs]; + + // Generate masks for loading and storing + typename keyType::opmask_t ioMasks[numVecs - numVecs / 2]; + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), + (uint64_t)keyType::numlanes); + ioMasks[j] = keyType::get_partial_loadmask(num_to_read); + } + + // Unmasked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + keyVecs[i] = keyType::loadu(keys + i * keyType::numlanes); + valueVecs[i] = valueType::loadu(values + i * valueType::numlanes); + } + // Masked part of the load + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + keyVecs[i] = keyType::mask_loadu( + keyType::zmm_max(), ioMasks[j], keys + i * keyType::numlanes); + valueVecs[i] = valueType::mask_loadu( + valueType::zmm_max(), ioMasks[j], values + i * valueType::numlanes); + } + + // Sort each loaded vector + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs; i++) { + sort_vec_dispatch(keyVecs[i], valueVecs[i]); + } + + // Run the full merger + bitonic_fullmerge_n_vec(keyVecs, valueVecs); + + // Unmasked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = 0; i < numVecs / 2; i++) { + keyType::storeu(keys + i * keyType::numlanes, keyVecs[i]); + valueType::storeu(values + i * valueType::numlanes, valueVecs[i]); + } + // Masked part of the store + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + keyType::mask_storeu(keys + i * keyType::numlanes, ioMasks[j], keyVecs[i]); + valueType::mask_storeu(values + i * valueType::numlanes, ioMasks[j], valueVecs[i]); + } +} + +template +X86_SIMD_SORT_INLINE void argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N) +{ + using indexType = typename index_64bit_vector_type::type; + + static_assert(keyType::numlanes == indexType::numlanes, "invalid pairing of value/index types"); + constexpr int numVecs = maxN / keyType::numlanes; + constexpr bool isMultiple = (maxN == (keyType::numlanes * numVecs)); + constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); + static_assert(powerOfTwo == true && isMultiple == true, + "maxN must be keyType::numlanes times a power of 2"); + + argsort_n_vec(keys, indices, N); } -#endif // AVX512_KEYVALUE_NETWORKS + +template +X86_SIMD_SORT_INLINE void kvsort_n(typename keyType::type_t *keys, typename valueType::type_t *values, int N) +{ + static_assert(keyType::numlanes == valueType::numlanes, "invalid pairing of key/value types"); + + constexpr int numVecs = maxN / keyType::numlanes; + constexpr bool isMultiple = (maxN == (keyType::numlanes * numVecs)); + constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); + static_assert(powerOfTwo == true && isMultiple == true, + "maxN must be keyType::numlanes times a power of 2"); + + kvsort_n_vec(keys, values, N); +} + +#endif \ No newline at end of file From dc767bdc01f3dedc4f8386e1430dd14a81d0ddad Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 14 Nov 2023 11:46:41 -0800 Subject: [PATCH 2/9] Enabled kvsort to build, it may be broken however --- src/avx512-64bit-keyvaluesort.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index d28ad61f..4d50d5b7 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -245,7 +245,7 @@ X86_SIMD_SORT_INLINE void qsort_64bit_(type1_t *keys, */ if (right + 1 - left <= 128) { - sort_128_64bit( + kvsort_n( keys + left, indexes + left, (int32_t)(right + 1 - left)); return; } From cac3f7db2dabb3d35a435f5af0707db692276fe9 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 14 Nov 2023 11:48:54 -0800 Subject: [PATCH 3/9] clang-format --- src/avx512-64bit-common.h | 6 +- src/avx512-64bit-keyvaluesort.hpp | 2 - src/xss-network-keyvaluesort.hpp | 151 ++++++++++++++++++------------ 3 files changed, 96 insertions(+), 63 deletions(-) diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index d0fed753..bde8dd71 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -33,7 +33,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; - + using swizzle_ops = avx2_32bit_swizzle_ops; static type_t type_max() @@ -218,7 +218,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; - + using swizzle_ops = avx2_32bit_swizzle_ops; static type_t type_max() @@ -393,7 +393,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; - + using swizzle_ops = avx2_32bit_swizzle_ops; static type_t type_max() diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 4d50d5b7..1f446c68 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -182,8 +182,6 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t1 *keys, return l_store; } - - template struct index_64bit_vector_type; -template <> struct index_64bit_vector_type<8> { using type = zmm_vector;}; -template <> struct index_64bit_vector_type<4> { using type = avx2_vector;}; +template +struct index_64bit_vector_type; +template <> +struct index_64bit_vector_type<8> { + using type = zmm_vector; +}; +template <> +struct index_64bit_vector_type<4> { + using type = avx2_vector; +}; template -X86_SIMD_SORT_INLINE void bitonic_merge_dispatch(typename keyType::reg_t &key, typename valueType::reg_t &value){ +X86_SIMD_SORT_INLINE void +bitonic_merge_dispatch(typename keyType::reg_t &key, + typename valueType::reg_t &value) +{ constexpr int numlanes = keyType::numlanes; - if constexpr (numlanes == 8){ - key = bitonic_merge_zmm_64bit(key, value); - }else{ + if constexpr (numlanes == 8) { + key = bitonic_merge_zmm_64bit(key, value); + } + else { static_assert(numlanes == -1, "should not reach here"); UNUSED(key); UNUSED(value); @@ -138,23 +149,23 @@ X86_SIMD_SORT_INLINE void bitonic_merge_dispatch(typename keyType::reg_t &key, t } template -X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, typename valueType::reg_t &value){ +X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, + typename valueType::reg_t &value) +{ constexpr int numlanes = keyType::numlanes; - if constexpr (numlanes == 8){ - key = sort_zmm_64bit(key, value); - }else{ + if constexpr (numlanes == 8) { + key = sort_zmm_64bit(key, value); + } + else { static_assert(numlanes == -1, "should not reach here"); UNUSED(key); UNUSED(value); } } - - -template -X86_SIMD_SORT_INLINE void bitonic_clean_n_vec(typename keyType::reg_t *keys, typename valueType::reg_t *values) +template +X86_SIMD_SORT_INLINE void bitonic_clean_n_vec(typename keyType::reg_t *keys, + typename valueType::reg_t *values) { X86_SIMD_SORT_UNROLL_LOOP(64) for (int num = numVecs / 2; num >= 2; num /= 2) { @@ -164,18 +175,19 @@ X86_SIMD_SORT_INLINE void bitonic_clean_n_vec(typename keyType::reg_t *keys, typ for (int i = 0; i < num / 2; i++) { arrsize_t index1 = i + j; arrsize_t index2 = i + j + num / 2; - COEX(keys[index1], keys[index2], values[index1], values[index2]); + COEX(keys[index1], + keys[index2], + values[index1], + values[index2]); } } } } - -template -X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(typename keyType::reg_t *keys, typename valueType::reg_t *values) -{ +template +X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(typename keyType::reg_t *keys, + typename valueType::reg_t *values) +{ // Do the reverse part if constexpr (numVecs == 2) { keys[1] = keyType::reverse(keys[1]); @@ -189,12 +201,17 @@ X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(typename keyType::reg_t *keys, typ X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs / 2; i++) { keys[numVecs - i - 1] = keyType::reverse(keys[numVecs - i - 1]); - values[numVecs - i - 1] = valueType::reverse(values[numVecs - i - 1]); - - COEX(keys[i], keys[numVecs - i - 1], values[i], values[numVecs - i - 1]); - + values[numVecs - i - 1] + = valueType::reverse(values[numVecs - i - 1]); + + COEX(keys[i], + keys[numVecs - i - 1], + values[i], + values[numVecs - i - 1]); + keys[numVecs - i - 1] = keyType::reverse(keys[numVecs - i - 1]); - values[numVecs - i - 1] = valueType::reverse(values[numVecs - i - 1]); + values[numVecs - i - 1] + = valueType::reverse(values[numVecs - i - 1]); } } @@ -208,11 +225,10 @@ X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(typename keyType::reg_t *keys, typ } } -template -X86_SIMD_SORT_INLINE void bitonic_fullmerge_n_vec(typename keyType::reg_t *keys, typename valueType::reg_t *values) +template +X86_SIMD_SORT_INLINE void +bitonic_fullmerge_n_vec(typename keyType::reg_t *keys, + typename valueType::reg_t *values) { if constexpr (numPer > numVecs) { UNUSED(keys); @@ -222,18 +238,22 @@ X86_SIMD_SORT_INLINE void bitonic_fullmerge_n_vec(typename keyType::reg_t *keys, else { X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs / numPer; i++) { - bitonic_merge_n_vec(keys + i * numPer, values + i * numPer); + bitonic_merge_n_vec( + keys + i * numPer, values + i * numPer); } - bitonic_fullmerge_n_vec(keys, values); + bitonic_fullmerge_n_vec( + keys, values); } } template -X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, typename indexType::type_t *indices, int N) +X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, + typename indexType::type_t *indices, + int N) { using kreg_t = typename keyType::reg_t; using ireg_t = typename indexType::reg_t; - + static_assert(numVecs > 0, "numVecs should be > 0"); if constexpr (numVecs > 1) { if (N * 2 <= numVecs * keyType::numlanes) { @@ -259,17 +279,21 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, typename X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs / 2; i++) { indexVecs[i] = indexType::loadu(indices + i * indexType::numlanes); - keyVecs[i] = keyType::i64gather(keys, indices + i * indexType::numlanes); + keyVecs[i] + = keyType::i64gather(keys, indices + i * indexType::numlanes); } // Masked part of the load X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - indexVecs[i] = indexType::mask_loadu( - indexType::zmm_max(), ioMasks[j], indices + i * indexType::numlanes); - - keyVecs[i] = keyType::template mask_i64gather(keyType::zmm_max(), ioMasks[j], indexVecs[i], keys); + indexVecs[i] = indexType::mask_loadu(indexType::zmm_max(), + ioMasks[j], + indices + i * indexType::numlanes); + + keyVecs[i] = keyType::template mask_i64gather( + keyType::zmm_max(), ioMasks[j], indexVecs[i], keys); } - + // Sort each loaded vector X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs; i++) { @@ -287,16 +311,19 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, typename // Masked part of the store X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - indexType::mask_storeu(indices + i * indexType::numlanes, ioMasks[j], indexVecs[i]); + indexType::mask_storeu( + indices + i * indexType::numlanes, ioMasks[j], indexVecs[i]); } } template -X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, typename valueType::type_t *values, int N) +X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, + typename valueType::type_t *values, + int N) { using kreg_t = typename keyType::reg_t; using vreg_t = typename valueType::reg_t; - + static_assert(numVecs > 0, "numVecs should be > 0"); if constexpr (numVecs > 1) { if (N * 2 <= numVecs * keyType::numlanes) { @@ -329,8 +356,9 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, typename for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { keyVecs[i] = keyType::mask_loadu( keyType::zmm_max(), ioMasks[j], keys + i * keyType::numlanes); - valueVecs[i] = valueType::mask_loadu( - valueType::zmm_max(), ioMasks[j], values + i * valueType::numlanes); + valueVecs[i] = valueType::mask_loadu(valueType::zmm_max(), + ioMasks[j], + values + i * valueType::numlanes); } // Sort each loaded vector @@ -351,17 +379,21 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, typename // Masked part of the store X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - keyType::mask_storeu(keys + i * keyType::numlanes, ioMasks[j], keyVecs[i]); - valueType::mask_storeu(values + i * valueType::numlanes, ioMasks[j], valueVecs[i]); + keyType::mask_storeu( + keys + i * keyType::numlanes, ioMasks[j], keyVecs[i]); + valueType::mask_storeu( + values + i * valueType::numlanes, ioMasks[j], valueVecs[i]); } } template -X86_SIMD_SORT_INLINE void argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N) +X86_SIMD_SORT_INLINE void +argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N) { using indexType = typename index_64bit_vector_type::type; - - static_assert(keyType::numlanes == indexType::numlanes, "invalid pairing of value/index types"); + + static_assert(keyType::numlanes == indexType::numlanes, + "invalid pairing of value/index types"); constexpr int numVecs = maxN / keyType::numlanes; constexpr bool isMultiple = (maxN == (keyType::numlanes * numVecs)); constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); @@ -372,10 +404,13 @@ X86_SIMD_SORT_INLINE void argsort_n(typename keyType::type_t *keys, arrsize_t *i } template -X86_SIMD_SORT_INLINE void kvsort_n(typename keyType::type_t *keys, typename valueType::type_t *values, int N) -{ - static_assert(keyType::numlanes == valueType::numlanes, "invalid pairing of key/value types"); - +X86_SIMD_SORT_INLINE void kvsort_n(typename keyType::type_t *keys, + typename valueType::type_t *values, + int N) +{ + static_assert(keyType::numlanes == valueType::numlanes, + "invalid pairing of key/value types"); + constexpr int numVecs = maxN / keyType::numlanes; constexpr bool isMultiple = (maxN == (keyType::numlanes * numVecs)); constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); From 2ac25457260c6c97851a340f024896938a6025b8 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 14 Nov 2023 12:07:21 -0800 Subject: [PATCH 4/9] Removed unused swizzle_ops logic --- src/avx512-64bit-common.h | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index bde8dd71..911f5395 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -34,8 +34,6 @@ struct ymm_vector { using opmask_t = __mmask8; static const uint8_t numlanes = 8; - using swizzle_ops = avx2_32bit_swizzle_ops; - static type_t type_max() { return X86_SIMD_SORT_INFINITYF; @@ -219,8 +217,6 @@ struct ymm_vector { using opmask_t = __mmask8; static const uint8_t numlanes = 8; - using swizzle_ops = avx2_32bit_swizzle_ops; - static type_t type_max() { return X86_SIMD_SORT_MAX_UINT32; @@ -394,8 +390,6 @@ struct ymm_vector { using opmask_t = __mmask8; static const uint8_t numlanes = 8; - using swizzle_ops = avx2_32bit_swizzle_ops; - static type_t type_max() { return X86_SIMD_SORT_MAX_INT32; From bd44ccf39313e2c7620b1f74d7393dde83ff766c Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 14 Nov 2023 14:38:43 -0800 Subject: [PATCH 5/9] Changed parameters for argsort/argselect --- src/avx512-64bit-argsort.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index ad8ecd23..fa91008b 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -396,8 +396,8 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, /* * Base case: use bitonic networks to sort arrays <= 64 */ - if (right + 1 - left <= 64) { - argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + if (right + 1 - left <= 256) { + argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); @@ -429,8 +429,8 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, /* * Base case: use bitonic networks to sort arrays <= 64 */ - if (right + 1 - left <= 64) { - argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + if (right + 1 - left <= 256) { + argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); From a7e88478478a32571cbd8e95d11a70d5d2817734 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Wed, 3 Jan 2024 12:51:23 -0800 Subject: [PATCH 6/9] Fixes logic for 32-bit systems --- src/avx512-64bit-argsort.hpp | 23 +++++++++++++---------- src/xss-network-keyvaluesort.hpp | 5 ++--- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index fa91008b..0ee72040 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -379,7 +379,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, } } -template +template X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, arrsize_t *arg, arrsize_t left, @@ -397,7 +397,7 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, * Base case: use bitonic networks to sort arrays <= 64 */ if (right + 1 - left <= 256) { - argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); @@ -406,12 +406,12 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, arrsize_t pivot_index = partition_avx512_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) - argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); + argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); if (pivot != biggest) - argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); + argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); } -template +template X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, arrsize_t *arg, arrsize_t pos, @@ -430,7 +430,7 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, * Base case: use bitonic networks to sort arrays <= 64 */ if (right + 1 - left <= 256) { - argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); @@ -439,10 +439,10 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, arrsize_t pivot_index = partition_avx512_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if ((pivot != smallest) && (pos < pivot_index)) - argselect_64bit_( + argselect_64bit_( arr, arg, pos, left, pivot_index - 1, max_iters - 1); else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_64bit_( + argselect_64bit_( arr, arg, pos, pivot_index, right, max_iters - 1); } @@ -454,6 +454,8 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) using vectype = typename std::conditional, zmm_vector>::type; + using indextype = typename std::conditional, zmm_vector>::type; + if (arrsize > 1) { if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { @@ -462,7 +464,7 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) } } UNUSED(hasnan); - argsort_64bit_( + argsort_64bit_( arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } @@ -488,6 +490,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, using vectype = typename std::conditional, zmm_vector>::type; + using indextype = typename std::conditional, zmm_vector>::type; if (arrsize > 1) { if constexpr (std::is_floating_point_v) { @@ -497,7 +500,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, } } UNUSED(hasnan); - argselect_64bit_( + argselect_64bit_( arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index 4aca9270..82e9a992 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -386,14 +386,13 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, } } -template +template X86_SIMD_SORT_INLINE void argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N) { - using indexType = typename index_64bit_vector_type::type; - static_assert(keyType::numlanes == indexType::numlanes, "invalid pairing of value/index types"); + constexpr int numVecs = maxN / keyType::numlanes; constexpr bool isMultiple = (maxN == (keyType::numlanes * numVecs)); constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); From 2282e703d56dac4b7161a49db8ba9468c3b14702 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Wed, 3 Jan 2024 12:52:04 -0800 Subject: [PATCH 7/9] clang-format --- src/avx512-64bit-argsort.hpp | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 0ee72040..917a75f1 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -397,7 +397,8 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, * Base case: use bitonic networks to sort arrays <= 64 */ if (right + 1 - left <= 256) { - argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); @@ -406,9 +407,11 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, arrsize_t pivot_index = partition_avx512_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) - argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); + argsort_64bit_( + arr, arg, left, pivot_index - 1, max_iters - 1); if (pivot != biggest) - argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); + argsort_64bit_( + arr, arg, pivot_index, right, max_iters - 1); } template @@ -430,7 +433,8 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, * Base case: use bitonic networks to sort arrays <= 64 */ if (right + 1 - left <= 256) { - argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); @@ -454,8 +458,12 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) using vectype = typename std::conditional, zmm_vector>::type; - using indextype = typename std::conditional, zmm_vector>::type; - + using indextype = + typename std::conditional, + zmm_vector>::type; + if (arrsize > 1) { if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { @@ -490,7 +498,11 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, using vectype = typename std::conditional, zmm_vector>::type; - using indextype = typename std::conditional, zmm_vector>::type; + using indextype = + typename std::conditional, + zmm_vector>::type; if (arrsize > 1) { if constexpr (std::is_floating_point_v) { From e7e452f693a06a92258f9ba9e981e015bc44d73b Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Wed, 3 Jan 2024 12:58:14 -0800 Subject: [PATCH 8/9] Removed unused code --- src/xss-network-keyvaluesort.hpp | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index 82e9a992..6e3fc111 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -4,17 +4,6 @@ #include "avx512-64bit-qsort.hpp" #include "avx2-64bit-qsort.hpp" -template -struct index_64bit_vector_type; -template <> -struct index_64bit_vector_type<8> { - using type = zmm_vector; -}; -template <> -struct index_64bit_vector_type<4> { - using type = avx2_vector; -}; - template Date: Wed, 3 Jan 2024 14:31:37 -0800 Subject: [PATCH 9/9] get rid of global argtype definition --- src/avx512-64bit-argsort.hpp | 202 +++++++++++++++++-------------- src/xss-network-keyvaluesort.hpp | 2 +- 2 files changed, 111 insertions(+), 93 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 917a75f1..799303d8 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -65,24 +65,15 @@ std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) }); } -/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of - * undefined template 'zmm_vector'*/ -#ifdef __APPLE__ -using argtype = typename std::conditional, - zmm_vector>::type; -#else -using argtype = typename std::conditional, - zmm_vector>::type; -#endif -using argreg_t = typename argtype::reg_t; - /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. */ -template +template X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, arrsize_t left, arrsize_t right, @@ -107,7 +98,11 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, * Parition an array based on the pivot and returns the index of the * last element that is less than equal to the pivot. */ -template +template X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, arrsize_t *arg, arrsize_t left, @@ -131,7 +126,6 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, if (left == right) return left; /* less than vtype::numlanes elements in the array */ - using reg_t = typename vtype::reg_t; reg_t pivot_vec = vtype::set1(pivot); reg_t min_vec = vtype::set1(*smallest); reg_t max_vec = vtype::set1(*biggest); @@ -139,14 +133,15 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, if (right - left == vtype::numlanes) { argreg_t argvec = argtype::loadu(arg + left); reg_t vec = vtype::i64gather(arr, arg + left); - int32_t amount_gt_pivot = partition_vec(arg, - left, - left + vtype::numlanes, - argvec, - vec, - pivot_vec, - &min_vec, - &max_vec); + int32_t amount_gt_pivot + = partition_vec(arg, + left, + left + vtype::numlanes, + argvec, + vec, + pivot_vec, + &min_vec, + &max_vec); *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); return left + (vtype::numlanes - amount_gt_pivot); @@ -183,37 +178,38 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, } // partition the current vector and save it on both sides of the array int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec, - curr_vec, - pivot_vec, - &min_vec, - &max_vec); + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec, + curr_vec, + pivot_vec, + &min_vec, + &max_vec); ; r_store -= amount_gt_pivot; l_store += (vtype::numlanes - amount_gt_pivot); } /* partition and save vec_left and vec_right */ - int32_t amount_gt_pivot = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left, - vec_left, - pivot_vec, - &min_vec, - &max_vec); + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left, + vec_left, + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(arg, - l_store, - l_store + vtype::numlanes, - argvec_right, - vec_right, - pivot_vec, - &min_vec, - &max_vec); + amount_gt_pivot = partition_vec(arg, + l_store, + l_store + vtype::numlanes, + argvec_right, + vec_right, + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); @@ -221,8 +217,10 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, } template + typename type_t = typename vtype::type_t, + typename argreg_t = typename argtype::reg_t> X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, arrsize_t *arg, arrsize_t left, @@ -232,7 +230,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, type_t *biggest) { if (right - left <= 8 * num_unroll * vtype::numlanes) { - return partition_avx512( + return partition_avx512( arr, arg, left, right, pivot, smallest, biggest); } /* make array length divisible by vtype::numlanes , shortening the array */ @@ -305,14 +303,14 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec[ii], - curr_vec[ii], - pivot_vec, - &min_vec, - &max_vec); + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec[ii], + curr_vec[ii], + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); r_store -= amount_gt_pivot; } @@ -322,28 +320,28 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left[ii], - vec_left[ii], - pivot_vec, - &min_vec, - &max_vec); + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left[ii], + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); r_store -= amount_gt_pivot; } X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_right[ii], - vec_right[ii], - pivot_vec, - &min_vec, - &max_vec); + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_right[ii], + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); l_store += (vtype::numlanes - amount_gt_pivot); r_store -= amount_gt_pivot; } @@ -379,7 +377,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, } } -template +template X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, arrsize_t *arg, arrsize_t left, @@ -397,24 +395,24 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, * Base case: use bitonic networks to sort arrays <= 64 */ if (right + 1 - left <= 256) { - argsort_n( + argsort_n( arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( + arrsize_t pivot_index = partition_avx512_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) - argsort_64bit_( + argsort_64bit_( arr, arg, left, pivot_index - 1, max_iters - 1); if (pivot != biggest) - argsort_64bit_( + argsort_64bit_( arr, arg, pivot_index, right, max_iters - 1); } -template +template X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, arrsize_t *arg, arrsize_t pos, @@ -433,20 +431,20 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, * Base case: use bitonic networks to sort arrays <= 64 */ if (right + 1 - left <= 256) { - argsort_n( + argsort_n( arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( + arrsize_t pivot_index = partition_avx512_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if ((pivot != smallest) && (pos < pivot_index)) - argselect_64bit_( + argselect_64bit_( arr, arg, pos, left, pivot_index - 1, max_iters - 1); else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_64bit_( + argselect_64bit_( arr, arg, pos, pivot_index, right, max_iters - 1); } @@ -455,14 +453,24 @@ template X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) { + /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ using vectype = typename std::conditional, zmm_vector>::type; - using indextype = - typename std::conditional'*/ +#ifdef __APPLE__ + using argtype = + typename std::conditional, + zmm_vector>::type; +#else + using argtype = + typename std::conditional, zmm_vector>::type; +#endif if (arrsize > 1) { if constexpr (std::is_floating_point_v) { @@ -472,7 +480,7 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) } } UNUSED(hasnan); - argsort_64bit_( + argsort_64bit_( arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } @@ -495,14 +503,24 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, arrsize_t arrsize, bool hasnan = false) { + /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ using vectype = typename std::conditional, zmm_vector>::type; - using indextype = - typename std::conditional'*/ +#ifdef __APPLE__ + using argtype = + typename std::conditional, + zmm_vector>::type; +#else + using argtype = + typename std::conditional, zmm_vector>::type; +#endif if (arrsize > 1) { if constexpr (std::is_floating_point_v) { @@ -512,7 +530,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, } } UNUSED(hasnan); - argselect_64bit_( + argselect_64bit_( arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index 6e3fc111..fccf33dc 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -408,4 +408,4 @@ X86_SIMD_SORT_INLINE void kvsort_n(typename keyType::type_t *keys, kvsort_n_vec(keys, values, N); } -#endif \ No newline at end of file +#endif