diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 5588cffa..81d7d00e 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -1,8 +1,6 @@ // AVX2 specific routines: #include "avx2-32bit-qsort.hpp" #include "avx2-64bit-qsort.hpp" -#include "avx2-32bit-half.hpp" -#include "xss-common-argsort.h" #include "x86simdsort-internal.h" #define DEFINE_ALL_METHODS(type) \ @@ -20,17 +18,6 @@ void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \ { \ avx2_partial_qsort(arr, k, arrsize, hasnan); \ - }\ - template <> \ - std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ - { \ - return avx2_argsort(arr, arrsize, hasnan); \ - } \ - template <> \ - std::vector argselect( \ - type *arr, size_t k, size_t arrsize, bool hasnan) \ - { \ - return avx2_argselect(arr, k, arrsize, hasnan); \ } namespace xss { diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index f088e4cd..8ebbc6be 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -189,12 +189,12 @@ DISPATCH_ALL(partial_qsort, (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_ALL(argsort, (ISA_LIST("none")), - (ISA_LIST("avx512_skx", "avx2")), - (ISA_LIST("avx512_skx", "avx2"))) + (ISA_LIST("avx512_skx")), + (ISA_LIST("avx512_skx"))) DISPATCH_ALL(argselect, (ISA_LIST("none")), - (ISA_LIST("avx512_skx", "avx2")), - (ISA_LIST("avx512_skx", "avx2"))) + (ISA_LIST("avx512_skx")), + (ISA_LIST("avx512_skx"))) #define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \ DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx")))\ diff --git a/src/avx2-32bit-half.hpp b/src/avx2-32bit-half.hpp deleted file mode 100644 index 5a6ee5b5..00000000 --- a/src/avx2-32bit-half.hpp +++ /dev/null @@ -1,557 +0,0 @@ -/******************************************************************* - * Copyright (C) 2022 Intel Corporation - * SPDX-License-Identifier: BSD-3-Clause - * Authors: Raghuveer Devulapalli - * ****************************************************************/ - -#ifndef AVX2_HALF_32BIT -#define AVX2_HALF_32BIT - -#include "xss-common-qsort.h" -#include "avx2-emu-funcs.hpp" - -/* - * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ - -// ymm 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 - -/* - * Assumes ymm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit_half(reg_t ymm) -{ - using swizzle = typename vtype::swizzle_ops; - - const typename vtype::opmask_t oxAA = vtype::seti(-1, 0, -1, 0); - const typename vtype::opmask_t oxCC = vtype::seti(-1, -1, 0, 0); - - ymm = cmp_merge(ymm, swizzle::template swap_n(ymm), oxAA); - ymm = cmp_merge(ymm, vtype::reverse(ymm), oxCC); - ymm = cmp_merge(ymm, swizzle::template swap_n(ymm), oxAA); - return ymm; -} - -struct avx2_32bit_half_swizzle_ops; - -template <> -struct avx2_half_vector { - using type_t = int32_t; - using reg_t = __m128i; - using ymmi_t = __m128i; - using opmask_t = __m128i; - static const uint8_t numlanes = 4; - static constexpr simd_type vec_type = simd_type::AVX2; - - using swizzle_ops = avx2_32bit_half_swizzle_ops; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_INT32; - } - static type_t type_min() - { - return X86_SIMD_SORT_MIN_INT32; - } - static reg_t zmm_max() - { - return _mm_set1_epi32(type_max()); - } // TODO: this should broadcast bits as is? - static opmask_t get_partial_loadmask(uint64_t num_to_read) - { - auto mask = ((0x1ull << num_to_read) - 0x1ull); - return convert_int_to_avx2_mask_half(mask); - } - static ymmi_t seti(int v1, int v2, int v3, int v4) - { - return _mm_set_epi32(v1, v2, v3, v4); - } - static reg_t set(int v1, int v2, int v3, int v4) - { - return _mm_set_epi32(v1, v2, v3, v4); - } - static opmask_t kxor_opmask(opmask_t x, opmask_t y) - { - return _mm_xor_si128(x, y); - } - static opmask_t ge(reg_t x, reg_t y) - { - opmask_t equal = eq(x, y); - opmask_t greater = _mm_cmpgt_epi32(x, y); - return _mm_castps_si128( - _mm_or_ps(_mm_castsi128_ps(equal), _mm_castsi128_ps(greater))); - } - static opmask_t eq(reg_t x, reg_t y) - { - return _mm_cmpeq_epi32(x, y); - } - template - static reg_t - mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) - { - return _mm256_mask_i64gather_epi32( - src, (const int *)base, index, mask, scale); - } - static reg_t i64gather(type_t *arr, arrsize_t *ind) - { - return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); - } - static reg_t loadu(void const *mem) - { - return _mm_loadu_si128((reg_t const *)mem); - } - static reg_t max(reg_t x, reg_t y) - { - return _mm_max_epi32(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) - { - return avx2_emu_mask_compressstoreu32_half(mem, mask, x); - } - static reg_t maskz_loadu(opmask_t mask, void const *mem) - { - return _mm_maskload_epi32((const int *)mem, mask); - } - static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) - { - reg_t dst = _mm_maskload_epi32((type_t *)mem, mask); - return mask_mov(x, mask, dst); - } - static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) - { - return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(x), - _mm_castsi128_ps(y), - _mm_castsi128_ps(mask))); - } - static void mask_storeu(void *mem, opmask_t mask, reg_t x) - { - return _mm_maskstore_epi32((type_t *)mem, mask, x); - } - static reg_t min(reg_t x, reg_t y) - { - return _mm_min_epi32(x, y); - } - static reg_t permutexvar(__m128i idx, reg_t ymm) - { - return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); - } - static reg_t permutevar(reg_t ymm, __m128i idx) - { - return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); - } - static reg_t reverse(reg_t ymm) - { - const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); - return permutexvar(rev_index, ymm); - } - static type_t reducemax(reg_t v) - { - return avx2_emu_reduce_max32_half(v); - } - static type_t reducemin(reg_t v) - { - return avx2_emu_reduce_min32_half(v); - } - static reg_t set1(type_t v) - { - return _mm_set1_epi32(v); - } - template - static reg_t shuffle(reg_t ymm) - { - return _mm_shuffle_epi32(ymm, mask); - } - static void storeu(void *mem, reg_t x) - { - _mm_storeu_si128((__m128i *)mem, x); - } - static reg_t sort_vec(reg_t x) - { - return sort_ymm_32bit_half>(x); - } - static reg_t cast_from(__m128i v) - { - return v; - } - static __m128i cast_to(reg_t v) - { - return v; - } - static int double_compressstore(type_t *left_addr, - type_t *right_addr, - opmask_t k, - reg_t reg) - { - return avx2_double_compressstore32_half( - left_addr, right_addr, k, reg); - } -}; -template <> -struct avx2_half_vector { - using type_t = uint32_t; - using reg_t = __m128i; - using ymmi_t = __m128i; - using opmask_t = __m128i; - static const uint8_t numlanes = 4; - static constexpr simd_type vec_type = simd_type::AVX2; - - using swizzle_ops = avx2_32bit_half_swizzle_ops; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_UINT32; - } - static type_t type_min() - { - return 0; - } - static reg_t zmm_max() - { - return _mm_set1_epi32(type_max()); - } - static opmask_t get_partial_loadmask(uint64_t num_to_read) - { - auto mask = ((0x1ull << num_to_read) - 0x1ull); - return convert_int_to_avx2_mask_half(mask); - } - static ymmi_t seti(int v1, int v2, int v3, int v4) - { - return _mm_set_epi32(v1, v2, v3, v4); - } - static reg_t set(int v1, int v2, int v3, int v4) - { - return _mm_set_epi32(v1, v2, v3, v4); - } - template - static reg_t - mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) - { - return _mm256_mask_i64gather_epi32( - src, (const int *)base, index, mask, scale); - } - static reg_t i64gather(type_t *arr, arrsize_t *ind) - { - return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); - } - static opmask_t ge(reg_t x, reg_t y) - { - reg_t maxi = max(x, y); - return eq(maxi, x); - } - static opmask_t eq(reg_t x, reg_t y) - { - return _mm_cmpeq_epi32(x, y); - } - static reg_t loadu(void const *mem) - { - return _mm_loadu_si128((reg_t const *)mem); - } - static reg_t max(reg_t x, reg_t y) - { - return _mm_max_epu32(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) - { - return avx2_emu_mask_compressstoreu32_half(mem, mask, x); - } - static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) - { - reg_t dst = _mm_maskload_epi32((const int *)mem, mask); - return mask_mov(x, mask, dst); - } - static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) - { - return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(x), - _mm_castsi128_ps(y), - _mm_castsi128_ps(mask))); - } - static void mask_storeu(void *mem, opmask_t mask, reg_t x) - { - return _mm_maskstore_epi32((int *)mem, mask, x); - } - static reg_t min(reg_t x, reg_t y) - { - return _mm_min_epu32(x, y); - } - static reg_t permutexvar(__m128i idx, reg_t ymm) - { - return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); - } - static reg_t permutevar(reg_t ymm, __m128i idx) - { - return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); - } - static reg_t reverse(reg_t ymm) - { - const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); - return permutexvar(rev_index, ymm); - } - static type_t reducemax(reg_t v) - { - return avx2_emu_reduce_max32_half(v); - } - static type_t reducemin(reg_t v) - { - return avx2_emu_reduce_min32_half(v); - } - static reg_t set1(type_t v) - { - return _mm_set1_epi32(v); - } - template - static reg_t shuffle(reg_t ymm) - { - return _mm_shuffle_epi32(ymm, mask); - } - static void storeu(void *mem, reg_t x) - { - _mm_storeu_si128((__m128i *)mem, x); - } - static reg_t sort_vec(reg_t x) - { - return sort_ymm_32bit_half>(x); - } - static reg_t cast_from(__m128i v) - { - return v; - } - static __m128i cast_to(reg_t v) - { - return v; - } - static int double_compressstore(type_t *left_addr, - type_t *right_addr, - opmask_t k, - reg_t reg) - { - return avx2_double_compressstore32_half( - left_addr, right_addr, k, reg); - } -}; -template <> -struct avx2_half_vector { - using type_t = float; - using reg_t = __m128; - using ymmi_t = __m128i; - using opmask_t = __m128i; - static const uint8_t numlanes = 4; - static constexpr simd_type vec_type = simd_type::AVX2; - - using swizzle_ops = avx2_32bit_half_swizzle_ops; - - static type_t type_max() - { - return X86_SIMD_SORT_INFINITYF; - } - static type_t type_min() - { - return -X86_SIMD_SORT_INFINITYF; - } - static reg_t zmm_max() - { - return _mm_set1_ps(type_max()); - } - - static ymmi_t seti(int v1, int v2, int v3, int v4) - { - return _mm_set_epi32(v1, v2, v3, v4); - } - static reg_t set(float v1, float v2, float v3, float v4) - { - return _mm_set_ps(v1, v2, v3, v4); - } - static reg_t maskz_loadu(opmask_t mask, void const *mem) - { - return _mm_maskload_ps((const float *)mem, mask); - } - static opmask_t ge(reg_t x, reg_t y) - { - return _mm_castps_si128(_mm_cmp_ps(x, y, _CMP_GE_OQ)); - } - static opmask_t eq(reg_t x, reg_t y) - { - return _mm_castps_si128(_mm_cmp_ps(x, y, _CMP_EQ_OQ)); - } - static opmask_t get_partial_loadmask(uint64_t num_to_read) - { - auto mask = ((0x1ull << num_to_read) - 0x1ull); - return convert_int_to_avx2_mask_half(mask); - } - static int32_t convert_mask_to_int(opmask_t mask) - { - return convert_avx2_mask_to_int_half(mask); - } - template - static opmask_t fpclass(reg_t x) - { - if constexpr (type == (0x01 | 0x80)) { - return _mm_castps_si128(_mm_cmp_ps(x, x, _CMP_UNORD_Q)); - } - else { - static_assert(type == (0x01 | 0x80), "should not reach here"); - } - } - template - static reg_t - mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) - { - return _mm256_mask_i64gather_ps( - src, (const float *)base, index, _mm_castsi128_ps(mask), scale); - } - static reg_t i64gather(type_t *arr, arrsize_t *ind) - { - return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); - } - static reg_t loadu(void const *mem) - { - return _mm_loadu_ps((float const *)mem); - } - static reg_t max(reg_t x, reg_t y) - { - return _mm_max_ps(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) - { - return avx2_emu_mask_compressstoreu32_half(mem, mask, x); - } - static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) - { - reg_t dst = _mm_maskload_ps((type_t *)mem, mask); - return mask_mov(x, mask, dst); - } - static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) - { - return _mm_blendv_ps(x, y, _mm_castsi128_ps(mask)); - } - static void mask_storeu(void *mem, opmask_t mask, reg_t x) - { - return _mm_maskstore_ps((type_t *)mem, mask, x); - } - static reg_t min(reg_t x, reg_t y) - { - return _mm_min_ps(x, y); - } - static reg_t permutexvar(__m128i idx, reg_t ymm) - { - return _mm_permutevar_ps(ymm, idx); - } - static reg_t permutevar(reg_t ymm, __m128i idx) - { - return _mm_permutevar_ps(ymm, idx); - } - static reg_t reverse(reg_t ymm) - { - const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); - return permutexvar(rev_index, ymm); - } - static type_t reducemax(reg_t v) - { - return avx2_emu_reduce_max32_half(v); - } - static type_t reducemin(reg_t v) - { - return avx2_emu_reduce_min32_half(v); - } - static reg_t set1(type_t v) - { - return _mm_set1_ps(v); - } - template - static reg_t shuffle(reg_t ymm) - { - return _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(ymm), mask)); - } - static void storeu(void *mem, reg_t x) - { - _mm_storeu_ps((float *)mem, x); - } - static reg_t sort_vec(reg_t x) - { - return sort_ymm_32bit_half>(x); - } - static reg_t cast_from(__m128i v) - { - return _mm_castsi128_ps(v); - } - static __m128i cast_to(reg_t v) - { - return _mm_castps_si128(v); - } - static int double_compressstore(type_t *left_addr, - type_t *right_addr, - opmask_t k, - reg_t reg) - { - return avx2_double_compressstore32_half( - left_addr, right_addr, k, reg); - } -}; - -struct avx2_32bit_half_swizzle_ops { - template - X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) - { - __m128i v = vtype::cast_to(reg); - - if constexpr (scale == 2) { - __m128 vf = _mm_castsi128_ps(v); - vf = _mm_permute_ps(vf, 0b10110001); - v = _mm_castps_si128(vf); - } - else if constexpr (scale == 4) { - __m128 vf = _mm_castsi128_ps(v); - vf = _mm_permute_ps(vf, 0b01001110); - v = _mm_castps_si128(vf); - } - else { - static_assert(scale == -1, "should not be reached"); - } - - return vtype::cast_from(v); - } - - template - X86_SIMD_SORT_INLINE typename vtype::reg_t - reverse_n(typename vtype::reg_t reg) - { - __m128i v = vtype::cast_to(reg); - - if constexpr (scale == 2) { return swap_n(reg); } - else if constexpr (scale == 4) { - return vtype::reverse(reg); - } - else { - static_assert(scale == -1, "should not be reached"); - } - - return vtype::cast_from(v); - } - - template - X86_SIMD_SORT_INLINE typename vtype::reg_t - merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) - { - __m128i v1 = vtype::cast_to(reg); - __m128i v2 = vtype::cast_to(other); - - if constexpr (scale == 2) { v1 = _mm_blend_epi32(v1, v2, 0b0101); } - else if constexpr (scale == 4) { - v1 = _mm_blend_epi32(v1, v2, 0b0011); - } - else { - static_assert(scale == -1, "should not be reached"); - } - - return vtype::cast_from(v1); - } -}; - -#endif // AVX2_HALF_32BIT diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index cf0fbd55..521597cd 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -70,7 +70,6 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; - static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; @@ -226,7 +225,6 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; - static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; @@ -371,7 +369,6 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; - static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index 709d98ef..6ffddbde 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -11,6 +11,15 @@ #include "xss-common-qsort.h" #include "avx2-emu-funcs.hpp" +/* + * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ +// ymm 3, 2, 1, 0 +#define NETWORK_64BIT_R 0, 1, 2, 3 +#define NETWORK_64BIT_1 1, 0, 3, 2 + /* * Assumes ymm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg @@ -52,7 +61,6 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; @@ -77,10 +85,6 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } - static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) - { - return _mm256_set_epi64x(v1, v2, v3, v4); - } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _mm256_xor_si256(x, y); @@ -103,12 +107,12 @@ struct avx2_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi64( - src, (const long long int *)base, index, mask, scale); + return _mm256_mask_i64gather_epi64(src, base, index, mask, scale); } - static reg_t i64gather(type_t *arr, arrsize_t *ind) + template + static reg_t i64gather(__m256i index, void const *base) { - return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); + return _mm256_i64gather_epi64((int64_t const *)base, index, scale); } static reg_t loadu(void const *mem) { @@ -216,7 +220,6 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; @@ -241,20 +244,17 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } - static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) - { - return _mm256_set_epi64x(v1, v2, v3, v4); - } template static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi64( - src, (const long long int *)base, index, mask, scale); + return _mm256_mask_i64gather_epi64(src, base, index, mask, scale); } - static reg_t i64gather(type_t *arr, arrsize_t *ind) + template + static reg_t i64gather(__m256i index, void const *base) { - return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); + return _mm256_i64gather_epi64( + (long long int const *)base, index, scale); } static opmask_t gt(reg_t x, reg_t y) { @@ -378,7 +378,6 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; @@ -417,10 +416,7 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } - static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) - { - return _mm256_set_pd(v1, v2, v3, v4); - } + static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm256_maskload_pd((const double *)mem, mask); @@ -437,16 +433,14 @@ struct avx2_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_pd(src, - (const type_t *)base, - index, - _mm256_castsi256_pd(mask), - scale); + return _mm256_mask_i64gather_pd( + src, base, index, _mm256_castsi256_pd(mask), scale); ; } - static reg_t i64gather(type_t *arr, arrsize_t *ind) + template + static reg_t i64gather(__m256i index, void const *base) { - return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); + return _mm256_i64gather_pd((double *)base, index, scale); } static reg_t loadu(void const *mem) { diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index 6e40d2a6..9f6229f7 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -35,21 +35,6 @@ constexpr auto avx2_mask_helper_lut64 = [] { return lut; }(); -constexpr auto avx2_mask_helper_lut32_half = [] { - std::array, 16> lut {}; - for (int64_t i = 0; i <= 0xF; i++) { - std::array entry {}; - for (int j = 0; j < 4; j++) { - if (((i >> j) & 1) == 1) - entry[j] = 0xFFFFFFFF; - else - entry[j] = 0; - } - lut[i] = entry; - } - return lut; -}(); - constexpr auto avx2_compressstore_lut32_gen = [] { std::array, 256>, 2> lutPair {}; auto &permLut = lutPair[0]; @@ -80,38 +65,6 @@ constexpr auto avx2_compressstore_lut32_gen = [] { constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0]; constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1]; -constexpr auto avx2_compressstore_lut32_half_gen = [] { - std::array, 16>, 2> lutPair {}; - auto &permLut = lutPair[0]; - auto &leftLut = lutPair[1]; - for (int64_t i = 0; i <= 0xF; i++) { - std::array indices {}; - std::array leftEntry = {0, 0, 0, 0}; - int right = 3; - int left = 0; - for (int j = 0; j < 4; j++) { - bool ge = (i >> j) & 1; - if (ge) { - indices[right] = j; - right--; - } - else { - indices[left] = j; - leftEntry[left] = 0xFFFFFFFF; - left++; - } - } - permLut[i] = indices; - leftLut[i] = leftEntry; - } - return lutPair; -}(); - -constexpr auto avx2_compressstore_lut32_half_perm - = avx2_compressstore_lut32_half_gen[0]; -constexpr auto avx2_compressstore_lut32_half_left - = avx2_compressstore_lut32_half_gen[1]; - constexpr auto avx2_compressstore_lut64_gen = [] { std::array, 16> permLut {}; std::array, 16> leftLut {}; @@ -170,19 +123,6 @@ int32_t convert_avx2_mask_to_int_64bit(__m256i m) return _mm256_movemask_pd(_mm256_castsi256_pd(m)); } -X86_SIMD_SORT_INLINE -__m128i convert_int_to_avx2_mask_half(int32_t m) -{ - return _mm_loadu_si128( - (const __m128i *)avx2_mask_helper_lut32_half[m].data()); -} - -X86_SIMD_SORT_INLINE -int32_t convert_avx2_mask_to_int_half(__m128i m) -{ - return _mm_movemask_ps(_mm_castsi128_ps(m)); -} - // Emulators for intrinsics missing from AVX2 compared to AVX512 template T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) @@ -199,19 +139,6 @@ T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) return std::max(arr[0], arr[7]); } -template -T avx2_emu_reduce_max32_half(typename avx2_half_vector::reg_t x) -{ - using vtype = avx2_half_vector; - using reg_t = typename vtype::reg_t; - - reg_t inter1 = vtype::max( - x, vtype::template shuffle(x)); - T arr[vtype::numlanes]; - vtype::storeu(arr, inter1); - return std::max(arr[0], arr[3]); -} - template T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) { @@ -227,19 +154,6 @@ T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) return std::min(arr[0], arr[7]); } -template -T avx2_emu_reduce_min32_half(typename avx2_half_vector::reg_t x) -{ - using vtype = avx2_half_vector; - using reg_t = typename vtype::reg_t; - - reg_t inter1 = vtype::min( - x, vtype::template shuffle(x)); - T arr[vtype::numlanes]; - vtype::storeu(arr, inter1); - return std::min(arr[0], arr[3]); -} - template T avx2_emu_reduce_max64(typename avx2_vector::reg_t x) { @@ -282,29 +196,6 @@ void avx2_emu_mask_compressstoreu32(void *base_addr, vtype::mask_storeu(leftStore, left, temp); } -template -void avx2_emu_mask_compressstoreu32_half( - void *base_addr, - typename avx2_half_vector::opmask_t k, - typename avx2_half_vector::reg_t reg) -{ - using vtype = avx2_half_vector; - - T *leftStore = (T *)base_addr; - - int32_t shortMask = convert_avx2_mask_to_int_half(k); - const __m128i &perm = _mm_loadu_si128( - (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] - .data()); - const __m128i &left = _mm_loadu_si128( - (const __m128i *)avx2_compressstore_lut32_half_left[shortMask] - .data()); - - typename vtype::reg_t temp = vtype::permutevar(reg, perm); - - vtype::mask_storeu(leftStore, left, temp); -} - template void avx2_emu_mask_compressstoreu64(void *base_addr, typename avx2_vector::opmask_t k, @@ -349,30 +240,6 @@ int avx2_double_compressstore32(void *left_addr, return _mm_popcnt_u32(shortMask); } -template -int avx2_double_compressstore32_half(void *left_addr, - void *right_addr, - typename avx2_half_vector::opmask_t k, - typename avx2_half_vector::reg_t reg) -{ - using vtype = avx2_half_vector; - - T *leftStore = (T *)left_addr; - T *rightStore = (T *)right_addr; - - int32_t shortMask = convert_avx2_mask_to_int_half(k); - const __m128i &perm = _mm_loadu_si128( - (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] - .data()); - - typename vtype::reg_t temp = vtype::permutevar(reg, perm); - - vtype::storeu(leftStore, temp); - vtype::storeu(rightStore, temp); - - return _mm_popcnt_u32(shortMask); -} - template int32_t avx2_double_compressstore64(void *left_addr, void *right_addr, diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 32d7419c..be806f5f 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -26,7 +26,6 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; @@ -209,7 +208,6 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; @@ -345,7 +343,6 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index f06cfff0..2d101b88 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -41,7 +41,6 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; @@ -181,7 +180,6 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; @@ -321,7 +319,6 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 3a475da8..c4084c68 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -7,7 +7,713 @@ #ifndef AVX512_ARGSORT_64BIT #define AVX512_ARGSORT_64BIT +#include "xss-common-qsort.h" #include "avx512-64bit-common.h" -#include "xss-common-argsort.h" +#include "xss-network-keyvaluesort.hpp" +#include + +template +X86_SIMD_SORT_INLINE void std_argselect_withnan( + T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) +{ + std::nth_element(arg + left, + arg + k, + arg + right, + [arr](arrsize_t a, arrsize_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] < arr[b]; + } + else if (std::isnan(arr[a])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { + return arr[left] < arr[right]; + } + else if (std::isnan(arr[left])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + // sort indices according to corresponding array element + return arr[left] < arr[right]; + }); +} + +/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of + * undefined template 'zmm_vector'*/ +#ifdef __APPLE__ +using argtype = typename std::conditional, + zmm_vector>::type; +#else +using argtype = typename std::conditional, + zmm_vector>::type; +#endif +using argreg_t = typename argtype::reg_t; + +/* + * Parition one ZMM register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + argtype::mask_compressstoreu( + arg + left, vtype::knot_opmask(gt_mask), arg_vec); + argtype::mask_compressstoreu( + arg + right - amount_gt_pivot, gt_mask, arg_vec); + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + return amount_gt_pivot; +} +/* + * Parition an array based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + argreg_t argvec = argtype::loadu(arg + left); + reg_t vec = vtype::i64gather(arr, arg + left); + int32_t amount_gt_pivot = partition_vec(arg, + left, + left + vtype::numlanes, + argvec, + vec, + pivot_vec, + &min_vec, + &max_vec); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return left + (vtype::numlanes - amount_gt_pivot); + } + + // first and last vtype::numlanes values are partitioned at the end + argreg_t argvec_left = argtype::loadu(arg + left); + reg_t vec_left = vtype::i64gather(arr, arg + left); + argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); + reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec; + reg_t curr_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= vtype::numlanes; + arg_vec = argtype::loadu(arg + right); + curr_vec = vtype::i64gather(arr, arg + right); + } + else { + arg_vec = argtype::loadu(arg + left); + curr_vec = vtype::i64gather(arr, arg + left); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec, + curr_vec, + pivot_vec, + &min_vec, + &max_vec); + ; + r_store -= amount_gt_pivot; + l_store += (vtype::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left, + vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(arg, + l_store, + l_store + vtype::numlanes, + argvec_right, + vec_right, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + if (right - left <= 8 * num_unroll * vtype::numlanes) { + return partition_avx512( + arr, arg, left, right, pivot, smallest, biggest); + } + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; + --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + // first and last vtype::numlanes values are partitioned at the end + reg_t vec_left[num_unroll], vec_right[num_unroll]; + argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); + vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); + argvec_right[ii] = argtype::loadu( + arg + (right - vtype::numlanes * (num_unroll - ii))); + vec_right[ii] = vtype::i64gather( + arr, arg + (right - vtype::numlanes * (num_unroll - ii))); + } + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += num_unroll * vtype::numlanes; + right -= num_unroll * vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec[num_unroll]; + reg_t curr_vec[num_unroll]; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= num_unroll * vtype::numlanes; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] + = argtype::loadu(arg + right + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + right + ii * vtype::numlanes); + } + } + else { + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + left + ii * vtype::numlanes); + } + left += num_unroll * vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec[ii], + curr_vec[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + } + + /* partition and save vec_left and vec_right */ + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left[ii], + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_right[ii], + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE void +argsort_8_64bit(type_t *arr, arrsize_t *arg, int32_t N) +{ + using reg_t = typename vtype::reg_t; + typename vtype::opmask_t load_mask = (0x01 << N) - 0x01; + argreg_t argzmm = argtype::maskz_loadu(load_mask, arg); + reg_t arrzmm = vtype::template mask_i64gather( + vtype::zmm_max(), load_mask, argzmm, arr); + arrzmm = sort_zmm_64bit(arrzmm, argzmm); + argtype::mask_storeu(arg, load_mask, argzmm); +} + +template +X86_SIMD_SORT_INLINE void +argsort_16_64bit(type_t *arr, arrsize_t *arg, int32_t N) +{ + if (N <= 8) { + argsort_8_64bit(arr, arg, N); + return; + } + using reg_t = typename vtype::reg_t; + typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; + argreg_t argzmm1 = argtype::loadu(arg); + argreg_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8); + reg_t arrzmm1 = vtype::i64gather(arr, arg); + reg_t arrzmm2 = vtype::template mask_i64gather( + vtype::zmm_max(), load_mask, argzmm2, arr); + arrzmm1 = sort_zmm_64bit(arrzmm1, argzmm1); + arrzmm2 = sort_zmm_64bit(arrzmm2, argzmm2); + bitonic_merge_two_zmm_64bit( + arrzmm1, arrzmm2, argzmm1, argzmm2); + argtype::storeu(arg, argzmm1); + argtype::mask_storeu(arg + 8, load_mask, argzmm2); +} + +template +X86_SIMD_SORT_INLINE void +argsort_32_64bit(type_t *arr, arrsize_t *arg, int32_t N) +{ + if (N <= 16) { + argsort_16_64bit(arr, arg, N); + return; + } + using reg_t = typename vtype::reg_t; + using opmask_t = typename vtype::opmask_t; + reg_t arrzmm[4]; + argreg_t argzmm[4]; + + X86_SIMD_SORT_UNROLL_LOOP(2) + for (int ii = 0; ii < 2; ++ii) { + argzmm[ii] = argtype::loadu(arg + 8 * ii); + arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); + arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); + } + + uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; + opmask_t load_mask[2] = {0xFF, 0xFF}; + X86_SIMD_SORT_UNROLL_LOOP(2) + for (int ii = 0; ii < 2; ++ii) { + load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF; + argzmm[ii + 2] = argtype::maskz_loadu(load_mask[ii], arg + 16 + 8 * ii); + arrzmm[ii + 2] = vtype::template mask_i64gather( + vtype::zmm_max(), load_mask[ii], argzmm[ii + 2], arr); + arrzmm[ii + 2] = sort_zmm_64bit(arrzmm[ii + 2], + argzmm[ii + 2]); + } + + bitonic_merge_two_zmm_64bit( + arrzmm[0], arrzmm[1], argzmm[0], argzmm[1]); + bitonic_merge_two_zmm_64bit( + arrzmm[2], arrzmm[3], argzmm[2], argzmm[3]); + bitonic_merge_four_zmm_64bit(arrzmm, argzmm); + + argtype::storeu(arg, argzmm[0]); + argtype::storeu(arg + 8, argzmm[1]); + argtype::mask_storeu(arg + 16, load_mask[0], argzmm[2]); + argtype::mask_storeu(arg + 24, load_mask[1], argzmm[3]); +} + +template +X86_SIMD_SORT_INLINE void +argsort_64_64bit(type_t *arr, arrsize_t *arg, int32_t N) +{ + if (N <= 32) { + argsort_32_64bit(arr, arg, N); + return; + } + using reg_t = typename vtype::reg_t; + using opmask_t = typename vtype::opmask_t; + reg_t arrzmm[8]; + argreg_t argzmm[8]; + + X86_SIMD_SORT_UNROLL_LOOP(4) + for (int ii = 0; ii < 4; ++ii) { + argzmm[ii] = argtype::loadu(arg + 8 * ii); + arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii); + arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); + } + + opmask_t load_mask[4] = {0xFF, 0xFF, 0xFF, 0xFF}; + uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; + X86_SIMD_SORT_UNROLL_LOOP(4) + for (int ii = 0; ii < 4; ++ii) { + load_mask[ii] = (combined_mask >> (ii * 8)) & 0xFF; + argzmm[ii + 4] = argtype::maskz_loadu(load_mask[ii], arg + 32 + 8 * ii); + arrzmm[ii + 4] = vtype::template mask_i64gather( + vtype::zmm_max(), load_mask[ii], argzmm[ii + 4], arr); + arrzmm[ii + 4] = sort_zmm_64bit(arrzmm[ii + 4], + argzmm[ii + 4]); + } + + X86_SIMD_SORT_UNROLL_LOOP(4) + for (int ii = 0; ii < 8; ii = ii + 2) { + bitonic_merge_two_zmm_64bit( + arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); + } + bitonic_merge_four_zmm_64bit(arrzmm, argzmm); + bitonic_merge_four_zmm_64bit(arrzmm + 4, argzmm + 4); + bitonic_merge_eight_zmm_64bit(arrzmm, argzmm); + + X86_SIMD_SORT_UNROLL_LOOP(4) + for (int ii = 0; ii < 4; ++ii) { + argtype::storeu(arg + 8 * ii, argzmm[ii]); + } + X86_SIMD_SORT_UNROLL_LOOP(4) + for (int ii = 0; ii < 4; ++ii) { + argtype::mask_storeu(arg + 32 + 8 * ii, load_mask[ii], argzmm[ii + 4]); + } +} + +/* arsort 128 doesn't seem to make much of a difference to perf*/ +//template +//X86_SIMD_SORT_INLINE void +//argsort_128_64bit(type_t *arr, arrsize_t *arg, int32_t N) +//{ +// if (N <= 64) { +// argsort_64_64bit(arr, arg, N); +// return; +// } +// using reg_t = typename vtype::reg_t; +// using opmask_t = typename vtype::opmask_t; +// reg_t arrzmm[16]; +// argreg_t argzmm[16]; +// +//X86_SIMD_SORT_UNROLL_LOOP(8) +// for (int ii = 0; ii < 8; ++ii) { +// argzmm[ii] = argtype::loadu(arg + 8*ii); +// arrzmm[ii] = vtype::i64gather(argzmm[ii], arr); +// arrzmm[ii] = sort_zmm_64bit(arrzmm[ii], argzmm[ii]); +// } +// +// opmask_t load_mask[8] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; +// if (N != 128) { +// uarrsize_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; +//X86_SIMD_SORT_UNROLL_LOOP(8) +// for (int ii = 0; ii < 8; ++ii) { +// load_mask[ii] = (combined_mask >> (ii*8)) & 0xFF; +// } +// } +//X86_SIMD_SORT_UNROLL_LOOP(8) +// for (int ii = 0; ii < 8; ++ii) { +// argzmm[ii+8] = argtype::maskz_loadu(load_mask[ii], arg + 64 + 8*ii); +// arrzmm[ii+8] = vtype::template mask_i64gather(vtype::zmm_max(), load_mask[ii], argzmm[ii+8], arr); +// arrzmm[ii+8] = sort_zmm_64bit(arrzmm[ii+8], argzmm[ii+8]); +// } +// +//X86_SIMD_SORT_UNROLL_LOOP(8) +// for (int ii = 0; ii < 16; ii = ii + 2) { +// bitonic_merge_two_zmm_64bit(arrzmm[ii], arrzmm[ii + 1], argzmm[ii], argzmm[ii + 1]); +// } +// bitonic_merge_four_zmm_64bit(arrzmm, argzmm); +// bitonic_merge_four_zmm_64bit(arrzmm + 4, argzmm + 4); +// bitonic_merge_four_zmm_64bit(arrzmm + 8, argzmm + 8); +// bitonic_merge_four_zmm_64bit(arrzmm + 12, argzmm + 12); +// bitonic_merge_eight_zmm_64bit(arrzmm, argzmm); +// bitonic_merge_eight_zmm_64bit(arrzmm+8, argzmm+8); +// bitonic_merge_sixteen_zmm_64bit(arrzmm, argzmm); +// +//X86_SIMD_SORT_UNROLL_LOOP(8) +// for (int ii = 0; ii < 8; ++ii) { +// argtype::storeu(arg + 8*ii, argzmm[ii]); +// } +//X86_SIMD_SORT_UNROLL_LOOP(8) +// for (int ii = 0; ii < 8; ++ii) { +// argtype::mask_storeu(arg + 64 + 8*ii, load_mask[ii], argzmm[ii + 8]); +// } +//} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, + arrsize_t *arg, + const arrsize_t left, + const arrsize_t right) +{ + if (right - left >= vtype::numlanes) { + // median of 8 + arrsize_t size = (right - left) / 8; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]], + arr[arg[left + 5 * size]], + arr[arg[left + 6 * size]], + arr[arg[left + 7 * size]], + arr[arg[left + 8 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = sort_zmm_64bit(rand_vec); + return ((type_t *)&sort)[4]; + } + else { + return arr[arg[right]]; + } +} + +template +X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 64) { + argsort_64_64bit(arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = partition_avx512_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if (pivot != smallest) + argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); +} + +template +X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, + arrsize_t *arg, + arrsize_t pos, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 64) { + argsort_64_64bit(arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = partition_avx512_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if ((pivot != smallest) && (pos < pivot_index)) + argselect_64bit_( + arr, arg, pos, left, pivot_index - 1, max_iters - 1); + else if ((pivot != biggest) && (pos >= pivot_index)) + argselect_64bit_( + arr, arg, pos, pivot_index, right, max_iters - 1); +} + +/* argsort methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void +avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) +{ + using vectype = typename std::conditional, + zmm_vector>::type; + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argsort_withnan(arr, arg, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argsort_64bit_( + arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argsort(arr, indices.data(), arrsize, hasnan); + return indices; +} + +/* argselect methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + using vectype = typename std::conditional, + zmm_vector>::type; + + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argselect_64bit_( + arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx512_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argselect(arr, indices.data(), k, arrsize, hasnan); + return indices; +} + +/* To maintain compatibility with NumPy build */ +template +X86_SIMD_SORT_INLINE void +avx512_argselect(T *arr, int64_t *arg, arrsize_t k, arrsize_t arrsize) +{ + avx512_argselect(arr, reinterpret_cast(arg), k, arrsize); +} + +template +X86_SIMD_SORT_INLINE void +avx512_argsort(T *arr, int64_t *arg, arrsize_t arrsize) +{ + avx512_argsort(arr, reinterpret_cast(arg), arrsize); +} #endif // AVX512_ARGSORT_64BIT diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 65ee85db..909f3b2b 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -8,7 +8,6 @@ #define AVX512_64BIT_COMMON #include "xss-common-includes.h" -#include "avx2-32bit-qsort.hpp" /* * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic @@ -33,7 +32,6 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; - static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -87,10 +85,6 @@ struct ymm_vector { { return ((0x1ull << num_to_read) - 0x1ull); } - static int32_t convert_mask_to_int(opmask_t mask) - { - return mask; - } template static opmask_t fpclass(reg_t x) { @@ -200,19 +194,6 @@ struct ymm_vector { { _mm256_storeu_ps((float *)mem, x); } - static reg_t cast_from(__m256i v) - { - return _mm256_castsi256_ps(v); - } - static __m256i cast_to(reg_t v) - { - return _mm256_castps_si256(v); - } - static reg_t reverse(reg_t ymm) - { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); - return permutexvar(rev_index, ymm); - } }; template <> struct ymm_vector { @@ -221,7 +202,6 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; - static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -374,19 +354,6 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } - static reg_t cast_from(__m256i v) - { - return v; - } - static __m256i cast_to(reg_t v) - { - return v; - } - static reg_t reverse(reg_t ymm) - { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); - return permutexvar(rev_index, ymm); - } }; template <> struct ymm_vector { @@ -395,7 +362,6 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; - static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -548,19 +514,6 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } - static reg_t cast_from(__m256i v) - { - return v; - } - static __m256i cast_to(reg_t v) - { - return v; - } - static reg_t reverse(reg_t ymm) - { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); - return permutexvar(rev_index, ymm); - } }; template <> struct zmm_vector { @@ -576,7 +529,6 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; @@ -755,7 +707,6 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; @@ -926,7 +877,6 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; - static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index 1f446c68..55f79bb1 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -182,6 +182,356 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t1 *keys, return l_store; } +template +X86_SIMD_SORT_INLINE void +sort_8_64bit(type1_t *keys, type2_t *indexes, int32_t N) +{ + typename vtype1::opmask_t load_mask = (0x01 << N) - 0x01; + typename vtype1::reg_t key_zmm + = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys); + + typename vtype2::reg_t index_zmm + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes); + vtype1::mask_storeu(keys, + load_mask, + sort_zmm_64bit(key_zmm, index_zmm)); + vtype2::mask_storeu(indexes, load_mask, index_zmm); +} + +template +X86_SIMD_SORT_INLINE void +sort_16_64bit(type1_t *keys, type2_t *indexes, int32_t N) +{ + if (N <= 8) { + sort_8_64bit(keys, indexes, N); + return; + } + using reg_t = typename vtype1::reg_t; + using index_type = typename vtype2::reg_t; + + typename vtype1::opmask_t load_mask = (0x01 << (N - 8)) - 0x01; + + reg_t key_zmm1 = vtype1::loadu(keys); + reg_t key_zmm2 = vtype1::mask_loadu(vtype1::zmm_max(), load_mask, keys + 8); + + index_type index_zmm1 = vtype2::loadu(indexes); + index_type index_zmm2 + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask, indexes + 8); + + key_zmm1 = sort_zmm_64bit(key_zmm1, index_zmm1); + key_zmm2 = sort_zmm_64bit(key_zmm2, index_zmm2); + bitonic_merge_two_zmm_64bit( + key_zmm1, key_zmm2, index_zmm1, index_zmm2); + + vtype2::storeu(indexes, index_zmm1); + vtype2::mask_storeu(indexes + 8, load_mask, index_zmm2); + + vtype1::storeu(keys, key_zmm1); + vtype1::mask_storeu(keys + 8, load_mask, key_zmm2); +} + +template +X86_SIMD_SORT_INLINE void +sort_32_64bit(type1_t *keys, type2_t *indexes, int32_t N) +{ + if (N <= 16) { + sort_16_64bit(keys, indexes, N); + return; + } + using reg_t = typename vtype1::reg_t; + using opmask_t = typename vtype2::opmask_t; + using index_type = typename vtype2::reg_t; + reg_t key_zmm[4]; + index_type index_zmm[4]; + + key_zmm[0] = vtype1::loadu(keys); + key_zmm[1] = vtype1::loadu(keys + 8); + + index_zmm[0] = vtype2::loadu(indexes); + index_zmm[1] = vtype2::loadu(indexes + 8); + + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + uint64_t combined_mask = (0x1ull << (N - 16)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + key_zmm[2] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 16); + key_zmm[3] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 24); + + index_zmm[2] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 16); + index_zmm[3] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 24); + + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + + vtype2::storeu(indexes, index_zmm[0]); + vtype2::storeu(indexes + 8, index_zmm[1]); + vtype2::mask_storeu(indexes + 16, load_mask1, index_zmm[2]); + vtype2::mask_storeu(indexes + 24, load_mask2, index_zmm[3]); + + vtype1::storeu(keys, key_zmm[0]); + vtype1::storeu(keys + 8, key_zmm[1]); + vtype1::mask_storeu(keys + 16, load_mask1, key_zmm[2]); + vtype1::mask_storeu(keys + 24, load_mask2, key_zmm[3]); +} + +template +X86_SIMD_SORT_INLINE void +sort_64_64bit(type1_t *keys, type2_t *indexes, int32_t N) +{ + if (N <= 32) { + sort_32_64bit(keys, indexes, N); + return; + } + using reg_t = typename vtype1::reg_t; + using opmask_t = typename vtype1::opmask_t; + using index_type = typename vtype2::reg_t; + reg_t key_zmm[8]; + index_type index_zmm[8]; + + key_zmm[0] = vtype1::loadu(keys); + key_zmm[1] = vtype1::loadu(keys + 8); + key_zmm[2] = vtype1::loadu(keys + 16); + key_zmm[3] = vtype1::loadu(keys + 24); + + index_zmm[0] = vtype2::loadu(indexes); + index_zmm[1] = vtype2::loadu(indexes + 8); + index_zmm[2] = vtype2::loadu(indexes + 16); + index_zmm[3] = vtype2::loadu(indexes + 24); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; + // N-32 >= 1 + uint64_t combined_mask = (0x1ull << (N - 32)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + load_mask3 = (combined_mask >> 16) & 0xFF; + load_mask4 = (combined_mask >> 24) & 0xFF; + key_zmm[4] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 32); + key_zmm[5] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 40); + key_zmm[6] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 48); + key_zmm[7] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 56); + + index_zmm[4] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 32); + index_zmm[5] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 40); + index_zmm[6] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 48); + index_zmm[7] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 56); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit( + key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + + vtype2::storeu(indexes, index_zmm[0]); + vtype2::storeu(indexes + 8, index_zmm[1]); + vtype2::storeu(indexes + 16, index_zmm[2]); + vtype2::storeu(indexes + 24, index_zmm[3]); + vtype2::mask_storeu(indexes + 32, load_mask1, index_zmm[4]); + vtype2::mask_storeu(indexes + 40, load_mask2, index_zmm[5]); + vtype2::mask_storeu(indexes + 48, load_mask3, index_zmm[6]); + vtype2::mask_storeu(indexes + 56, load_mask4, index_zmm[7]); + + vtype1::storeu(keys, key_zmm[0]); + vtype1::storeu(keys + 8, key_zmm[1]); + vtype1::storeu(keys + 16, key_zmm[2]); + vtype1::storeu(keys + 24, key_zmm[3]); + vtype1::mask_storeu(keys + 32, load_mask1, key_zmm[4]); + vtype1::mask_storeu(keys + 40, load_mask2, key_zmm[5]); + vtype1::mask_storeu(keys + 48, load_mask3, key_zmm[6]); + vtype1::mask_storeu(keys + 56, load_mask4, key_zmm[7]); +} + +template +X86_SIMD_SORT_INLINE void +sort_128_64bit(type1_t *keys, type2_t *indexes, int32_t N) +{ + if (N <= 64) { + sort_64_64bit(keys, indexes, N); + return; + } + using reg_t = typename vtype1::reg_t; + using index_type = typename vtype2::reg_t; + using opmask_t = typename vtype1::opmask_t; + reg_t key_zmm[16]; + index_type index_zmm[16]; + + key_zmm[0] = vtype1::loadu(keys); + key_zmm[1] = vtype1::loadu(keys + 8); + key_zmm[2] = vtype1::loadu(keys + 16); + key_zmm[3] = vtype1::loadu(keys + 24); + key_zmm[4] = vtype1::loadu(keys + 32); + key_zmm[5] = vtype1::loadu(keys + 40); + key_zmm[6] = vtype1::loadu(keys + 48); + key_zmm[7] = vtype1::loadu(keys + 56); + + index_zmm[0] = vtype2::loadu(indexes); + index_zmm[1] = vtype2::loadu(indexes + 8); + index_zmm[2] = vtype2::loadu(indexes + 16); + index_zmm[3] = vtype2::loadu(indexes + 24); + index_zmm[4] = vtype2::loadu(indexes + 32); + index_zmm[5] = vtype2::loadu(indexes + 40); + index_zmm[6] = vtype2::loadu(indexes + 48); + index_zmm[7] = vtype2::loadu(indexes + 56); + key_zmm[0] = sort_zmm_64bit(key_zmm[0], index_zmm[0]); + key_zmm[1] = sort_zmm_64bit(key_zmm[1], index_zmm[1]); + key_zmm[2] = sort_zmm_64bit(key_zmm[2], index_zmm[2]); + key_zmm[3] = sort_zmm_64bit(key_zmm[3], index_zmm[3]); + key_zmm[4] = sort_zmm_64bit(key_zmm[4], index_zmm[4]); + key_zmm[5] = sort_zmm_64bit(key_zmm[5], index_zmm[5]); + key_zmm[6] = sort_zmm_64bit(key_zmm[6], index_zmm[6]); + key_zmm[7] = sort_zmm_64bit(key_zmm[7], index_zmm[7]); + + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; + opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; + opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; + if (N != 128) { + uint64_t combined_mask = (0x1ull << (N - 64)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + load_mask3 = (combined_mask >> 16) & 0xFF; + load_mask4 = (combined_mask >> 24) & 0xFF; + load_mask5 = (combined_mask >> 32) & 0xFF; + load_mask6 = (combined_mask >> 40) & 0xFF; + load_mask7 = (combined_mask >> 48) & 0xFF; + load_mask8 = (combined_mask >> 56) & 0xFF; + } + key_zmm[8] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask1, keys + 64); + key_zmm[9] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask2, keys + 72); + key_zmm[10] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask3, keys + 80); + key_zmm[11] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask4, keys + 88); + key_zmm[12] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask5, keys + 96); + key_zmm[13] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask6, keys + 104); + key_zmm[14] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask7, keys + 112); + key_zmm[15] = vtype1::mask_loadu(vtype1::zmm_max(), load_mask8, keys + 120); + + index_zmm[8] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask1, indexes + 64); + index_zmm[9] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask2, indexes + 72); + index_zmm[10] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask3, indexes + 80); + index_zmm[11] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask4, indexes + 88); + index_zmm[12] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask5, indexes + 96); + index_zmm[13] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask6, indexes + 104); + index_zmm[14] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask7, indexes + 112); + index_zmm[15] + = vtype2::mask_loadu(vtype2::zmm_max(), load_mask8, indexes + 120); + key_zmm[8] = sort_zmm_64bit(key_zmm[8], index_zmm[8]); + key_zmm[9] = sort_zmm_64bit(key_zmm[9], index_zmm[9]); + key_zmm[10] = sort_zmm_64bit(key_zmm[10], index_zmm[10]); + key_zmm[11] = sort_zmm_64bit(key_zmm[11], index_zmm[11]); + key_zmm[12] = sort_zmm_64bit(key_zmm[12], index_zmm[12]); + key_zmm[13] = sort_zmm_64bit(key_zmm[13], index_zmm[13]); + key_zmm[14] = sort_zmm_64bit(key_zmm[14], index_zmm[14]); + key_zmm[15] = sort_zmm_64bit(key_zmm[15], index_zmm[15]); + + bitonic_merge_two_zmm_64bit( + key_zmm[0], key_zmm[1], index_zmm[0], index_zmm[1]); + bitonic_merge_two_zmm_64bit( + key_zmm[2], key_zmm[3], index_zmm[2], index_zmm[3]); + bitonic_merge_two_zmm_64bit( + key_zmm[4], key_zmm[5], index_zmm[4], index_zmm[5]); + bitonic_merge_two_zmm_64bit( + key_zmm[6], key_zmm[7], index_zmm[6], index_zmm[7]); + bitonic_merge_two_zmm_64bit( + key_zmm[8], key_zmm[9], index_zmm[8], index_zmm[9]); + bitonic_merge_two_zmm_64bit( + key_zmm[10], key_zmm[11], index_zmm[10], index_zmm[11]); + bitonic_merge_two_zmm_64bit( + key_zmm[12], key_zmm[13], index_zmm[12], index_zmm[13]); + bitonic_merge_two_zmm_64bit( + key_zmm[14], key_zmm[15], index_zmm[14], index_zmm[15]); + bitonic_merge_four_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_four_zmm_64bit(key_zmm + 4, index_zmm + 4); + bitonic_merge_four_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_four_zmm_64bit(key_zmm + 12, index_zmm + 12); + bitonic_merge_eight_zmm_64bit(key_zmm, index_zmm); + bitonic_merge_eight_zmm_64bit(key_zmm + 8, index_zmm + 8); + bitonic_merge_sixteen_zmm_64bit(key_zmm, index_zmm); + vtype2::storeu(indexes, index_zmm[0]); + vtype2::storeu(indexes + 8, index_zmm[1]); + vtype2::storeu(indexes + 16, index_zmm[2]); + vtype2::storeu(indexes + 24, index_zmm[3]); + vtype2::storeu(indexes + 32, index_zmm[4]); + vtype2::storeu(indexes + 40, index_zmm[5]); + vtype2::storeu(indexes + 48, index_zmm[6]); + vtype2::storeu(indexes + 56, index_zmm[7]); + vtype2::mask_storeu(indexes + 64, load_mask1, index_zmm[8]); + vtype2::mask_storeu(indexes + 72, load_mask2, index_zmm[9]); + vtype2::mask_storeu(indexes + 80, load_mask3, index_zmm[10]); + vtype2::mask_storeu(indexes + 88, load_mask4, index_zmm[11]); + vtype2::mask_storeu(indexes + 96, load_mask5, index_zmm[12]); + vtype2::mask_storeu(indexes + 104, load_mask6, index_zmm[13]); + vtype2::mask_storeu(indexes + 112, load_mask7, index_zmm[14]); + vtype2::mask_storeu(indexes + 120, load_mask8, index_zmm[15]); + + vtype1::storeu(keys, key_zmm[0]); + vtype1::storeu(keys + 8, key_zmm[1]); + vtype1::storeu(keys + 16, key_zmm[2]); + vtype1::storeu(keys + 24, key_zmm[3]); + vtype1::storeu(keys + 32, key_zmm[4]); + vtype1::storeu(keys + 40, key_zmm[5]); + vtype1::storeu(keys + 48, key_zmm[6]); + vtype1::storeu(keys + 56, key_zmm[7]); + vtype1::mask_storeu(keys + 64, load_mask1, key_zmm[8]); + vtype1::mask_storeu(keys + 72, load_mask2, key_zmm[9]); + vtype1::mask_storeu(keys + 80, load_mask3, key_zmm[10]); + vtype1::mask_storeu(keys + 88, load_mask4, key_zmm[11]); + vtype1::mask_storeu(keys + 96, load_mask5, key_zmm[12]); + vtype1::mask_storeu(keys + 104, load_mask6, key_zmm[13]); + vtype1::mask_storeu(keys + 112, load_mask7, key_zmm[14]); + vtype1::mask_storeu(keys + 120, load_mask8, key_zmm[15]); +} + template ( + sort_128_64bit( keys + left, indexes + left, (int32_t)(right + 1 - left)); return; } diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index f44209fa..21958027 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -23,7 +23,6 @@ struct zmm_vector<_Float16> { static const uint8_t numlanes = 32; static constexpr int network_sort_threshold = 128; static constexpr int partition_unroll_factor = 0; - static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h deleted file mode 100644 index 67aa2002..00000000 --- a/src/xss-common-argsort.h +++ /dev/null @@ -1,705 +0,0 @@ -/******************************************************************* - * Copyright (C) 2022 Intel Corporation - * SPDX-License-Identifier: BSD-3-Clause - * Authors: Raghuveer Devulapalli - * ****************************************************************/ - -#ifndef XSS_COMMON_ARGSORT -#define XSS_COMMON_ARGSORT - -#include "xss-common-qsort.h" -#include "xss-network-keyvaluesort.hpp" -#include - -template -X86_SIMD_SORT_INLINE void std_argselect_withnan( - T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) -{ - std::nth_element(arg + left, - arg + k, - arg + right, - [arr](arrsize_t a, arrsize_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { - return arr[a] < arr[b]; - } - else if (std::isnan(arr[a])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { - return arr[left] < arr[right]; - } - else if (std::isnan(arr[left])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - // sort indices according to corresponding array element - return arr[left] < arr[right]; - }); -} - -/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of - * undefined template 'zmm_vector'*/ -#ifdef __APPLE__ -using argtypeAVX512 = - typename std::conditional, - zmm_vector>::type; -#else -using argtypeAVX512 = - typename std::conditional, - zmm_vector>::type; -#endif - -/* - * Parition one ZMM register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE int32_t partition_vec_avx512(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); - int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - argtype::mask_compressstoreu( - arg + left, vtype::knot_opmask(gt_mask), arg_vec); - argtype::mask_compressstoreu( - arg + right - amount_gt_pivot, gt_mask, arg_vec); - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); - return amount_gt_pivot; -} - -/* - * Parition one AVX2 register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec); - typename argtype::opmask_t ge_mask - = extend_mask(ge_mask_vtype); - - auto l_store = arg + left; - auto r_store = arg + right - vtype::numlanes; - - int amount_ge_pivot - = argtype::double_compressstore(l_store, r_store, ge_mask, arg_vec); - - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); - - return amount_ge_pivot; -} - -template -X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - if constexpr (vtype::vec_type == simd_type::AVX512) { - return partition_vec_avx512(arg, - left, - right, - arg_vec, - curr_vec, - pivot_vec, - smallest_vec, - biggest_vec); - } - else if constexpr (vtype::vec_type == simd_type::AVX2) { - return partition_vec_avx2(arg, - left, - right, - arg_vec, - curr_vec, - pivot_vec, - smallest_vec, - biggest_vec); - } - else { - static_assert(sizeof(argreg_t) == 0, "Should not get here"); - } -} - -/* - * Parition an array based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using reg_t = typename vtype::reg_t; - using argreg_t = typename argtype::reg_t; - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - if (right - left == vtype::numlanes) { - argreg_t argvec = argtype::loadu(arg + left); - reg_t vec = vtype::i64gather(arr, arg + left); - int32_t amount_gt_pivot - = partition_vec(arg, - left, - left + vtype::numlanes, - argvec, - vec, - pivot_vec, - &min_vec, - &max_vec); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return left + (vtype::numlanes - amount_gt_pivot); - } - - // first and last vtype::numlanes values are partitioned at the end - argreg_t argvec_left = argtype::loadu(arg + left); - reg_t vec_left = vtype::i64gather(arr, arg + left); - argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); - reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += vtype::numlanes; - right -= vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec; - reg_t curr_vec; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= vtype::numlanes; - arg_vec = argtype::loadu(arg + right); - curr_vec = vtype::i64gather(arr, arg + right); - } - else { - arg_vec = argtype::loadu(arg + left); - curr_vec = vtype::i64gather(arr, arg + left); - left += vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec, - curr_vec, - pivot_vec, - &min_vec, - &max_vec); - ; - r_store -= amount_gt_pivot; - l_store += (vtype::numlanes - amount_gt_pivot); - } - - /* partition and save vec_left and vec_right */ - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left, - vec_left, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(arg, - l_store, - l_store + vtype::numlanes, - argvec_right, - vec_right, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - if (right - left <= 8 * num_unroll * vtype::numlanes) { - return partition_avx512( - arr, arg, left, right, pivot, smallest, biggest); - } - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; - --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using reg_t = typename vtype::reg_t; - using argreg_t = typename argtype::reg_t; - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - // first and last vtype::numlanes values are partitioned at the end - reg_t vec_left[num_unroll], vec_right[num_unroll]; - argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); - vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); - argvec_right[ii] = argtype::loadu( - arg + (right - vtype::numlanes * (num_unroll - ii))); - vec_right[ii] = vtype::i64gather( - arr, arg + (right - vtype::numlanes * (num_unroll - ii))); - } - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += num_unroll * vtype::numlanes; - right -= num_unroll * vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec[num_unroll]; - reg_t curr_vec[num_unroll]; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= num_unroll * vtype::numlanes; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] - = argtype::loadu(arg + right + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + right + ii * vtype::numlanes); - } - } - else { - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + left + ii * vtype::numlanes); - } - left += num_unroll * vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec[ii], - curr_vec[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - } - - /* partition and save vec_left and vec_right */ - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left[ii], - vec_left[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_right[ii], - vec_right[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, - arrsize_t *arg, - const arrsize_t left, - const arrsize_t right) -{ - if constexpr (vtype::numlanes == 8) { - if (right - left >= vtype::numlanes) { - // median of 8 - arrsize_t size = (right - left) / 8; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[arg[left + size]], - arr[arg[left + 2 * size]], - arr[arg[left + 3 * size]], - arr[arg[left + 4 * size]], - arr[arg[left + 5 * size]], - arr[arg[left + 6 * size]], - arr[arg[left + 7 * size]], - arr[arg[left + 8 * size]]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = sort_zmm_64bit(rand_vec); - return ((type_t *)&sort)[4]; - } - else { - return arr[arg[right]]; - } - } - else if constexpr (vtype::numlanes == 4) { - if (right - left >= vtype::numlanes) { - // median of 4 - arrsize_t size = (right - left) / 4; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[arg[left + size]], - arr[arg[left + 2 * size]], - arr[arg[left + 3 * size]], - arr[arg[left + 4 * size]]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[2]; - } - else { - return arr[arg[right]]; - } - } -} - -template -X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) -{ - using argtype = typename std::conditional, - zmm_vector>::type; - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 256) { - argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); - if (pivot != smallest) - argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) - argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); -} - -template -X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, - arrsize_t *arg, - arrsize_t pos, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) -{ - using argtype = typename std::conditional, - zmm_vector>::type; - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 256) { - argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - argselect_64bit_( - arr, arg, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_64bit_( - arr, arg, pos, pivot_index, right, max_iters - 1); -} - -/* argsort methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void -avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) -{ - using vectype = typename std::conditional, - zmm_vector>::type; - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argsort_withnan(arr, arg, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argsort_64bit_( - arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argsort(arr, indices.data(), arrsize, hasnan); - return indices; -} - -/* argsort methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void -avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) -{ - using vectype = typename std::conditional, - avx2_vector>::type; - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argsort_withnan(arr, arg, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argsort_64bit_( - arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx2_argsort(arr, indices.data(), arrsize, hasnan); - return indices; -} - -/* argselect methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - using vectype = typename std::conditional, - zmm_vector>::type; - - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argselect_64bit_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx512_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argselect(arr, indices.data(), k, arrsize, hasnan); - return indices; -} - -/* argselect methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - using vectype = typename std::conditional, - avx2_vector>::type; - - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argselect_64bit_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx2_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx2_argselect(arr, indices.data(), k, arrsize, hasnan); - return indices; -} - -/* To maintain compatibility with NumPy build */ -template -X86_SIMD_SORT_INLINE void -avx512_argselect(T *arr, int64_t *arg, arrsize_t k, arrsize_t arrsize) -{ - avx512_argselect(arr, reinterpret_cast(arg), k, arrsize); -} - -template -X86_SIMD_SORT_INLINE void -avx512_argsort(T *arr, int64_t *arg, arrsize_t arrsize) -{ - avx512_argsort(arr, reinterpret_cast(arg), arrsize); -} - -#endif // XSS_COMMON_ARGSORT diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index 9f793e37..c373ba54 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -82,9 +82,4 @@ struct ymm_vector; template struct avx2_vector; -template -struct avx2_half_vector; - -enum class simd_type : int { AVX2, AVX512 }; - #endif // XSS_COMMON_INCLUDES diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 097efceb..e76d9f6a 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -87,8 +87,7 @@ X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size) else { in = vtype::loadu(arr + ii); } - auto nanmask = vtype::convert_mask_to_int( - vtype::template fpclass<0x01 | 0x80>(in)); + opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); if (nanmask != 0x00) { found_nan = true; break; diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index 334cb560..cec1cb7a 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -1,34 +1,5 @@ -#ifndef XSS_KEYVALUE_NETWORKS -#define XSS_KEYVALUE_NETWORKS - -#include "xss-common-includes.h" - -template -struct index_64bit_vector_type; -template <> -struct index_64bit_vector_type<8> { - using type = zmm_vector; -}; -template <> -struct index_64bit_vector_type<4> { - using type = avx2_vector; -}; - -template -typename valueType::opmask_t extend_mask(typename keyType::opmask_t mask) -{ - if constexpr (keyType::vec_type == simd_type::AVX512) { return mask; } - else if constexpr (keyType::vec_type == simd_type::AVX2) { - if constexpr (sizeof(mask) == 32) { return mask; } - else { - return _mm256_cvtepi32_epi64(mask); - } - } - else { - static_assert(keyType::vec_type == simd_type::AVX512, - "Should not reach here"); - } -} +#ifndef AVX512_KEYVALUE_NETWORKS +#define AVX512_KEYVALUE_NETWORKS template (vtype1::eq(key_t1, key1)); - - reg_t2 index_t1 = vtype2::mask_mov(index2, eqMask, index1); - reg_t2 index_t2 = vtype2::mask_mov(index1, eqMask, index2); + reg_t2 index_t1 + = vtype2::mask_mov(index2, vtype1::eq(key_t1, key1), index1); + reg_t2 index_t2 + = vtype2::mask_mov(index1, vtype1::eq(key_t1, key1), index2); key1 = key_t1; key2 = key_t2; @@ -63,24 +34,10 @@ X86_SIMD_SORT_INLINE reg_t1 cmp_merge(reg_t1 in1, opmask_t mask) { reg_t1 tmp_keys = cmp_merge(in1, in2, mask); - indexes1 = vtype2::mask_mov( - indexes2, - extend_mask(vtype1::eq(tmp_keys, in1)), - indexes1); + indexes1 = vtype2::mask_mov(indexes2, vtype1::eq(tmp_keys, in1), indexes1); return tmp_keys; // 0 -> min, 1 -> max } -/* - * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -// ZMM 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 - template -X86_SIMD_SORT_INLINE reg_t sort_ymm_64bit(reg_t key_zmm, index_type &index_zmm) -{ - using key_swizzle = typename vtype1::swizzle_ops; - using index_swizzle = typename vtype2::swizzle_ops; - - const typename vtype1::opmask_t oxAA = vtype1::seti(-1, 0, -1, 0); - const typename vtype1::opmask_t oxCC = vtype1::seti(-1, -1, 0, 0); - - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxAA); - key_zmm = cmp_merge(key_zmm, - vtype1::reverse(key_zmm), - index_zmm, - vtype2::reverse(index_zmm), - oxCC); - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxAA); - return key_zmm; -} - // Assumes zmm is bitonic and performs a recursive half cleaner template -X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_64bit(reg_t key_zmm, - index_type &index_zmm) -{ - using key_swizzle = typename vtype1::swizzle_ops; - using index_swizzle = typename vtype2::swizzle_ops; - - const typename vtype1::opmask_t oxAA = vtype1::seti(-1, 0, -1, 0); - const typename vtype1::opmask_t oxCC = vtype1::seti(-1, -1, 0, 0); - - // 2) half_cleaner[4] - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxCC); - // 3) half_cleaner[1] - key_zmm = cmp_merge( - key_zmm, - key_swizzle::template swap_n(key_zmm), - index_zmm, - index_swizzle::template swap_n(index_zmm), - oxAA); - return key_zmm; -} - -template -X86_SIMD_SORT_INLINE void -bitonic_merge_dispatch(typename keyType::reg_t &key, - typename valueType::reg_t &value) +X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_64bit(reg_t &key_zmm1, + reg_t &key_zmm2, + index_type &index_zmm1, + index_type &index_zmm2) { - constexpr int numlanes = keyType::numlanes; - if constexpr (numlanes == 8) { - key = bitonic_merge_zmm_64bit(key, value); - } - else if constexpr (numlanes == 4) { - key = bitonic_merge_ymm_64bit(key, value); - } - else { - static_assert(numlanes == -1, "should not reach here"); - UNUSED(key); - UNUSED(value); - } -} + const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); + const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); + // 1) First step of a merging network: coex of zmm1 and zmm2 reversed + key_zmm2 = vtype1::permutexvar(rev_index1, key_zmm2); + index_zmm2 = vtype2::permutexvar(rev_index2, index_zmm2); -template -X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, - typename valueType::reg_t &value) -{ - constexpr int numlanes = keyType::numlanes; - if constexpr (numlanes == 8) { - key = sort_zmm_64bit(key, value); - } - else if constexpr (numlanes == 4) { - key = sort_ymm_64bit(key, value); - } - else { - static_assert(numlanes == -1, "should not reach here"); - UNUSED(key); - UNUSED(value); - } -} + reg_t key_zmm3 = vtype1::min(key_zmm1, key_zmm2); + reg_t key_zmm4 = vtype1::max(key_zmm1, key_zmm2); -template -X86_SIMD_SORT_INLINE void bitonic_clean_n_vec(typename keyType::reg_t *keys, - typename valueType::reg_t *values) -{ - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int num = numVecs / 2; num >= 2; num /= 2) { - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int j = 0; j < numVecs; j += num) { - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < num / 2; i++) { - arrsize_t index1 = i + j; - arrsize_t index2 = i + j + num / 2; - COEX(keys[index1], - keys[index2], - values[index1], - values[index2]); - } - } - } -} + typename vtype1::opmask_t movmask = vtype1::eq(key_zmm3, key_zmm1); -template -X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(typename keyType::reg_t *keys, - typename valueType::reg_t *values) -{ - // Do the reverse part - if constexpr (numVecs == 2) { - keys[1] = keyType::reverse(keys[1]); - values[1] = valueType::reverse(values[1]); - COEX(keys[0], keys[1], values[0], values[1]); - keys[1] = keyType::reverse(keys[1]); - values[1] = valueType::reverse(values[1]); - } - else if constexpr (numVecs > 2) { - // Reverse upper half - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs / 2; i++) { - keys[numVecs - i - 1] = keyType::reverse(keys[numVecs - i - 1]); - values[numVecs - i - 1] - = valueType::reverse(values[numVecs - i - 1]); - - COEX(keys[i], - keys[numVecs - i - 1], - values[i], - values[numVecs - i - 1]); - - keys[numVecs - i - 1] = keyType::reverse(keys[numVecs - i - 1]); - values[numVecs - i - 1] - = valueType::reverse(values[numVecs - i - 1]); - } - } - - // Call cleaner - bitonic_clean_n_vec(keys, values); - - // Now do bitonic_merge - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs; i++) { - bitonic_merge_dispatch(keys[i], values[i]); - } -} + index_type index_zmm3 = vtype2::mask_mov(index_zmm2, movmask, index_zmm1); + index_type index_zmm4 = vtype2::mask_mov(index_zmm1, movmask, index_zmm2); -template -X86_SIMD_SORT_INLINE void -bitonic_fullmerge_n_vec(typename keyType::reg_t *keys, - typename valueType::reg_t *values) -{ - if constexpr (numPer > numVecs) { - UNUSED(keys); - UNUSED(values); - return; - } - else { - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs / numPer; i++) { - bitonic_merge_n_vec( - keys + i * numPer, values + i * numPer); - } - bitonic_fullmerge_n_vec( - keys, values); - } -} + /* need to reverse the lower registers to keep the correct order */ + key_zmm4 = vtype1::permutexvar(rev_index1, key_zmm4); + index_zmm4 = vtype2::permutexvar(rev_index2, index_zmm4); -template -X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, - typename indexType::type_t *indices, - int N) -{ - using kreg_t = typename keyType::reg_t; - using ireg_t = typename indexType::reg_t; - - static_assert(numVecs > 0, "numVecs should be > 0"); - if constexpr (numVecs > 1) { - if (N * 2 <= numVecs * keyType::numlanes) { - argsort_n_vec(keys, indices, N); - return; - } - } - - kreg_t keyVecs[numVecs]; - ireg_t indexVecs[numVecs]; - - // Generate masks for loading and storing - typename keyType::opmask_t ioMasks[numVecs - numVecs / 2]; - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - uint64_t num_to_read - = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), - (uint64_t)keyType::numlanes); - ioMasks[j] = keyType::get_partial_loadmask(num_to_read); - } - - // Unmasked part of the load - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs / 2; i++) { - indexVecs[i] = indexType::loadu(indices + i * indexType::numlanes); - keyVecs[i] - = keyType::i64gather(keys, indices + i * indexType::numlanes); - } - // Masked part of the load - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2; i < numVecs; i++) { - indexVecs[i] = indexType::mask_loadu( - indexType::zmm_max(), - extend_mask(ioMasks[i - numVecs / 2]), - indices + i * indexType::numlanes); - - keyVecs[i] = keyType::template mask_i64gather(keyType::zmm_max(), - ioMasks[i - numVecs / 2], - indexVecs[i], - keys); - } - - // Sort each loaded vector - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs; i++) { - sort_vec_dispatch(keyVecs[i], indexVecs[i]); - } - - // Run the full merger - bitonic_fullmerge_n_vec(keyVecs, indexVecs); - - // Unmasked part of the store - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs / 2; i++) { - indexType::storeu(indices + i * indexType::numlanes, indexVecs[i]); - } - // Masked part of the store - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - indexType::mask_storeu( - indices + i * indexType::numlanes, - extend_mask(ioMasks[i - numVecs / 2]), - indexVecs[i]); - } + // 2) Recursive half cleaner for each + key_zmm1 = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + key_zmm2 = bitonic_merge_zmm_64bit(key_zmm4, index_zmm4); + index_zmm1 = index_zmm3; + index_zmm2 = index_zmm4; } - -template -X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys, - typename valueType::type_t *values, - int N) +// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive +// half cleaner +template +X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(reg_t *key_zmm, + index_type *index_zmm) { - using kreg_t = typename keyType::reg_t; - using vreg_t = typename valueType::reg_t; - - static_assert(numVecs > 0, "numVecs should be > 0"); - if constexpr (numVecs > 1) { - if (N * 2 <= numVecs * keyType::numlanes) { - kvsort_n_vec(keys, values, N); - return; - } - } - - kreg_t keyVecs[numVecs]; - vreg_t valueVecs[numVecs]; - - // Generate masks for loading and storing - typename keyType::opmask_t ioMasks[numVecs - numVecs / 2]; - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - uint64_t num_to_read - = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), - (uint64_t)keyType::numlanes); - ioMasks[j] = keyType::get_partial_loadmask(num_to_read); - } - - // Unmasked part of the load - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs / 2; i++) { - keyVecs[i] = keyType::loadu(keys + i * keyType::numlanes); - valueVecs[i] = valueType::loadu(values + i * valueType::numlanes); - } - // Masked part of the load - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - keyVecs[i] = keyType::mask_loadu( - keyType::zmm_max(), ioMasks[j], keys + i * keyType::numlanes); - valueVecs[i] = valueType::mask_loadu(valueType::zmm_max(), - ioMasks[j], - values + i * valueType::numlanes); - } - - // Sort each loaded vector - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs; i++) { - sort_vec_dispatch(keyVecs[i], valueVecs[i]); - } - - // Run the full merger - bitonic_fullmerge_n_vec(keyVecs, valueVecs); - - // Unmasked part of the store - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = 0; i < numVecs / 2; i++) { - keyType::storeu(keys + i * keyType::numlanes, keyVecs[i]); - valueType::storeu(values + i * valueType::numlanes, valueVecs[i]); - } - // Masked part of the store - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - keyType::mask_storeu( - keys + i * keyType::numlanes, ioMasks[j], keyVecs[i]); - valueType::mask_storeu( - values + i * valueType::numlanes, ioMasks[j], valueVecs[i]); - } + const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); + const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); + // 1) First step of a merging network + reg_t key_zmm2r = vtype1::permutexvar(rev_index1, key_zmm[2]); + reg_t key_zmm3r = vtype1::permutexvar(rev_index1, key_zmm[3]); + index_type index_zmm2r = vtype2::permutexvar(rev_index2, index_zmm[2]); + index_type index_zmm3r = vtype2::permutexvar(rev_index2, index_zmm[3]); + + reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm3r); + reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm2r); + reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm3r); + reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm2r); + + typename vtype1::opmask_t movmask1 = vtype1::eq(key_reg_t1, key_zmm[0]); + typename vtype1::opmask_t movmask2 = vtype1::eq(key_reg_t2, key_zmm[1]); + + index_type index_reg_t1 + = vtype2::mask_mov(index_zmm3r, movmask1, index_zmm[0]); + index_type index_zmm_m1 + = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm3r); + index_type index_reg_t2 + = vtype2::mask_mov(index_zmm2r, movmask2, index_zmm[1]); + index_type index_zmm_m2 + = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm2r); + + // 2) Recursive half clearer: 16 + reg_t key_reg_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2); + reg_t key_reg_t4 = vtype1::permutexvar(rev_index1, key_zmm_m1); + index_type index_reg_t3 = vtype2::permutexvar(rev_index2, index_zmm_m2); + index_type index_reg_t4 = vtype2::permutexvar(rev_index2, index_zmm_m1); + + reg_t key_zmm0 = vtype1::min(key_reg_t1, key_reg_t2); + reg_t key_zmm1 = vtype1::max(key_reg_t1, key_reg_t2); + reg_t key_zmm2 = vtype1::min(key_reg_t3, key_reg_t4); + reg_t key_zmm3 = vtype1::max(key_reg_t3, key_reg_t4); + + movmask1 = vtype1::eq(key_zmm0, key_reg_t1); + movmask2 = vtype1::eq(key_zmm2, key_reg_t3); + + index_type index_zmm0 + = vtype2::mask_mov(index_reg_t2, movmask1, index_reg_t1); + index_type index_zmm1 + = vtype2::mask_mov(index_reg_t1, movmask1, index_reg_t2); + index_type index_zmm2 + = vtype2::mask_mov(index_reg_t4, movmask2, index_reg_t3); + index_type index_zmm3 + = vtype2::mask_mov(index_reg_t3, movmask2, index_reg_t4); + + key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); + key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); + key_zmm[2] = bitonic_merge_zmm_64bit(key_zmm2, index_zmm2); + key_zmm[3] = bitonic_merge_zmm_64bit(key_zmm3, index_zmm3); + + index_zmm[0] = index_zmm0; + index_zmm[1] = index_zmm1; + index_zmm[2] = index_zmm2; + index_zmm[3] = index_zmm3; } -template -X86_SIMD_SORT_INLINE void -argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N) +template +X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(reg_t *key_zmm, + index_type *index_zmm) { - using indexType = typename index_64bit_vector_type::type; - - static_assert(keyType::numlanes == indexType::numlanes, - "invalid pairing of value/index types"); - constexpr int numVecs = maxN / keyType::numlanes; - constexpr bool isMultiple = (maxN == (keyType::numlanes * numVecs)); - constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); - static_assert(powerOfTwo == true && isMultiple == true, - "maxN must be keyType::numlanes times a power of 2"); - - argsort_n_vec(keys, indices, N); + const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); + const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); + reg_t key_zmm4r = vtype1::permutexvar(rev_index1, key_zmm[4]); + reg_t key_zmm5r = vtype1::permutexvar(rev_index1, key_zmm[5]); + reg_t key_zmm6r = vtype1::permutexvar(rev_index1, key_zmm[6]); + reg_t key_zmm7r = vtype1::permutexvar(rev_index1, key_zmm[7]); + index_type index_zmm4r = vtype2::permutexvar(rev_index2, index_zmm[4]); + index_type index_zmm5r = vtype2::permutexvar(rev_index2, index_zmm[5]); + index_type index_zmm6r = vtype2::permutexvar(rev_index2, index_zmm[6]); + index_type index_zmm7r = vtype2::permutexvar(rev_index2, index_zmm[7]); + + reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm7r); + reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm6r); + reg_t key_reg_t3 = vtype1::min(key_zmm[2], key_zmm5r); + reg_t key_reg_t4 = vtype1::min(key_zmm[3], key_zmm4r); + + reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm7r); + reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm6r); + reg_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm5r); + reg_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm4r); + + typename vtype1::opmask_t movmask1 = vtype1::eq(key_reg_t1, key_zmm[0]); + typename vtype1::opmask_t movmask2 = vtype1::eq(key_reg_t2, key_zmm[1]); + typename vtype1::opmask_t movmask3 = vtype1::eq(key_reg_t3, key_zmm[2]); + typename vtype1::opmask_t movmask4 = vtype1::eq(key_reg_t4, key_zmm[3]); + + index_type index_reg_t1 + = vtype2::mask_mov(index_zmm7r, movmask1, index_zmm[0]); + index_type index_zmm_m1 + = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm7r); + index_type index_reg_t2 + = vtype2::mask_mov(index_zmm6r, movmask2, index_zmm[1]); + index_type index_zmm_m2 + = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm6r); + index_type index_reg_t3 + = vtype2::mask_mov(index_zmm5r, movmask3, index_zmm[2]); + index_type index_zmm_m3 + = vtype2::mask_mov(index_zmm[2], movmask3, index_zmm5r); + index_type index_reg_t4 + = vtype2::mask_mov(index_zmm4r, movmask4, index_zmm[3]); + index_type index_zmm_m4 + = vtype2::mask_mov(index_zmm[3], movmask4, index_zmm4r); + + reg_t key_reg_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4); + reg_t key_reg_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3); + reg_t key_reg_t7 = vtype1::permutexvar(rev_index1, key_zmm_m2); + reg_t key_reg_t8 = vtype1::permutexvar(rev_index1, key_zmm_m1); + index_type index_reg_t5 = vtype2::permutexvar(rev_index2, index_zmm_m4); + index_type index_reg_t6 = vtype2::permutexvar(rev_index2, index_zmm_m3); + index_type index_reg_t7 = vtype2::permutexvar(rev_index2, index_zmm_m2); + index_type index_reg_t8 = vtype2::permutexvar(rev_index2, index_zmm_m1); + + COEX(key_reg_t1, key_reg_t3, index_reg_t1, index_reg_t3); + COEX(key_reg_t2, key_reg_t4, index_reg_t2, index_reg_t4); + COEX(key_reg_t5, key_reg_t7, index_reg_t5, index_reg_t7); + COEX(key_reg_t6, key_reg_t8, index_reg_t6, index_reg_t8); + COEX(key_reg_t1, key_reg_t2, index_reg_t1, index_reg_t2); + COEX(key_reg_t3, key_reg_t4, index_reg_t3, index_reg_t4); + COEX(key_reg_t5, key_reg_t6, index_reg_t5, index_reg_t6); + COEX(key_reg_t7, key_reg_t8, index_reg_t7, index_reg_t8); + key_zmm[0] + = bitonic_merge_zmm_64bit(key_reg_t1, index_reg_t1); + key_zmm[1] + = bitonic_merge_zmm_64bit(key_reg_t2, index_reg_t2); + key_zmm[2] + = bitonic_merge_zmm_64bit(key_reg_t3, index_reg_t3); + key_zmm[3] + = bitonic_merge_zmm_64bit(key_reg_t4, index_reg_t4); + key_zmm[4] + = bitonic_merge_zmm_64bit(key_reg_t5, index_reg_t5); + key_zmm[5] + = bitonic_merge_zmm_64bit(key_reg_t6, index_reg_t6); + key_zmm[6] + = bitonic_merge_zmm_64bit(key_reg_t7, index_reg_t7); + key_zmm[7] + = bitonic_merge_zmm_64bit(key_reg_t8, index_reg_t8); + + index_zmm[0] = index_reg_t1; + index_zmm[1] = index_reg_t2; + index_zmm[2] = index_reg_t3; + index_zmm[3] = index_reg_t4; + index_zmm[4] = index_reg_t5; + index_zmm[5] = index_reg_t6; + index_zmm[6] = index_reg_t7; + index_zmm[7] = index_reg_t8; } -template -X86_SIMD_SORT_INLINE void kvsort_n(typename keyType::type_t *keys, - typename valueType::type_t *values, - int N) +template +X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(reg_t *key_zmm, + index_type *index_zmm) { - static_assert(keyType::numlanes == valueType::numlanes, - "invalid pairing of key/value types"); - - constexpr int numVecs = maxN / keyType::numlanes; - constexpr bool isMultiple = (maxN == (keyType::numlanes * numVecs)); - constexpr bool powerOfTwo = (numVecs != 0 && !(numVecs & (numVecs - 1))); - static_assert(powerOfTwo == true && isMultiple == true, - "maxN must be keyType::numlanes times a power of 2"); - - kvsort_n_vec(keys, values, N); + const typename vtype1::regi_t rev_index1 = vtype1::seti(NETWORK_64BIT_2); + const typename vtype2::regi_t rev_index2 = vtype2::seti(NETWORK_64BIT_2); + reg_t key_zmm8r = vtype1::permutexvar(rev_index1, key_zmm[8]); + reg_t key_zmm9r = vtype1::permutexvar(rev_index1, key_zmm[9]); + reg_t key_zmm10r = vtype1::permutexvar(rev_index1, key_zmm[10]); + reg_t key_zmm11r = vtype1::permutexvar(rev_index1, key_zmm[11]); + reg_t key_zmm12r = vtype1::permutexvar(rev_index1, key_zmm[12]); + reg_t key_zmm13r = vtype1::permutexvar(rev_index1, key_zmm[13]); + reg_t key_zmm14r = vtype1::permutexvar(rev_index1, key_zmm[14]); + reg_t key_zmm15r = vtype1::permutexvar(rev_index1, key_zmm[15]); + + index_type index_zmm8r = vtype2::permutexvar(rev_index2, index_zmm[8]); + index_type index_zmm9r = vtype2::permutexvar(rev_index2, index_zmm[9]); + index_type index_zmm10r = vtype2::permutexvar(rev_index2, index_zmm[10]); + index_type index_zmm11r = vtype2::permutexvar(rev_index2, index_zmm[11]); + index_type index_zmm12r = vtype2::permutexvar(rev_index2, index_zmm[12]); + index_type index_zmm13r = vtype2::permutexvar(rev_index2, index_zmm[13]); + index_type index_zmm14r = vtype2::permutexvar(rev_index2, index_zmm[14]); + index_type index_zmm15r = vtype2::permutexvar(rev_index2, index_zmm[15]); + + reg_t key_reg_t1 = vtype1::min(key_zmm[0], key_zmm15r); + reg_t key_reg_t2 = vtype1::min(key_zmm[1], key_zmm14r); + reg_t key_reg_t3 = vtype1::min(key_zmm[2], key_zmm13r); + reg_t key_reg_t4 = vtype1::min(key_zmm[3], key_zmm12r); + reg_t key_reg_t5 = vtype1::min(key_zmm[4], key_zmm11r); + reg_t key_reg_t6 = vtype1::min(key_zmm[5], key_zmm10r); + reg_t key_reg_t7 = vtype1::min(key_zmm[6], key_zmm9r); + reg_t key_reg_t8 = vtype1::min(key_zmm[7], key_zmm8r); + + reg_t key_zmm_m1 = vtype1::max(key_zmm[0], key_zmm15r); + reg_t key_zmm_m2 = vtype1::max(key_zmm[1], key_zmm14r); + reg_t key_zmm_m3 = vtype1::max(key_zmm[2], key_zmm13r); + reg_t key_zmm_m4 = vtype1::max(key_zmm[3], key_zmm12r); + reg_t key_zmm_m5 = vtype1::max(key_zmm[4], key_zmm11r); + reg_t key_zmm_m6 = vtype1::max(key_zmm[5], key_zmm10r); + reg_t key_zmm_m7 = vtype1::max(key_zmm[6], key_zmm9r); + reg_t key_zmm_m8 = vtype1::max(key_zmm[7], key_zmm8r); + + index_type index_reg_t1 = vtype2::mask_mov( + index_zmm15r, vtype1::eq(key_reg_t1, key_zmm[0]), index_zmm[0]); + index_type index_zmm_m1 = vtype2::mask_mov( + index_zmm[0], vtype1::eq(key_reg_t1, key_zmm[0]), index_zmm15r); + index_type index_reg_t2 = vtype2::mask_mov( + index_zmm14r, vtype1::eq(key_reg_t2, key_zmm[1]), index_zmm[1]); + index_type index_zmm_m2 = vtype2::mask_mov( + index_zmm[1], vtype1::eq(key_reg_t2, key_zmm[1]), index_zmm14r); + index_type index_reg_t3 = vtype2::mask_mov( + index_zmm13r, vtype1::eq(key_reg_t3, key_zmm[2]), index_zmm[2]); + index_type index_zmm_m3 = vtype2::mask_mov( + index_zmm[2], vtype1::eq(key_reg_t3, key_zmm[2]), index_zmm13r); + index_type index_reg_t4 = vtype2::mask_mov( + index_zmm12r, vtype1::eq(key_reg_t4, key_zmm[3]), index_zmm[3]); + index_type index_zmm_m4 = vtype2::mask_mov( + index_zmm[3], vtype1::eq(key_reg_t4, key_zmm[3]), index_zmm12r); + + index_type index_reg_t5 = vtype2::mask_mov( + index_zmm11r, vtype1::eq(key_reg_t5, key_zmm[4]), index_zmm[4]); + index_type index_zmm_m5 = vtype2::mask_mov( + index_zmm[4], vtype1::eq(key_reg_t5, key_zmm[4]), index_zmm11r); + index_type index_reg_t6 = vtype2::mask_mov( + index_zmm10r, vtype1::eq(key_reg_t6, key_zmm[5]), index_zmm[5]); + index_type index_zmm_m6 = vtype2::mask_mov( + index_zmm[5], vtype1::eq(key_reg_t6, key_zmm[5]), index_zmm10r); + index_type index_reg_t7 = vtype2::mask_mov( + index_zmm9r, vtype1::eq(key_reg_t7, key_zmm[6]), index_zmm[6]); + index_type index_zmm_m7 = vtype2::mask_mov( + index_zmm[6], vtype1::eq(key_reg_t7, key_zmm[6]), index_zmm9r); + index_type index_reg_t8 = vtype2::mask_mov( + index_zmm8r, vtype1::eq(key_reg_t8, key_zmm[7]), index_zmm[7]); + index_type index_zmm_m8 = vtype2::mask_mov( + index_zmm[7], vtype1::eq(key_reg_t8, key_zmm[7]), index_zmm8r); + + reg_t key_reg_t9 = vtype1::permutexvar(rev_index1, key_zmm_m8); + reg_t key_reg_t10 = vtype1::permutexvar(rev_index1, key_zmm_m7); + reg_t key_reg_t11 = vtype1::permutexvar(rev_index1, key_zmm_m6); + reg_t key_reg_t12 = vtype1::permutexvar(rev_index1, key_zmm_m5); + reg_t key_reg_t13 = vtype1::permutexvar(rev_index1, key_zmm_m4); + reg_t key_reg_t14 = vtype1::permutexvar(rev_index1, key_zmm_m3); + reg_t key_reg_t15 = vtype1::permutexvar(rev_index1, key_zmm_m2); + reg_t key_reg_t16 = vtype1::permutexvar(rev_index1, key_zmm_m1); + index_type index_reg_t9 = vtype2::permutexvar(rev_index2, index_zmm_m8); + index_type index_reg_t10 = vtype2::permutexvar(rev_index2, index_zmm_m7); + index_type index_reg_t11 = vtype2::permutexvar(rev_index2, index_zmm_m6); + index_type index_reg_t12 = vtype2::permutexvar(rev_index2, index_zmm_m5); + index_type index_reg_t13 = vtype2::permutexvar(rev_index2, index_zmm_m4); + index_type index_reg_t14 = vtype2::permutexvar(rev_index2, index_zmm_m3); + index_type index_reg_t15 = vtype2::permutexvar(rev_index2, index_zmm_m2); + index_type index_reg_t16 = vtype2::permutexvar(rev_index2, index_zmm_m1); + + COEX(key_reg_t1, key_reg_t5, index_reg_t1, index_reg_t5); + COEX(key_reg_t2, key_reg_t6, index_reg_t2, index_reg_t6); + COEX(key_reg_t3, key_reg_t7, index_reg_t3, index_reg_t7); + COEX(key_reg_t4, key_reg_t8, index_reg_t4, index_reg_t8); + COEX(key_reg_t9, key_reg_t13, index_reg_t9, index_reg_t13); + COEX( + key_reg_t10, key_reg_t14, index_reg_t10, index_reg_t14); + COEX( + key_reg_t11, key_reg_t15, index_reg_t11, index_reg_t15); + COEX( + key_reg_t12, key_reg_t16, index_reg_t12, index_reg_t16); + + COEX(key_reg_t1, key_reg_t3, index_reg_t1, index_reg_t3); + COEX(key_reg_t2, key_reg_t4, index_reg_t2, index_reg_t4); + COEX(key_reg_t5, key_reg_t7, index_reg_t5, index_reg_t7); + COEX(key_reg_t6, key_reg_t8, index_reg_t6, index_reg_t8); + COEX(key_reg_t9, key_reg_t11, index_reg_t9, index_reg_t11); + COEX( + key_reg_t10, key_reg_t12, index_reg_t10, index_reg_t12); + COEX( + key_reg_t13, key_reg_t15, index_reg_t13, index_reg_t15); + COEX( + key_reg_t14, key_reg_t16, index_reg_t14, index_reg_t16); + + COEX(key_reg_t1, key_reg_t2, index_reg_t1, index_reg_t2); + COEX(key_reg_t3, key_reg_t4, index_reg_t3, index_reg_t4); + COEX(key_reg_t5, key_reg_t6, index_reg_t5, index_reg_t6); + COEX(key_reg_t7, key_reg_t8, index_reg_t7, index_reg_t8); + COEX(key_reg_t9, key_reg_t10, index_reg_t9, index_reg_t10); + COEX( + key_reg_t11, key_reg_t12, index_reg_t11, index_reg_t12); + COEX( + key_reg_t13, key_reg_t14, index_reg_t13, index_reg_t14); + COEX( + key_reg_t15, key_reg_t16, index_reg_t15, index_reg_t16); + // + key_zmm[0] + = bitonic_merge_zmm_64bit(key_reg_t1, index_reg_t1); + key_zmm[1] + = bitonic_merge_zmm_64bit(key_reg_t2, index_reg_t2); + key_zmm[2] + = bitonic_merge_zmm_64bit(key_reg_t3, index_reg_t3); + key_zmm[3] + = bitonic_merge_zmm_64bit(key_reg_t4, index_reg_t4); + key_zmm[4] + = bitonic_merge_zmm_64bit(key_reg_t5, index_reg_t5); + key_zmm[5] + = bitonic_merge_zmm_64bit(key_reg_t6, index_reg_t6); + key_zmm[6] + = bitonic_merge_zmm_64bit(key_reg_t7, index_reg_t7); + key_zmm[7] + = bitonic_merge_zmm_64bit(key_reg_t8, index_reg_t8); + key_zmm[8] + = bitonic_merge_zmm_64bit(key_reg_t9, index_reg_t9); + key_zmm[9] = bitonic_merge_zmm_64bit(key_reg_t10, + index_reg_t10); + key_zmm[10] = bitonic_merge_zmm_64bit(key_reg_t11, + index_reg_t11); + key_zmm[11] = bitonic_merge_zmm_64bit(key_reg_t12, + index_reg_t12); + key_zmm[12] = bitonic_merge_zmm_64bit(key_reg_t13, + index_reg_t13); + key_zmm[13] = bitonic_merge_zmm_64bit(key_reg_t14, + index_reg_t14); + key_zmm[14] = bitonic_merge_zmm_64bit(key_reg_t15, + index_reg_t15); + key_zmm[15] = bitonic_merge_zmm_64bit(key_reg_t16, + index_reg_t16); + + index_zmm[0] = index_reg_t1; + index_zmm[1] = index_reg_t2; + index_zmm[2] = index_reg_t3; + index_zmm[3] = index_reg_t4; + index_zmm[4] = index_reg_t5; + index_zmm[5] = index_reg_t6; + index_zmm[6] = index_reg_t7; + index_zmm[7] = index_reg_t8; + index_zmm[8] = index_reg_t9; + index_zmm[9] = index_reg_t10; + index_zmm[10] = index_reg_t11; + index_zmm[11] = index_reg_t12; + index_zmm[12] = index_reg_t13; + index_zmm[13] = index_reg_t14; + index_zmm[14] = index_reg_t15; + index_zmm[15] = index_reg_t16; } - -#endif \ No newline at end of file +#endif // AVX512_KEYVALUE_NETWORKS