diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 81d7d00e..5588cffa 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -1,6 +1,8 @@ // AVX2 specific routines: #include "avx2-32bit-qsort.hpp" #include "avx2-64bit-qsort.hpp" +#include "avx2-32bit-half.hpp" +#include "xss-common-argsort.h" #include "x86simdsort-internal.h" #define DEFINE_ALL_METHODS(type) \ @@ -18,6 +20,17 @@ void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \ { \ avx2_partial_qsort(arr, k, arrsize, hasnan); \ + }\ + template <> \ + std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ + { \ + return avx2_argsort(arr, arrsize, hasnan); \ + } \ + template <> \ + std::vector argselect( \ + type *arr, size_t k, size_t arrsize, bool hasnan) \ + { \ + return avx2_argselect(arr, k, arrsize, hasnan); \ } namespace xss { diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 8ebbc6be..f088e4cd 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -189,12 +189,12 @@ DISPATCH_ALL(partial_qsort, (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_ALL(argsort, (ISA_LIST("none")), - (ISA_LIST("avx512_skx")), - (ISA_LIST("avx512_skx"))) + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_ALL(argselect, (ISA_LIST("none")), - (ISA_LIST("avx512_skx")), - (ISA_LIST("avx512_skx"))) + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx", "avx2"))) #define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \ DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx")))\ diff --git a/src/avx2-32bit-half.hpp b/src/avx2-32bit-half.hpp new file mode 100644 index 00000000..52697692 --- /dev/null +++ b/src/avx2-32bit-half.hpp @@ -0,0 +1,557 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef AVX2_HALF_32BIT +#define AVX2_HALF_32BIT + +#include "xss-common-includes.h" +#include "avx2-emu-funcs.hpp" + +/* + * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ + +// ymm 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 + +/* + * Assumes ymm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit_half(reg_t ymm) +{ + using swizzle = typename vtype::swizzle_ops; + + const typename vtype::opmask_t oxAA = vtype::seti(-1, 0, -1, 0); + const typename vtype::opmask_t oxCC = vtype::seti(-1, -1, 0, 0); + + ymm = cmp_merge(ymm, swizzle::template swap_n(ymm), oxAA); + ymm = cmp_merge(ymm, vtype::reverse(ymm), oxCC); + ymm = cmp_merge(ymm, swizzle::template swap_n(ymm), oxAA); + return ymm; +} + +struct avx2_32bit_half_swizzle_ops; + +template <> +struct avx2_half_vector { + using type_t = int32_t; + using reg_t = __m128i; + using regi_t = __m128i; + using opmask_t = __m128i; + static const uint8_t numlanes = 4; + static constexpr simd_type vec_type = simd_type::AVX2; + + using swizzle_ops = avx2_32bit_half_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT32; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT32; + } + static reg_t zmm_max() + { + return _mm_set1_epi32(type_max()); + } // TODO: this should broadcast bits as is? + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_half(mask); + } + static regi_t seti(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static reg_t set(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) + { + return _mm_xor_si128(x, y); + } + static opmask_t ge(reg_t x, reg_t y) + { + opmask_t equal = eq(x, y); + opmask_t greater = _mm_cmpgt_epi32(x, y); + return _mm_or_si128(equal, greater); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm_cmpeq_epi32(x, y); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_epi32( + src, (const int *)base, index, mask, scale); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) + { + return _mm_mask_i32gather_epi32( + src, (const int *)base, index, mask, scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) + { + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); + } + static reg_t loadu(void const *mem) + { + return _mm_loadu_si128((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm_max_epi32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu32_half(mem, mask, x); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm_maskload_epi32((const int *)mem, mask); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm_maskload_epi32((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(x), + _mm_castsi128_ps(y), + _mm_castsi128_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm_maskstore_epi32((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm_min_epi32(x, y); + } + static reg_t permutexvar(__m128i idx, reg_t ymm) + { + return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); + } + static reg_t reverse(reg_t ymm) + { + const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + return permutexvar(rev_index, ymm); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32_half(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32_half(v); + } + static reg_t set1(type_t v) + { + return _mm_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm_storeu_si128((__m128i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit_half>(x); + } + static reg_t cast_from(__m128i v) + { + return v; + } + static __m128i cast_to(reg_t v) + { + return v; + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32_half( + left_addr, right_addr, k, reg); + } +}; +template <> +struct avx2_half_vector { + using type_t = uint32_t; + using reg_t = __m128i; + using regi_t = __m128i; + using opmask_t = __m128i; + static const uint8_t numlanes = 4; + static constexpr simd_type vec_type = simd_type::AVX2; + + using swizzle_ops = avx2_32bit_half_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT32; + } + static type_t type_min() + { + return 0; + } + static reg_t zmm_max() + { + return _mm_set1_epi32(type_max()); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_half(mask); + } + static regi_t seti(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static reg_t set(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_epi32( + src, (const int *)base, index, mask, scale); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) + { + return _mm_mask_i32gather_epi32( + src, (const int *)base, index, mask, scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) + { + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); + } + static opmask_t ge(reg_t x, reg_t y) + { + reg_t maxi = max(x, y); + return eq(maxi, x); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm_cmpeq_epi32(x, y); + } + static reg_t loadu(void const *mem) + { + return _mm_loadu_si128((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm_max_epu32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu32_half(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm_maskload_epi32((const int *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(x), + _mm_castsi128_ps(y), + _mm_castsi128_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm_maskstore_epi32((int *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm_min_epu32(x, y); + } + static reg_t permutexvar(__m128i idx, reg_t ymm) + { + return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); + } + static reg_t reverse(reg_t ymm) + { + const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + return permutexvar(rev_index, ymm); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32_half(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32_half(v); + } + static reg_t set1(type_t v) + { + return _mm_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm_storeu_si128((__m128i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit_half>(x); + } + static reg_t cast_from(__m128i v) + { + return v; + } + static __m128i cast_to(reg_t v) + { + return v; + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32_half( + left_addr, right_addr, k, reg); + } +}; +template <> +struct avx2_half_vector { + using type_t = float; + using reg_t = __m128; + using regi_t = __m128i; + using opmask_t = __m128i; + static const uint8_t numlanes = 4; + static constexpr simd_type vec_type = simd_type::AVX2; + + using swizzle_ops = avx2_32bit_half_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITYF; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITYF; + } + static reg_t zmm_max() + { + return _mm_set1_ps(type_max()); + } + + static regi_t seti(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static reg_t set(float v1, float v2, float v3, float v4) + { + return _mm_set_ps(v1, v2, v3, v4); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm_maskload_ps((const float *)mem, mask); + } + static opmask_t ge(reg_t x, reg_t y) + { + return _mm_castps_si128(_mm_cmp_ps(x, y, _CMP_GE_OQ)); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm_castps_si128(_mm_cmp_ps(x, y, _CMP_EQ_OQ)); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_half(mask); + } + static int32_t convert_mask_to_int(opmask_t mask) + { + return convert_avx2_mask_to_int_half(mask); + } + template + static opmask_t fpclass(reg_t x) + { + if constexpr (type == (0x01 | 0x80)) { + return _mm_castps_si128(_mm_cmp_ps(x, x, _CMP_UNORD_Q)); + } + else { + static_assert(type == (0x01 | 0x80), "should not reach here"); + } + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_ps( + src, (const float *)base, index, _mm_castsi128_ps(mask), scale); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) + { + return _mm_mask_i32gather_ps( + src, (const float *)base, index, _mm_castsi128_ps(mask), scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) + { + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); + } + static reg_t loadu(void const *mem) + { + return _mm_loadu_ps((float const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm_max_ps(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu32_half(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm_maskload_ps((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm_blendv_ps(x, y, _mm_castsi128_ps(mask)); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm_maskstore_ps((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm_min_ps(x, y); + } + static reg_t permutexvar(__m128i idx, reg_t ymm) + { + return _mm_permutevar_ps(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + return permutexvar(rev_index, ymm); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32_half(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32_half(v); + } + static reg_t set1(type_t v) + { + return _mm_set1_ps(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(ymm), mask)); + } + static void storeu(void *mem, reg_t x) + { + _mm_storeu_ps((float *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit_half>(x); + } + static reg_t cast_from(__m128i v) + { + return _mm_castsi128_ps(v); + } + static __m128i cast_to(reg_t v) + { + return _mm_castps_si128(v); + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32_half( + left_addr, right_addr, k, reg); + } +}; + +struct avx2_32bit_half_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { + if constexpr (scale == 2) { + return vtype::template shuffle<0b10110001>(reg); + } + else if constexpr (scale == 4) { + return vtype::template shuffle<0b01001110>(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { + __m128i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { + return vtype::reverse(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { + __m128i v1 = vtype::cast_to(reg); + __m128i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { v1 = _mm_blend_epi32(v1, v2, 0b0101); } + else if constexpr (scale == 4) { + v1 = _mm_blend_epi32(v1, v2, 0b0011); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; + +#endif // AVX2_HALF_32BIT diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index 521597cd..cf0fbd55 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -70,6 +70,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; @@ -225,6 +226,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; @@ -369,6 +371,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index fd7f92af..764bdcf4 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -52,6 +52,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; @@ -76,6 +77,10 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } + static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) + { + return _mm256_set_epi64x(v1, v2, v3, v4); + } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _mm256_xor_si256(x, y); @@ -98,12 +103,19 @@ struct avx2_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi64(src, base, index, mask, scale); + return _mm256_mask_i64gather_epi64( + src, (const long long int *)base, index, mask, scale); } template - static reg_t i64gather(__m256i index, void const *base) + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) + { + return _mm256_mask_i32gather_epi64( + src, (const long long int *)base, index, mask, scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return _mm256_i64gather_epi64((int64_t const *)base, index, scale); + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -211,6 +223,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; @@ -235,17 +248,27 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } + static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) + { + return _mm256_set_epi64x(v1, v2, v3, v4); + } template static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi64(src, base, index, mask, scale); + return _mm256_mask_i64gather_epi64( + src, (const long long int *)base, index, mask, scale); } template - static reg_t i64gather(__m256i index, void const *base) + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) { - return _mm256_i64gather_epi64( - (long long int const *)base, index, scale); + return _mm256_mask_i32gather_epi64( + src, (const long long int *)base, index, mask, scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) + { + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static opmask_t gt(reg_t x, reg_t y) { @@ -369,6 +392,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; @@ -407,7 +431,10 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } - + static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) + { + return _mm256_set_pd(v1, v2, v3, v4); + } static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm256_maskload_pd((const double *)mem, mask); @@ -424,14 +451,27 @@ struct avx2_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_pd( - src, base, index, _mm256_castsi256_pd(mask), scale); + return _mm256_mask_i64gather_pd(src, + (const type_t *)base, + index, + _mm256_castsi256_pd(mask), + scale); ; } template - static reg_t i64gather(__m256i index, void const *base) + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) + { + return _mm256_mask_i32gather_pd(src, + (const type_t *)base, + index, + _mm256_castsi256_pd(mask), + scale); + ; + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return _mm256_i64gather_pd((double *)base, index, scale); + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static reg_t loadu(void const *mem) { diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index 9f6229f7..38489626 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -35,6 +35,21 @@ constexpr auto avx2_mask_helper_lut64 = [] { return lut; }(); +constexpr auto avx2_mask_helper_lut32_half = [] { + std::array, 16> lut {}; + for (int64_t i = 0; i <= 0xF; i++) { + std::array entry {}; + for (int j = 0; j < 4; j++) { + if (((i >> j) & 1) == 1) + entry[j] = 0xFFFFFFFF; + else + entry[j] = 0; + } + lut[i] = entry; + } + return lut; +}(); + constexpr auto avx2_compressstore_lut32_gen = [] { std::array, 256>, 2> lutPair {}; auto &permLut = lutPair[0]; @@ -65,6 +80,38 @@ constexpr auto avx2_compressstore_lut32_gen = [] { constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0]; constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1]; +constexpr auto avx2_compressstore_lut32_half_gen = [] { + std::array, 16>, 2> lutPair {}; + auto &permLut = lutPair[0]; + auto &leftLut = lutPair[1]; + for (int64_t i = 0; i <= 0xF; i++) { + std::array indices {}; + std::array leftEntry = {0, 0, 0, 0}; + int right = 3; + int left = 0; + for (int j = 0; j < 4; j++) { + bool ge = (i >> j) & 1; + if (ge) { + indices[right] = j; + right--; + } + else { + indices[left] = j; + leftEntry[left] = 0xFFFFFFFF; + left++; + } + } + permLut[i] = indices; + leftLut[i] = leftEntry; + } + return lutPair; +}(); + +constexpr auto avx2_compressstore_lut32_half_perm + = avx2_compressstore_lut32_half_gen[0]; +constexpr auto avx2_compressstore_lut32_half_left + = avx2_compressstore_lut32_half_gen[1]; + constexpr auto avx2_compressstore_lut64_gen = [] { std::array, 16> permLut {}; std::array, 16> leftLut {}; @@ -123,6 +170,19 @@ int32_t convert_avx2_mask_to_int_64bit(__m256i m) return _mm256_movemask_pd(_mm256_castsi256_pd(m)); } +X86_SIMD_SORT_INLINE +__m128i convert_int_to_avx2_mask_half(int32_t m) +{ + return _mm_loadu_si128( + (const __m128i *)avx2_mask_helper_lut32_half[m].data()); +} + +X86_SIMD_SORT_INLINE +int32_t convert_avx2_mask_to_int_half(__m128i m) +{ + return _mm_movemask_ps(_mm_castsi128_ps(m)); +} + // Emulators for intrinsics missing from AVX2 compared to AVX512 template T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) @@ -139,6 +199,19 @@ T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) return std::max(arr[0], arr[7]); } +template +T avx2_emu_reduce_max32_half(typename avx2_half_vector::reg_t x) +{ + using vtype = avx2_half_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::max( + x, vtype::template shuffle(x)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter1); + return std::max(arr[0], arr[3]); +} + template T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) { @@ -154,6 +227,19 @@ T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) return std::min(arr[0], arr[7]); } +template +T avx2_emu_reduce_min32_half(typename avx2_half_vector::reg_t x) +{ + using vtype = avx2_half_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::min( + x, vtype::template shuffle(x)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter1); + return std::min(arr[0], arr[3]); +} + template T avx2_emu_reduce_max64(typename avx2_vector::reg_t x) { @@ -191,7 +277,30 @@ void avx2_emu_mask_compressstoreu32(void *base_addr, const __m256i &left = _mm256_loadu_si256( (const __m256i *)avx2_compressstore_lut32_left[shortMask].data()); - typename vtype::reg_t temp = vtype::permutevar(reg, perm); + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); + + vtype::mask_storeu(leftStore, left, temp); +} + +template +void avx2_emu_mask_compressstoreu32_half( + void *base_addr, + typename avx2_half_vector::opmask_t k, + typename avx2_half_vector::reg_t reg) +{ + using vtype = avx2_half_vector; + + T *leftStore = (T *)base_addr; + + int32_t shortMask = convert_avx2_mask_to_int_half(k); + const __m128i &perm = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] + .data()); + const __m128i &left = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_left[shortMask] + .data()); + + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); vtype::mask_storeu(leftStore, left, temp); } @@ -232,7 +341,31 @@ int avx2_double_compressstore32(void *left_addr, const __m256i &perm = _mm256_loadu_si256( (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); - typename vtype::reg_t temp = vtype::permutevar(reg, perm); + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); + + vtype::storeu(leftStore, temp); + vtype::storeu(rightStore, temp); + + return _mm_popcnt_u32(shortMask); +} + +template +int avx2_double_compressstore32_half(void *left_addr, + void *right_addr, + typename avx2_half_vector::opmask_t k, + typename avx2_half_vector::reg_t reg) +{ + using vtype = avx2_half_vector; + + T *leftStore = (T *)left_addr; + T *rightStore = (T *)right_addr; + + int32_t shortMask = convert_avx2_mask_to_int_half(k); + const __m128i &perm = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] + .data()); + + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); vtype::storeu(leftStore, temp); vtype::storeu(rightStore, temp); diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index be806f5f..32d7419c 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -26,6 +26,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; @@ -208,6 +209,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; @@ -343,6 +345,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 1814699f..74615765 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -42,6 +42,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; @@ -220,6 +221,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; @@ -398,6 +400,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 799303d8..3a475da8 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -7,557 +7,7 @@ #ifndef AVX512_ARGSORT_64BIT #define AVX512_ARGSORT_64BIT -#include "xss-common-qsort.h" #include "avx512-64bit-common.h" -#include "xss-network-keyvaluesort.hpp" -#include - -template -X86_SIMD_SORT_INLINE void std_argselect_withnan( - T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) -{ - std::nth_element(arg + left, - arg + k, - arg + right, - [arr](arrsize_t a, arrsize_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { - return arr[a] < arr[b]; - } - else if (std::isnan(arr[a])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { - return arr[left] < arr[right]; - } - else if (std::isnan(arr[left])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - // sort indices according to corresponding array element - return arr[left] < arr[right]; - }); -} - -/* - * Parition one ZMM register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); - int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - argtype::mask_compressstoreu( - arg + left, vtype::knot_opmask(gt_mask), arg_vec); - argtype::mask_compressstoreu( - arg + right - amount_gt_pivot, gt_mask, arg_vec); - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); - return amount_gt_pivot; -} -/* - * Parition an array based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - if (right - left == vtype::numlanes) { - argreg_t argvec = argtype::loadu(arg + left); - reg_t vec = vtype::i64gather(arr, arg + left); - int32_t amount_gt_pivot - = partition_vec(arg, - left, - left + vtype::numlanes, - argvec, - vec, - pivot_vec, - &min_vec, - &max_vec); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return left + (vtype::numlanes - amount_gt_pivot); - } - - // first and last vtype::numlanes values are partitioned at the end - argreg_t argvec_left = argtype::loadu(arg + left); - reg_t vec_left = vtype::i64gather(arr, arg + left); - argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); - reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += vtype::numlanes; - right -= vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec; - reg_t curr_vec; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= vtype::numlanes; - arg_vec = argtype::loadu(arg + right); - curr_vec = vtype::i64gather(arr, arg + right); - } - else { - arg_vec = argtype::loadu(arg + left); - curr_vec = vtype::i64gather(arr, arg + left); - left += vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec, - curr_vec, - pivot_vec, - &min_vec, - &max_vec); - ; - r_store -= amount_gt_pivot; - l_store += (vtype::numlanes - amount_gt_pivot); - } - - /* partition and save vec_left and vec_right */ - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left, - vec_left, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(arg, - l_store, - l_store + vtype::numlanes, - argvec_right, - vec_right, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - if (right - left <= 8 * num_unroll * vtype::numlanes) { - return partition_avx512( - arr, arg, left, right, pivot, smallest, biggest); - } - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; - --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using reg_t = typename vtype::reg_t; - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - // first and last vtype::numlanes values are partitioned at the end - reg_t vec_left[num_unroll], vec_right[num_unroll]; - argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); - vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); - argvec_right[ii] = argtype::loadu( - arg + (right - vtype::numlanes * (num_unroll - ii))); - vec_right[ii] = vtype::i64gather( - arr, arg + (right - vtype::numlanes * (num_unroll - ii))); - } - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += num_unroll * vtype::numlanes; - right -= num_unroll * vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec[num_unroll]; - reg_t curr_vec[num_unroll]; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= num_unroll * vtype::numlanes; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] - = argtype::loadu(arg + right + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + right + ii * vtype::numlanes); - } - } - else { - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + left + ii * vtype::numlanes); - } - left += num_unroll * vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec[ii], - curr_vec[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - } - - /* partition and save vec_left and vec_right */ - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left[ii], - vec_left[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_right[ii], - vec_right[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, - arrsize_t *arg, - const arrsize_t left, - const arrsize_t right) -{ - if (right - left >= vtype::numlanes) { - // median of 8 - arrsize_t size = (right - left) / 8; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[arg[left + size]], - arr[arg[left + 2 * size]], - arr[arg[left + 3 * size]], - arr[arg[left + 4 * size]], - arr[arg[left + 5 * size]], - arr[arg[left + 6 * size]], - arr[arg[left + 7 * size]], - arr[arg[left + 8 * size]]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = sort_zmm_64bit(rand_vec); - return ((type_t *)&sort)[4]; - } - else { - return arr[arg[right]]; - } -} - -template -X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 256) { - argsort_n( - arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); - if (pivot != smallest) - argsort_64bit_( - arr, arg, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) - argsort_64bit_( - arr, arg, pivot_index, right, max_iters - 1); -} - -template -X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, - arrsize_t *arg, - arrsize_t pos, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) -{ - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 256) { - argsort_n( - arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - argselect_64bit_( - arr, arg, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_64bit_( - arr, arg, pos, pivot_index, right, max_iters - 1); -} - -/* argsort methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void -avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) -{ - /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ - using vectype = typename std::conditional, - zmm_vector>::type; - -/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of - * undefined template 'zmm_vector'*/ -#ifdef __APPLE__ - using argtype = - typename std::conditional, - zmm_vector>::type; -#else - using argtype = - typename std::conditional, - zmm_vector>::type; -#endif - - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argsort_withnan(arr, arg, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argsort_64bit_( - arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argsort(arr, indices.data(), arrsize, hasnan); - return indices; -} - -/* argselect methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ - using vectype = typename std::conditional, - zmm_vector>::type; - -/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of - * undefined template 'zmm_vector'*/ -#ifdef __APPLE__ - using argtype = - typename std::conditional, - zmm_vector>::type; -#else - using argtype = - typename std::conditional, - zmm_vector>::type; -#endif - - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argselect_64bit_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx512_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argselect(arr, indices.data(), k, arrsize, hasnan); - return indices; -} - -/* To maintain compatibility with NumPy build */ -template -X86_SIMD_SORT_INLINE void -avx512_argselect(T *arr, int64_t *arg, arrsize_t k, arrsize_t arrsize) -{ - avx512_argselect(arr, reinterpret_cast(arg), k, arrsize); -} - -template -X86_SIMD_SORT_INLINE void -avx512_argsort(T *arr, int64_t *arg, arrsize_t arrsize) -{ - avx512_argsort(arr, reinterpret_cast(arg), arrsize); -} +#include "xss-common-argsort.h" #endif // AVX512_ARGSORT_64BIT diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 911f5395..65ee85db 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -33,6 +33,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -86,6 +87,10 @@ struct ymm_vector { { return ((0x1ull << num_to_read) - 0x1ull); } + static int32_t convert_mask_to_int(opmask_t mask) + { + return mask; + } template static opmask_t fpclass(reg_t x) { @@ -216,6 +221,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -389,6 +395,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -569,6 +576,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; @@ -747,6 +755,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; @@ -917,6 +926,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 21958027..f44209fa 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -23,6 +23,7 @@ struct zmm_vector<_Float16> { static const uint8_t numlanes = 32; static constexpr int network_sort_threshold = 128; static constexpr int partition_unroll_factor = 0; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h new file mode 100644 index 00000000..c826a527 --- /dev/null +++ b/src/xss-common-argsort.h @@ -0,0 +1,716 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef XSS_COMMON_ARGSORT +#define XSS_COMMON_ARGSORT + +#include "xss-common-qsort.h" +#include "xss-network-keyvaluesort.hpp" +#include + +template +X86_SIMD_SORT_INLINE void std_argselect_withnan( + T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) +{ + std::nth_element(arg + left, + arg + k, + arg + right, + [arr](arrsize_t a, arrsize_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] < arr[b]; + } + else if (std::isnan(arr[a])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { + return arr[left] < arr[right]; + } + else if (std::isnan(arr[left])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + // sort indices according to corresponding array element + return arr[left] < arr[right]; + }); +} + +/* + * Parition one ZMM register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx512(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + argtype::mask_compressstoreu( + arg + left, vtype::knot_opmask(gt_mask), arg_vec); + argtype::mask_compressstoreu( + arg + right - amount_gt_pivot, gt_mask, arg_vec); + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + return amount_gt_pivot; +} + +/* + * Parition one AVX2 register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec); + typename argtype::opmask_t ge_mask + = resize_mask(ge_mask_vtype); + + auto l_store = arg + left; + auto r_store = arg + right - vtype::numlanes; + + int amount_ge_pivot + = argtype::double_compressstore(l_store, r_store, ge_mask, arg_vec); + + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + + return amount_ge_pivot; +} + +template +X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + if constexpr (vtype::vec_type == simd_type::AVX512) { + return partition_vec_avx512(arg, + left, + right, + arg_vec, + curr_vec, + pivot_vec, + smallest_vec, + biggest_vec); + } + else if constexpr (vtype::vec_type == simd_type::AVX2) { + return partition_vec_avx2(arg, + left, + right, + arg_vec, + curr_vec, + pivot_vec, + smallest_vec, + biggest_vec); + } + else { + static_assert(sizeof(argreg_t) == 0, "Should not get here"); + } +} + +/* + * Parition an array based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + using argreg_t = typename argtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + argreg_t argvec = argtype::loadu(arg + left); + reg_t vec = vtype::i64gather(arr, arg + left); + int32_t amount_gt_pivot + = partition_vec(arg, + left, + left + vtype::numlanes, + argvec, + vec, + pivot_vec, + &min_vec, + &max_vec); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return left + (vtype::numlanes - amount_gt_pivot); + } + + // first and last vtype::numlanes values are partitioned at the end + argreg_t argvec_left = argtype::loadu(arg + left); + reg_t vec_left = vtype::i64gather(arr, arg + left); + argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); + reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec; + reg_t curr_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= vtype::numlanes; + arg_vec = argtype::loadu(arg + right); + curr_vec = vtype::i64gather(arr, arg + right); + } + else { + arg_vec = argtype::loadu(arg + left); + curr_vec = vtype::i64gather(arr, arg + left); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec, + curr_vec, + pivot_vec, + &min_vec, + &max_vec); + ; + r_store -= amount_gt_pivot; + l_store += (vtype::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left, + vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(arg, + l_store, + l_store + vtype::numlanes, + argvec_right, + vec_right, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + if (right - left <= 8 * num_unroll * vtype::numlanes) { + return partition_avx512( + arr, arg, left, right, pivot, smallest, biggest); + } + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; + --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + using argreg_t = typename argtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + // first and last vtype::numlanes values are partitioned at the end + reg_t vec_left[num_unroll], vec_right[num_unroll]; + argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); + vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); + argvec_right[ii] = argtype::loadu( + arg + (right - vtype::numlanes * (num_unroll - ii))); + vec_right[ii] = vtype::i64gather( + arr, arg + (right - vtype::numlanes * (num_unroll - ii))); + } + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += num_unroll * vtype::numlanes; + right -= num_unroll * vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec[num_unroll]; + reg_t curr_vec[num_unroll]; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= num_unroll * vtype::numlanes; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] + = argtype::loadu(arg + right + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + right + ii * vtype::numlanes); + } + } + else { + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + left + ii * vtype::numlanes); + } + left += num_unroll * vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec[ii], + curr_vec[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + } + + /* partition and save vec_left and vec_right */ + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left[ii], + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_right[ii], + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, + arrsize_t *arg, + const arrsize_t left, + const arrsize_t right) +{ + if constexpr (vtype::numlanes == 8) { + if (right - left >= vtype::numlanes) { + // median of 8 + arrsize_t size = (right - left) / 8; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]], + arr[arg[left + 5 * size]], + arr[arg[left + 6 * size]], + arr[arg[left + 7 * size]], + arr[arg[left + 8 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[4]; + } + else { + return arr[arg[right]]; + } + } + else if constexpr (vtype::numlanes == 4) { + if (right - left >= vtype::numlanes) { + // median of 4 + arrsize_t size = (right - left) / 4; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[2]; + } + else { + return arr[arg[right]]; + } + } +} + +template +X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 256) { + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = partition_avx512_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if (pivot != smallest) + argsort_64bit_( + arr, arg, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + argsort_64bit_( + arr, arg, pivot_index, right, max_iters - 1); +} + +template +X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, + arrsize_t *arg, + arrsize_t pos, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 256) { + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = partition_avx512_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if ((pivot != smallest) && (pos < pivot_index)) + argselect_64bit_( + arr, arg, pos, left, pivot_index - 1, max_iters - 1); + else if ((pivot != biggest) && (pos >= pivot_index)) + argselect_64bit_( + arr, arg, pos, pivot_index, right, max_iters - 1); +} + +/* argsort methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void +avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) +{ + /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ + using vectype = typename std::conditional, + zmm_vector>::type; + +/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of + * undefined template 'zmm_vector'*/ +#ifdef __APPLE__ + using argtype = + typename std::conditional, + zmm_vector>::type; +#else + using argtype = + typename std::conditional, + zmm_vector>::type; +#endif + + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argsort_withnan(arr, arg, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argsort_64bit_( + arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argsort(arr, indices.data(), arrsize, hasnan); + return indices; +} + +/* argsort methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void +avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) +{ + using vectype = typename std::conditional, + avx2_vector>::type; + + using argtype = + typename std::conditional, + avx2_vector>::type; + + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argsort_withnan(arr, arg, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argsort_64bit_( + arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx2_argsort(arr, indices.data(), arrsize, hasnan); + return indices; +} + +/* argselect methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ + using vectype = typename std::conditional, + zmm_vector>::type; + +/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of + * undefined template 'zmm_vector'*/ +#ifdef __APPLE__ + using argtype = + typename std::conditional, + zmm_vector>::type; +#else + using argtype = + typename std::conditional, + zmm_vector>::type; +#endif + + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argselect_64bit_( + arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx512_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argselect(arr, indices.data(), k, arrsize, hasnan); + return indices; +} + +/* argselect methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + using vectype = typename std::conditional, + avx2_vector>::type; + + using argtype = + typename std::conditional, + avx2_vector>::type; + + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argselect_64bit_( + arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx2_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx2_argselect(arr, indices.data(), k, arrsize, hasnan); + return indices; +} + +#endif // XSS_COMMON_ARGSORT diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index c373ba54..fadb81f5 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -71,6 +71,25 @@ #define X86_SIMD_SORT_UNROLL_LOOP(num) #endif +template +constexpr bool always_false = false; + +/* + * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ +// ZMM 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 +#define NETWORK_32BIT_1 14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1 +#define NETWORK_32BIT_3 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_32BIT_5 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +#define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 +#define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 + typedef size_t arrsize_t; template @@ -82,4 +101,9 @@ struct ymm_vector; template struct avx2_vector; +template +struct avx2_half_vector; + +enum class simd_type : int { AVX2, AVX512 }; + #endif // XSS_COMMON_INCLUDES diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index e76d9f6a..7b89ba21 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -87,7 +87,7 @@ X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size) else { in = vtype::loadu(arr + ii); } - opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); + auto nanmask = vtype::convert_mask_to_int(vtype::template fpclass<0x01 | 0x80>(in)); if (nanmask != 0x00) { found_nan = true; break; diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index cf7a674a..03f15155 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -1,9 +1,29 @@ #ifndef XSS_KEYVALUE_NETWORKS #define XSS_KEYVALUE_NETWORKS -#include "avx512-32bit-qsort.hpp" -#include "avx512-64bit-qsort.hpp" -#include "avx2-64bit-qsort.hpp" +#include "xss-common-includes.h" + +template +typename valueType::opmask_t resize_mask(typename keyType::opmask_t mask) +{ + using inT = typename keyType::opmask_t; + using outT = typename valueType::opmask_t; + + if constexpr (sizeof(inT) == sizeof(outT)) { //std::is_same_v) { + return mask; + } + /* convert __m256i to __m128i */ + else if constexpr (sizeof(inT) == 32 && sizeof(outT) == 16) { + return _mm_castps_si128(_mm256_cvtpd_ps(_mm256_castsi256_pd(mask))); + } + /* convert __m128i to __m256i */ + else if constexpr (sizeof(inT) == 16 && sizeof(outT) == 32) { + return _mm256_cvtepi32_epi64(mask); + } + else { + static_assert(always_false, "Error in func resize_mask"); + } +} template (vtype1::eq(key_t1, key1)); + + reg_t2 index_t1 = vtype2::mask_mov(index2, eqMask, index1); + reg_t2 index_t2 = vtype2::mask_mov(index1, eqMask, index2); key1 = key_t1; key2 = key_t2; @@ -38,7 +58,10 @@ X86_SIMD_SORT_INLINE reg_t1 cmp_merge(reg_t1 in1, opmask_t mask) { reg_t1 tmp_keys = cmp_merge(in1, in2, mask); - indexes1 = vtype2::mask_mov(indexes2, vtype1::eq(tmp_keys, in1), indexes1); + indexes1 = vtype2::mask_mov( + indexes2, + resize_mask(vtype1::eq(tmp_keys, in1)), + indexes1); return tmp_keys; // 0 -> min, 1 -> max } @@ -194,6 +217,38 @@ X86_SIMD_SORT_INLINE reg_t sort_reg_8lanes(reg_t key_zmm, index_type &index_zmm) return key_zmm; } +template +X86_SIMD_SORT_INLINE reg_t sort_ymm_64bit(reg_t key_zmm, index_type &index_zmm) +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const typename vtype1::opmask_t oxAA = vtype1::seti(-1, 0, -1, 0); + const typename vtype1::opmask_t oxCC = vtype1::seti(-1, -1, 0, 0); + + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxAA); + key_zmm = cmp_merge(key_zmm, + vtype1::reverse(key_zmm), + index_zmm, + vtype2::reverse(index_zmm), + oxCC); + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxAA); + return key_zmm; +} + // Assumes zmm is bitonic and performs a recursive half cleaner template +X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_64bit(reg_t key_zmm, + index_type &index_zmm) +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const typename vtype1::opmask_t oxAA = vtype1::seti(-1, 0, -1, 0); + const typename vtype1::opmask_t oxCC = vtype1::seti(-1, -1, 0, 0); + + // 2) half_cleaner[4] + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxCC); + // 3) half_cleaner[1] + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxAA); + return key_zmm; +} + template X86_SIMD_SORT_INLINE void bitonic_merge_dispatch(typename keyType::reg_t &key, @@ -239,8 +324,11 @@ bitonic_merge_dispatch(typename keyType::reg_t &key, else if constexpr (numlanes == 16) { key = bitonic_merge_reg_16lanes(key, value); } + else if constexpr (numlanes == 4) { + key = bitonic_merge_ymm_64bit(key, value); + } else { - static_assert(numlanes == -1, "No implementation"); + static_assert(always_false, "bitonic_merge_dispatch: No implementation"); UNUSED(key); UNUSED(value); } @@ -257,8 +345,11 @@ X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, else if constexpr (numlanes == 16) { key = sort_reg_16lanes(key, value); } + else if constexpr (numlanes == 4) { + key = sort_ymm_64bit(key, value); + } else { - static_assert(numlanes == -1, "No implementation"); + static_assert(always_false, "sort_vec_dispatch: No implementation"); UNUSED(key); UNUSED(value); } @@ -385,14 +476,17 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, } // Masked part of the load X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - indexVecs[i] = indexType::mask_loadu(indexType::zmm_max(), - ioMasks[j], - indices + i * indexType::numlanes); + for (int i = numVecs / 2; i < numVecs; i++) { + indexVecs[i] = indexType::mask_loadu( + indexType::zmm_max(), + resize_mask(ioMasks[i - numVecs / 2]), + indices + i * indexType::numlanes); keyVecs[i] = keyType::template mask_i64gather( - keyType::zmm_max(), ioMasks[j], indexVecs[i], keys); + typename keyType::type_t)>(keyType::zmm_max(), + ioMasks[i - numVecs / 2], + indexVecs[i], + keys); } // Sort each loaded vector @@ -413,7 +507,9 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { indexType::mask_storeu( - indices + i * indexType::numlanes, ioMasks[j], indexVecs[i]); + indices + i * indexType::numlanes, + resize_mask(ioMasks[i - numVecs / 2]), + indexVecs[i]); } }