From 29712514d05519fba6797100d879b8810174b8a3 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Wed, 8 Nov 2023 13:41:49 -0800 Subject: [PATCH 01/10] Support for AVX2 argsort/argselect --- lib/x86simdsort-avx2.cpp | 13 + lib/x86simdsort.cpp | 8 +- src/avx2-32bit-half.hpp | 591 +++++++++++++++++++++++++++++++ src/avx2-64bit-qsort.hpp | 53 ++- src/avx2-emu-funcs.hpp | 127 +++++++ src/avx512-64bit-argsort.hpp | 190 ++++++++-- src/avx512-64bit-common.h | 4 + src/xss-common-includes.h | 3 + src/xss-common-qsort.h | 2 +- src/xss-network-keyvaluesort.hpp | 156 +++++++- 10 files changed, 1086 insertions(+), 61 deletions(-) create mode 100644 src/avx2-32bit-half.hpp diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 81d7d00e..eb25bccf 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -1,6 +1,8 @@ // AVX2 specific routines: #include "avx2-32bit-qsort.hpp" #include "avx2-64bit-qsort.hpp" +#include "avx2-32bit-half.hpp" +#include "avx512-64bit-argsort.hpp" #include "x86simdsort-internal.h" #define DEFINE_ALL_METHODS(type) \ @@ -18,6 +20,17 @@ void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \ { \ avx2_partial_qsort(arr, k, arrsize, hasnan); \ + }\ + template <> \ + std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ + { \ + return avx2_argsort(arr, arrsize, hasnan); \ + } \ + template <> \ + std::vector argselect( \ + type *arr, size_t k, size_t arrsize, bool hasnan) \ + { \ + return avx2_argselect(arr, k, arrsize, hasnan); \ } namespace xss { diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 8ebbc6be..f088e4cd 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -189,12 +189,12 @@ DISPATCH_ALL(partial_qsort, (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_ALL(argsort, (ISA_LIST("none")), - (ISA_LIST("avx512_skx")), - (ISA_LIST("avx512_skx"))) + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx", "avx2"))) DISPATCH_ALL(argselect, (ISA_LIST("none")), - (ISA_LIST("avx512_skx")), - (ISA_LIST("avx512_skx"))) + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx", "avx2"))) #define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \ DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx")))\ diff --git a/src/avx2-32bit-half.hpp b/src/avx2-32bit-half.hpp new file mode 100644 index 00000000..fc738dc7 --- /dev/null +++ b/src/avx2-32bit-half.hpp @@ -0,0 +1,591 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef AVX2_HALF_32BIT +#define AVX2_HALF_32BIT + +#include "xss-common-qsort.h" +#include "avx2-emu-funcs.hpp" + +/* + * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ + +// ymm 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 + +/* + * Assumes ymm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit_half(reg_t ymm) +{ + //static_assert(vtype::numlanes == 0, "This function is not implemented"); + typename vtype::type_t buffer[vtype::numlanes]; + vtype::storeu(buffer, ymm); + std::sort(&buffer[0], &buffer[vtype::numlanes], comparison_func); + return vtype::loadu(buffer); + /* + const typename vtype::opmask_t oxAA = _mm256_set_epi32( + 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0); + const typename vtype::opmask_t oxCC = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0); + const typename vtype::opmask_t oxF0 = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0); + + const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + ymm = cmp_merge( + ymm, + vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm), + oxCC); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + ymm = cmp_merge(ymm, vtype::permutexvar(rev_index, ymm), oxF0); + ymm = cmp_merge( + ymm, + vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm), + oxCC); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + return ymm; + */ +} + +struct avx2_32bit_half_swizzle_ops; + +template <> +struct avx2_half_vector { + using type_t = int32_t; + using reg_t = __m128i; + using ymmi_t = __m128i; + using opmask_t = __m128i; + static const uint8_t numlanes = 4; + + using swizzle_ops = avx2_32bit_half_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT32; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT32; + } + static reg_t zmm_max() + { + return _mm_set1_epi32(type_max()); + } // TODO: this should broadcast bits as is? + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_half(mask); + } + static ymmi_t + seti(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static reg_t + set(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) + { + return _mm_xor_si128(x, y); + } + static opmask_t ge(reg_t x, reg_t y) + { + opmask_t equal = eq(x, y); + opmask_t greater = _mm_cmpgt_epi32(x, y); + return _mm_castps_si128(_mm_or_ps(_mm_castsi128_ps(equal), + _mm_castsi128_ps(greater))); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm_cmpeq_epi32(x, y); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_epi32(src, (const int *) base, index, mask, scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) + { + return set(arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); + } + static reg_t loadu(void const *mem) + { + return _mm_loadu_si128((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm_max_epi32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu32_half(mem, mask, x); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm_maskload_epi32((const int *)mem, mask); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm_maskload_epi32((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(x), + _mm_castsi128_ps(y), + _mm_castsi128_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm_maskstore_epi32((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm_min_epi32(x, y); + } + static reg_t permutexvar(__m128i idx, reg_t ymm) + { + return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); + } + static reg_t permutevar(reg_t ymm, __m128i idx) + { + return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); + } + static reg_t reverse(reg_t ymm) + { + const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + return permutexvar(rev_index, ymm); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32_half(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32_half(v); + } + static reg_t set1(type_t v) + { + return _mm_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm_storeu_si128((__m128i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit_half>(x); + } + static reg_t cast_from(__m128i v) + { + return v; + } + static __m128i cast_to(reg_t v) + { + return v; + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32_half( + left_addr, right_addr, k, reg); + } +}; +template <> +struct avx2_half_vector { + using type_t = uint32_t; + using reg_t = __m128i; + using ymmi_t = __m128i; + using opmask_t = __m128i; + static const uint8_t numlanes = 4; + + using swizzle_ops = avx2_32bit_half_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT32; + } + static type_t type_min() + { + return 0; + } + static reg_t zmm_max() + { + return _mm_set1_epi32(type_max()); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_half(mask); + } + static ymmi_t + seti(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static reg_t + set(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_epi32(src, (const int *) base, index, mask, scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) + { + return set(arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); + } + static opmask_t ge(reg_t x, reg_t y) + { + reg_t maxi = max(x, y); + return eq(maxi, x); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm_cmpeq_epi32(x, y); + } + static reg_t loadu(void const *mem) + { + return _mm_loadu_si128((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm_max_epu32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu32_half(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm_maskload_epi32((const int *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(x), + _mm_castsi128_ps(y), + _mm_castsi128_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm_maskstore_epi32((int *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm_min_epu32(x, y); + } + static reg_t permutexvar(__m128i idx, reg_t ymm) + { + return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); + } + static reg_t permutevar(reg_t ymm, __m128i idx) + { + return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); + } + static reg_t reverse(reg_t ymm) + { + const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + return permutexvar(rev_index, ymm); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32_half(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32_half(v); + } + static reg_t set1(type_t v) + { + return _mm_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm_storeu_si128((__m128i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit_half>(x); + } + static reg_t cast_from(__m128i v) + { + return v; + } + static __m128i cast_to(reg_t v) + { + return v; + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32_half( + left_addr, right_addr, k, reg); + } +}; +template <> +struct avx2_half_vector { + using type_t = float; + using reg_t = __m128; + using ymmi_t = __m128i; + using opmask_t = __m128i; + static const uint8_t numlanes = 4; + + using swizzle_ops = avx2_32bit_half_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITYF; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITYF; + } + static reg_t zmm_max() + { + return _mm_set1_ps(type_max()); + } + + static ymmi_t + seti(int v1, int v2, int v3, int v4) + { + return _mm_set_epi32(v1, v2, v3, v4); + } + static reg_t + set(float v1, float v2, float v3, float v4) + { + return _mm_set_ps(v1, v2, v3, v4); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm_maskload_ps((const float *)mem, mask); + } + static opmask_t ge(reg_t x, reg_t y) + { + return _mm_castps_si128(_mm_cmp_ps(x, y, _CMP_GE_OQ)); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm_castps_si128(_mm_cmp_ps(x, y, _CMP_EQ_OQ)); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask_half(mask); + } + static int32_t convert_mask_to_int(opmask_t mask) + { + return convert_avx2_mask_to_int_half(mask); + } + template + static opmask_t fpclass(reg_t x) + { + if constexpr (type == (0x01 | 0x80)) { + return _mm_castps_si128(_mm_cmp_ps(x, x, _CMP_UNORD_Q)); + } + else { + static_assert(type == (0x01 | 0x80), "should not reach here"); + } + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i64gather_ps(src, (const float*) base, index, _mm_castsi128_ps(mask), scale); + } + static reg_t i64gather(type_t *arr, arrsize_t *ind) + { + return set(arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); + } + static reg_t loadu(void const *mem) + { + return _mm_loadu_ps((float const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm_max_ps(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu32_half(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm_maskload_ps((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm_blendv_ps(x, y, _mm_castsi128_ps(mask)); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm_maskstore_ps((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm_min_ps(x, y); + } + static reg_t permutexvar(__m128i idx, reg_t ymm) + { + return _mm_permutevar_ps(ymm, idx); + } + static reg_t permutevar(reg_t ymm, __m128i idx) + { + return _mm_permutevar_ps(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); + return permutexvar(rev_index, ymm); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32_half(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32_half(v); + } + static reg_t set1(type_t v) + { + return _mm_set1_ps(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm_castsi128_ps( + _mm_shuffle_epi32(_mm_castps_si128(ymm), mask)); + } + static void storeu(void *mem, reg_t x) + { + _mm_storeu_ps((float *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit_half>(x); + } + static reg_t cast_from(__m128i v) + { + return _mm_castsi128_ps(v); + } + static __m128i cast_to(reg_t v) + { + return _mm_castps_si128(v); + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32_half( + left_addr, right_addr, k, reg); + } +}; + +struct avx2_32bit_half_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { + __m128i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + __m128 vf = _mm_castsi128_ps(v); + vf = _mm_permute_ps(vf, 0b10110001); + v = _mm_castps_si128(vf); + } + else if constexpr (scale == 4) { + __m128 vf = _mm_castsi128_ps(v); + vf = _mm_permute_ps(vf, 0b01001110); + v = _mm_castps_si128(vf); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { + __m128i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { + return vtype::reverse(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { + __m128i v1 = vtype::cast_to(reg); + __m128i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { + v1 = _mm_blend_epi32(v1, v2, 0b0101); + } + else if constexpr (scale == 4) { + v1 = _mm_blend_epi32(v1, v2, 0b0011); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; + +#endif // AVX2_HALF_32BIT diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index fd7f92af..bdd33a45 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -76,6 +76,13 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4) + { + return _mm256_set_epi64x(v1, v2, v3, v4); + } static opmask_t kxor_opmask(opmask_t x, opmask_t y) { return _mm256_xor_si256(x, y); @@ -98,12 +105,14 @@ struct avx2_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi64(src, base, index, mask, scale); + return _mm256_mask_i64gather_epi64(src, (const long long int *) base, index, mask, scale); } - template - static reg_t i64gather(__m256i index, void const *base) + static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return _mm256_i64gather_epi64((int64_t const *)base, index, scale); + return set(arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -235,17 +244,25 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4) + { + return _mm256_set_epi64x(v1, v2, v3, v4); + } template static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi64(src, base, index, mask, scale); + return _mm256_mask_i64gather_epi64(src, (const long long int *) base, index, mask, scale); } - template - static reg_t i64gather(__m256i index, void const *base) + static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return _mm256_i64gather_epi64( - (long long int const *)base, index, scale); + return set(arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static opmask_t gt(reg_t x, reg_t y) { @@ -407,7 +424,13 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } - + static reg_t set(type_t v1, + type_t v2, + type_t v3, + type_t v4) + { + return _mm256_set_pd(v1, v2, v3, v4); + } static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm256_maskload_pd((const double *)mem, mask); @@ -425,13 +448,15 @@ struct avx2_vector { mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { return _mm256_mask_i64gather_pd( - src, base, index, _mm256_castsi256_pd(mask), scale); + src, (const type_t *) base, index, _mm256_castsi256_pd(mask), scale); ; } - template - static reg_t i64gather(__m256i index, void const *base) + static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return _mm256_i64gather_pd((double *)base, index, scale); + return set(arr[ind[3]], + arr[ind[2]], + arr[ind[1]], + arr[ind[0]]); } static reg_t loadu(void const *mem) { diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index 9f6229f7..2336439c 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -35,6 +35,21 @@ constexpr auto avx2_mask_helper_lut64 = [] { return lut; }(); +constexpr auto avx2_mask_helper_lut32_half = [] { + std::array, 16> lut {}; + for (int64_t i = 0; i <= 0xF; i++) { + std::array entry {}; + for (int j = 0; j < 4; j++) { + if (((i >> j) & 1) == 1) + entry[j] = 0xFFFFFFFF; + else + entry[j] = 0; + } + lut[i] = entry; + } + return lut; +}(); + constexpr auto avx2_compressstore_lut32_gen = [] { std::array, 256>, 2> lutPair {}; auto &permLut = lutPair[0]; @@ -65,6 +80,36 @@ constexpr auto avx2_compressstore_lut32_gen = [] { constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0]; constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1]; +constexpr auto avx2_compressstore_lut32_half_gen = [] { + std::array, 16>, 2> lutPair {}; + auto &permLut = lutPair[0]; + auto &leftLut = lutPair[1]; + for (int64_t i = 0; i <= 0xF; i++) { + std::array indices {}; + std::array leftEntry = {0, 0, 0, 0}; + int right = 3; + int left = 0; + for (int j = 0; j < 4; j++) { + bool ge = (i >> j) & 1; + if (ge) { + indices[right] = j; + right--; + } + else { + indices[left] = j; + leftEntry[left] = 0xFFFFFFFF; + left++; + } + } + permLut[i] = indices; + leftLut[i] = leftEntry; + } + return lutPair; +}(); + +constexpr auto avx2_compressstore_lut32_half_perm = avx2_compressstore_lut32_half_gen[0]; +constexpr auto avx2_compressstore_lut32_half_left = avx2_compressstore_lut32_half_gen[1]; + constexpr auto avx2_compressstore_lut64_gen = [] { std::array, 16> permLut {}; std::array, 16> leftLut {}; @@ -123,6 +168,19 @@ int32_t convert_avx2_mask_to_int_64bit(__m256i m) return _mm256_movemask_pd(_mm256_castsi256_pd(m)); } +X86_SIMD_SORT_INLINE +__m128i convert_int_to_avx2_mask_half(int32_t m) +{ + return _mm_loadu_si128( + (const __m128i *)avx2_mask_helper_lut32_half[m].data()); +} + +X86_SIMD_SORT_INLINE +int32_t convert_avx2_mask_to_int_half(__m128i m) +{ + return _mm_movemask_ps(_mm_castsi128_ps(m)); +} + // Emulators for intrinsics missing from AVX2 compared to AVX512 template T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) @@ -139,6 +197,19 @@ T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) return std::max(arr[0], arr[7]); } +template +T avx2_emu_reduce_max32_half(typename avx2_half_vector::reg_t x) +{ + using vtype = avx2_half_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::max( + x, vtype::template shuffle(x)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter1); + return std::max(arr[0], arr[3]); +} + template T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) { @@ -154,6 +225,19 @@ T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) return std::min(arr[0], arr[7]); } +template +T avx2_emu_reduce_min32_half(typename avx2_half_vector::reg_t x) +{ + using vtype = avx2_half_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::min( + x, vtype::template shuffle(x)); + T arr[vtype::numlanes]; + vtype::storeu(arr, inter1); + return std::min(arr[0], arr[3]); +} + template T avx2_emu_reduce_max64(typename avx2_vector::reg_t x) { @@ -196,6 +280,26 @@ void avx2_emu_mask_compressstoreu32(void *base_addr, vtype::mask_storeu(leftStore, left, temp); } +template +void avx2_emu_mask_compressstoreu32_half(void *base_addr, + typename avx2_half_vector::opmask_t k, + typename avx2_half_vector::reg_t reg) +{ + using vtype = avx2_half_vector; + + T *leftStore = (T *)base_addr; + + int32_t shortMask = convert_avx2_mask_to_int_half(k); + const __m128i &perm = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask].data()); + const __m128i &left = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_left[shortMask].data()); + + typename vtype::reg_t temp = vtype::permutevar(reg, perm); + + vtype::mask_storeu(leftStore, left, temp); +} + template void avx2_emu_mask_compressstoreu64(void *base_addr, typename avx2_vector::opmask_t k, @@ -240,6 +344,29 @@ int avx2_double_compressstore32(void *left_addr, return _mm_popcnt_u32(shortMask); } +template +int avx2_double_compressstore32_half(void *left_addr, + void *right_addr, + typename avx2_half_vector::opmask_t k, + typename avx2_half_vector::reg_t reg) +{ + using vtype = avx2_half_vector; + + T *leftStore = (T *)left_addr; + T *rightStore = (T *)right_addr; + + int32_t shortMask = convert_avx2_mask_to_int_half(k); + const __m128i &perm = _mm_loadu_si128( + (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask].data()); + + typename vtype::reg_t temp = vtype::permutevar(reg, perm); + + vtype::storeu(leftStore, temp); + vtype::storeu(rightStore, temp); + + return _mm_popcnt_u32(shortMask); +} + template int32_t avx2_double_compressstore64(void *left_addr, void *right_addr, diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 799303d8..2fda486d 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -8,10 +8,14 @@ #define AVX512_ARGSORT_64BIT #include "xss-common-qsort.h" -#include "avx512-64bit-common.h" +//#include "avx512-64bit-common.h" +//#include "avx2-32bit-half.hpp" #include "xss-network-keyvaluesort.hpp" #include +template +struct avx2_half_vector; + template X86_SIMD_SORT_INLINE void std_argselect_withnan( T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) @@ -69,12 +73,8 @@ std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. */ -template -X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx512(type_t *arg, arrsize_t left, arrsize_t right, const argreg_t arg_vec, @@ -94,6 +94,55 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, *biggest_vec = vtype::max(curr_vec, *biggest_vec); return amount_gt_pivot; } + +/* + * Parition one AVX2 register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec); + typename argtype::opmask_t ge_mask = extend_mask(ge_mask_vtype); + + auto l_store = arg + left; + auto r_store = arg + right - vtype::numlanes; + + int amount_ge_pivot = argtype::double_compressstore(l_store, r_store, ge_mask, arg_vec); + + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + + return amount_ge_pivot; +} + +template +X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + if constexpr (sizeof (argreg_t) == 64){ + return partition_vec_avx512(arg, left, right, arg_vec, curr_vec, pivot_vec, smallest_vec, biggest_vec); + }else if constexpr (sizeof (argreg_t) == 32){ + return partition_vec_avx2(arg, left, right, arg_vec, curr_vec, pivot_vec, smallest_vec, biggest_vec); + }else{ + static_assert(sizeof(argreg_t) == 0, "Should not get here"); + } +} + /* * Parition an array based on the pivot and returns the index of the * last element that is less than equal to the pivot. @@ -250,6 +299,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, return left; /* less than vtype::numlanes elements in the array */ using reg_t = typename vtype::reg_t; + using argreg_t = typename argtype::reg_t; reg_t pivot_vec = vtype::set1(pivot); reg_t min_vec = vtype::set1(*smallest); reg_t max_vec = vtype::set1(*biggest); @@ -356,24 +406,42 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, const arrsize_t left, const arrsize_t right) { - if (right - left >= vtype::numlanes) { - // median of 8 - arrsize_t size = (right - left) / 8; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[arg[left + size]], - arr[arg[left + 2 * size]], - arr[arg[left + 3 * size]], - arr[arg[left + 4 * size]], - arr[arg[left + 5 * size]], - arr[arg[left + 6 * size]], - arr[arg[left + 7 * size]], - arr[arg[left + 8 * size]]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = sort_zmm_64bit(rand_vec); - return ((type_t *)&sort)[4]; - } - else { - return arr[arg[right]]; + if constexpr (vtype::numlanes == 8){ + if (right - left >= vtype::numlanes) { + // median of 8 + arrsize_t size = (right - left) / 8; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]], + arr[arg[left + 5 * size]], + arr[arg[left + 6 * size]], + arr[arg[left + 7 * size]], + arr[arg[left + 8 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = sort_zmm_64bit(rand_vec); + return ((type_t *)&sort)[4]; + } + else { + return arr[arg[right]]; + } + }else if constexpr (vtype::numlanes == 4){ + if (right - left >= vtype::numlanes) { + // median of 4 + arrsize_t size = (right - left) / 4; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[2]; + } + else { + return arr[arg[right]]; + } } } @@ -384,6 +452,9 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, arrsize_t right, arrsize_t max_iters) { + using argtype = typename std::conditional, + zmm_vector>::type; /* * Resort to std::sort if quicksort isnt making any progress */ @@ -420,6 +491,9 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, arrsize_t right, arrsize_t max_iters) { + using argtype = typename std::conditional, + zmm_vector>::type; /* * Resort to std::sort if quicksort isnt making any progress */ @@ -495,6 +569,37 @@ avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) return indices; } +/* argsort methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void +avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) +{ + using vectype = typename std::conditional, + avx2_vector>::type; + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argsort_withnan(arr, arg, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argsort_64bit_( + arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx2_argsort(arr, indices.data(), arrsize, hasnan); + return indices; +} + /* argselect methods for 32-bit and 64-bit dtypes */ template X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, @@ -545,6 +650,41 @@ avx512_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) return indices; } +/* argselect methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + using vectype = typename std::conditional, + avx2_vector>::type; + + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argselect_64bit_( + arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx2_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx2_argselect(arr, indices.data(), k, arrsize, hasnan); + return indices; +} + /* To maintain compatibility with NumPy build */ template X86_SIMD_SORT_INLINE void diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 911f5395..34d4013f 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -86,6 +86,10 @@ struct ymm_vector { { return ((0x1ull << num_to_read) - 0x1ull); } + static int32_t convert_mask_to_int(opmask_t mask) + { + return mask; + } template static opmask_t fpclass(reg_t x) { diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index c373ba54..736259c4 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -82,4 +82,7 @@ struct ymm_vector; template struct avx2_vector; +template +struct avx2_half_vector; + #endif // XSS_COMMON_INCLUDES diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index e76d9f6a..7b89ba21 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -87,7 +87,7 @@ X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size) else { in = vtype::loadu(arr + ii); } - opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); + auto nanmask = vtype::convert_mask_to_int(vtype::template fpclass<0x01 | 0x80>(in)); if (nanmask != 0x00) { found_nan = true; break; diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index cf7a674a..67d45be7 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -1,9 +1,37 @@ #ifndef XSS_KEYVALUE_NETWORKS #define XSS_KEYVALUE_NETWORKS +<<<<<<< HEAD #include "avx512-32bit-qsort.hpp" #include "avx512-64bit-qsort.hpp" #include "avx2-64bit-qsort.hpp" +======= +#include "xss-common-includes.h" + +template +struct index_64bit_vector_type; +template <> +struct index_64bit_vector_type<8> { + using type = zmm_vector; +}; +template <> +struct index_64bit_vector_type<4> { + using type = avx2_vector; +}; + +template +typename valueType::opmask_t extend_mask(typename keyType::opmask_t mask){ + if constexpr (sizeof(typename valueType::reg_t) == 64){ + return mask; + }else{ + if constexpr (sizeof(mask) == 32){ + return mask; + }else{ + return _mm256_cvtepi32_epi64(mask); + } + } +} +>>>>>>> d1e90bb (Support for AVX2 argsort/argselect) template (vtype1::eq(key_t1, key1)); reg_t2 index_t1 - = vtype2::mask_mov(index2, vtype1::eq(key_t1, key1), index1); + = vtype2::mask_mov(index2, eqMask, index1); reg_t2 index_t2 - = vtype2::mask_mov(index1, vtype1::eq(key_t1, key1), index2); + = vtype2::mask_mov(index1, eqMask, index2); key1 = key_t1; key2 = key_t2; @@ -38,10 +68,21 @@ X86_SIMD_SORT_INLINE reg_t1 cmp_merge(reg_t1 in1, opmask_t mask) { reg_t1 tmp_keys = cmp_merge(in1, in2, mask); - indexes1 = vtype2::mask_mov(indexes2, vtype1::eq(tmp_keys, in1), indexes1); + indexes1 = vtype2::mask_mov(indexes2, extend_mask(vtype1::eq(tmp_keys, in1)), indexes1); return tmp_keys; // 0 -> min, 1 -> max } +/* + * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ +// ZMM 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 + template +X86_SIMD_SORT_INLINE reg_t sort_ymm_64bit(reg_t key_zmm, index_type &index_zmm) +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const typename vtype1::opmask_t oxAA + = vtype1::seti(-1, 0, -1, 0); + const typename vtype1::opmask_t oxCC + = vtype1::seti(-1, -1, 0, 0); + + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxAA); + key_zmm = cmp_merge( + key_zmm, + vtype1::reverse(key_zmm), + index_zmm, + vtype2::reverse(index_zmm), + oxCC); + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxAA); + return key_zmm; +} + // Assumes zmm is bitonic and performs a recursive half cleaner template +X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_64bit(reg_t key_zmm, + index_type &index_zmm) +{ + using key_swizzle = typename vtype1::swizzle_ops; + using index_swizzle = typename vtype2::swizzle_ops; + + const typename vtype1::opmask_t oxAA + = vtype1::seti(-1, 0, -1, 0); + const typename vtype1::opmask_t oxCC + = vtype1::seti(-1, -1, 0, 0); + + // 2) half_cleaner[4] + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxCC); + // 3) half_cleaner[1] + key_zmm = cmp_merge( + key_zmm, + key_swizzle::template swap_n(key_zmm), + index_zmm, + index_swizzle::template swap_n(index_zmm), + oxAA); + return key_zmm; +} + template X86_SIMD_SORT_INLINE void bitonic_merge_dispatch(typename keyType::reg_t &key, @@ -234,10 +342,16 @@ bitonic_merge_dispatch(typename keyType::reg_t &key, { constexpr int numlanes = keyType::numlanes; if constexpr (numlanes == 8) { +<<<<<<< HEAD key = bitonic_merge_reg_8lanes(key, value); } else if constexpr (numlanes == 16) { key = bitonic_merge_reg_16lanes(key, value); +======= + key = bitonic_merge_zmm_64bit(key, value); + }else if constexpr (numlanes == 4){ + key = bitonic_merge_ymm_64bit(key, value); +>>>>>>> d1e90bb (Support for AVX2 argsort/argselect) } else { static_assert(numlanes == -1, "No implementation"); @@ -252,10 +366,16 @@ X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, { constexpr int numlanes = keyType::numlanes; if constexpr (numlanes == 8) { +<<<<<<< HEAD key = sort_reg_8lanes(key, value); } else if constexpr (numlanes == 16) { key = sort_reg_16lanes(key, value); +======= + key = sort_zmm_64bit(key, value); + }else if constexpr (numlanes == 4){ + key = sort_ymm_64bit(key, value); +>>>>>>> d1e90bb (Support for AVX2 argsort/argselect) } else { static_assert(numlanes == -1, "No implementation"); @@ -366,16 +486,6 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, kreg_t keyVecs[numVecs]; ireg_t indexVecs[numVecs]; - // Generate masks for loading and storing - typename keyType::opmask_t ioMasks[numVecs - numVecs / 2]; - X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - uint64_t num_to_read - = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), - (uint64_t)keyType::numlanes); - ioMasks[j] = keyType::get_partial_loadmask(num_to_read); - } - // Unmasked part of the load X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs / 2; i++) { @@ -385,14 +495,21 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, } // Masked part of the load X86_SIMD_SORT_UNROLL_LOOP(64) - for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + for (int i = numVecs / 2; i < numVecs; i++) { + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), + (uint64_t)keyType::numlanes); + + auto indexMask = indexType::get_partial_loadmask(num_to_read); + auto keyMask = keyType::get_partial_loadmask(num_to_read); + indexVecs[i] = indexType::mask_loadu(indexType::zmm_max(), - ioMasks[j], + indexMask, indices + i * indexType::numlanes); keyVecs[i] = keyType::template mask_i64gather( - keyType::zmm_max(), ioMasks[j], indexVecs[i], keys); + keyType::zmm_max(), keyMask, indexVecs[i], keys); } // Sort each loaded vector @@ -412,8 +529,13 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, // Masked part of the store X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), + (uint64_t)keyType::numlanes); + + auto indexMask = indexType::get_partial_loadmask(num_to_read); indexType::mask_storeu( - indices + i * indexType::numlanes, ioMasks[j], indexVecs[i]); + indices + i * indexType::numlanes, indexMask, indexVecs[i]); } } From 0a4b74044195647672bbb29ab0918420a14a327c Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Fri, 1 Dec 2023 16:32:00 -0800 Subject: [PATCH 02/10] Changed to work with new 32-bit index code --- src/avx2-32bit-half.hpp | 3 +++ src/avx2-32bit-qsort.hpp | 3 +++ src/avx2-64bit-qsort.hpp | 3 +++ src/avx512-16bit-qsort.hpp | 3 +++ src/avx512-32bit-qsort.hpp | 3 +++ src/avx512-64bit-common.h | 6 ++++++ src/avx512fp16-16bit-qsort.hpp | 1 + src/xss-common-includes.h | 4 ++++ src/xss-network-keyvaluesort.hpp | 6 ++++-- 9 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/avx2-32bit-half.hpp b/src/avx2-32bit-half.hpp index fc738dc7..2716b6e8 100644 --- a/src/avx2-32bit-half.hpp +++ b/src/avx2-32bit-half.hpp @@ -71,6 +71,7 @@ struct avx2_half_vector { using ymmi_t = __m128i; using opmask_t = __m128i; static const uint8_t numlanes = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_half_swizzle_ops; @@ -226,6 +227,7 @@ struct avx2_half_vector { using ymmi_t = __m128i; using opmask_t = __m128i; static const uint8_t numlanes = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_half_swizzle_ops; @@ -371,6 +373,7 @@ struct avx2_half_vector { using ymmi_t = __m128i; using opmask_t = __m128i; static const uint8_t numlanes = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_half_swizzle_ops; diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index 521597cd..cf0fbd55 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -70,6 +70,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; @@ -225,6 +226,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; @@ -369,6 +371,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 4; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_32bit_swizzle_ops; diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index bdd33a45..64d71363 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -52,6 +52,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; @@ -220,6 +221,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; @@ -386,6 +388,7 @@ struct avx2_vector { static constexpr int network_sort_threshold = 64; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX2; using swizzle_ops = avx2_64bit_swizzle_ops; diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index be806f5f..32d7419c 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -26,6 +26,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; @@ -208,6 +209,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; @@ -343,6 +345,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 1814699f..74615765 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -42,6 +42,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; @@ -220,6 +221,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; @@ -398,6 +400,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 512; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_32bit_swizzle_ops; diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 34d4013f..65ee85db 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -33,6 +33,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -220,6 +221,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -393,6 +395,7 @@ struct ymm_vector { using regi_t = __m256i; using opmask_t = __mmask8; static const uint8_t numlanes = 8; + static constexpr simd_type vec_type = simd_type::AVX512; static type_t type_max() { @@ -573,6 +576,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; @@ -751,6 +755,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; @@ -921,6 +926,7 @@ struct zmm_vector { static constexpr int network_sort_threshold = 256; #endif static constexpr int partition_unroll_factor = 8; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_64bit_swizzle_ops; diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 21958027..f44209fa 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -23,6 +23,7 @@ struct zmm_vector<_Float16> { static const uint8_t numlanes = 32; static constexpr int network_sort_threshold = 128; static constexpr int partition_unroll_factor = 0; + static constexpr simd_type vec_type = simd_type::AVX512; using swizzle_ops = avx512_16bit_swizzle_ops; diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index 736259c4..addd369a 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -85,4 +85,8 @@ struct avx2_vector; template struct avx2_half_vector; +enum class simd_type:int{ + AVX2, AVX512 +}; + #endif // XSS_COMMON_INCLUDES diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index 67d45be7..36eae44f 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -21,14 +21,16 @@ struct index_64bit_vector_type<4> { template typename valueType::opmask_t extend_mask(typename keyType::opmask_t mask){ - if constexpr (sizeof(typename valueType::reg_t) == 64){ + if constexpr (keyType::vec_type == simd_type::AVX512){ return mask; - }else{ + }else if constexpr (keyType::vec_type == simd_type::AVX2){ if constexpr (sizeof(mask) == 32){ return mask; }else{ return _mm256_cvtepi32_epi64(mask); } + }else{ + static_assert(keyType::vec_type == simd_type::AVX512, "Should not reach here"); } } >>>>>>> d1e90bb (Support for AVX2 argsort/argselect) From 33147b1ca98c446939b59e3715ceaf682c0e3abb Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 5 Dec 2023 12:39:58 -0800 Subject: [PATCH 03/10] Minor cleanup changes --- src/avx2-32bit-half.hpp | 37 +++++++++++--------------------- src/avx512-64bit-argsort.hpp | 9 ++------ src/xss-network-keyvaluesort.hpp | 28 +++++++++++------------- 3 files changed, 28 insertions(+), 46 deletions(-) diff --git a/src/avx2-32bit-half.hpp b/src/avx2-32bit-half.hpp index 2716b6e8..f6cbd3d0 100644 --- a/src/avx2-32bit-half.hpp +++ b/src/avx2-32bit-half.hpp @@ -29,37 +29,26 @@ template X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit_half(reg_t ymm) { - //static_assert(vtype::numlanes == 0, "This function is not implemented"); - typename vtype::type_t buffer[vtype::numlanes]; - vtype::storeu(buffer, ymm); - std::sort(&buffer[0], &buffer[vtype::numlanes], comparison_func); - return vtype::loadu(buffer); - /* - const typename vtype::opmask_t oxAA = _mm256_set_epi32( - 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0); - const typename vtype::opmask_t oxCC = _mm256_set_epi32( - 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0); - const typename vtype::opmask_t oxF0 = _mm256_set_epi32( - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0); - - const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2); - ymm = cmp_merge( - ymm, vtype::template shuffle(ymm), oxAA); + using swizzle = typename vtype::swizzle_ops; + + const typename vtype::opmask_t oxAA + = vtype::seti(-1, 0, -1, 0); + const typename vtype::opmask_t oxCC + = vtype::seti(-1, -1, 0, 0); + ymm = cmp_merge( ymm, - vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm), - oxCC); - ymm = cmp_merge( - ymm, vtype::template shuffle(ymm), oxAA); - ymm = cmp_merge(ymm, vtype::permutexvar(rev_index, ymm), oxF0); + swizzle::template swap_n(ymm), + oxAA); ymm = cmp_merge( ymm, - vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm), + vtype::reverse(ymm), oxCC); ymm = cmp_merge( - ymm, vtype::template shuffle(ymm), oxAA); + ymm, + swizzle::template swap_n(ymm), + oxAA); return ymm; - */ } struct avx2_32bit_half_swizzle_ops; diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 2fda486d..670ed97a 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -8,14 +8,9 @@ #define AVX512_ARGSORT_64BIT #include "xss-common-qsort.h" -//#include "avx512-64bit-common.h" -//#include "avx2-32bit-half.hpp" #include "xss-network-keyvaluesort.hpp" #include -template -struct avx2_half_vector; - template X86_SIMD_SORT_INLINE void std_argselect_withnan( T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) @@ -134,9 +129,9 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, reg_t *smallest_vec, reg_t *biggest_vec) { - if constexpr (sizeof (argreg_t) == 64){ + if constexpr (vtype::vec_type == simd_type::AVX512){ return partition_vec_avx512(arg, left, right, arg_vec, curr_vec, pivot_vec, smallest_vec, biggest_vec); - }else if constexpr (sizeof (argreg_t) == 32){ + }else if constexpr (vtype::vec_type == simd_type::AVX2){ return partition_vec_avx2(arg, left, right, arg_vec, curr_vec, pivot_vec, smallest_vec, biggest_vec); }else{ static_assert(sizeof(argreg_t) == 0, "Should not get here"); diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index 36eae44f..4bddb692 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -487,6 +487,16 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, kreg_t keyVecs[numVecs]; ireg_t indexVecs[numVecs]; + + // Generate masks for loading and storing + typename keyType::opmask_t ioMasks[numVecs - numVecs / 2]; + X86_SIMD_SORT_UNROLL_LOOP(64) + for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), + (uint64_t)keyType::numlanes); + ioMasks[j] = keyType::get_partial_loadmask(num_to_read); + } // Unmasked part of the load X86_SIMD_SORT_UNROLL_LOOP(64) @@ -498,20 +508,13 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, // Masked part of the load X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2; i < numVecs; i++) { - uint64_t num_to_read - = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), - (uint64_t)keyType::numlanes); - - auto indexMask = indexType::get_partial_loadmask(num_to_read); - auto keyMask = keyType::get_partial_loadmask(num_to_read); - indexVecs[i] = indexType::mask_loadu(indexType::zmm_max(), - indexMask, + extend_mask(ioMasks[i - numVecs/2]), indices + i * indexType::numlanes); keyVecs[i] = keyType::template mask_i64gather( - keyType::zmm_max(), keyMask, indexVecs[i], keys); + keyType::zmm_max(), ioMasks[i - numVecs / 2], indexVecs[i], keys); } // Sort each loaded vector @@ -531,13 +534,8 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, // Masked part of the store X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - uint64_t num_to_read - = std::min((uint64_t)std::max(0, N - i * keyType::numlanes), - (uint64_t)keyType::numlanes); - - auto indexMask = indexType::get_partial_loadmask(num_to_read); indexType::mask_storeu( - indices + i * indexType::numlanes, indexMask, indexVecs[i]); + indices + i * indexType::numlanes, extend_mask(ioMasks[i - numVecs/2]), indexVecs[i]); } } From e89a50a32b5f1581584421729efc1392201394bf Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 12 Dec 2023 12:20:47 -0800 Subject: [PATCH 04/10] Fixes builds for things that directly included avx512-64bit-argsort.hpp --- lib/x86simdsort-avx2.cpp | 2 +- src/avx512-64bit-argsort.hpp | 689 +--------------------------------- src/xss-common-argsort.h | 705 +++++++++++++++++++++++++++++++++++ 3 files changed, 708 insertions(+), 688 deletions(-) create mode 100644 src/xss-common-argsort.h diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index eb25bccf..5588cffa 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -2,7 +2,7 @@ #include "avx2-32bit-qsort.hpp" #include "avx2-64bit-qsort.hpp" #include "avx2-32bit-half.hpp" -#include "avx512-64bit-argsort.hpp" +#include "xss-common-argsort.h" #include "x86simdsort-internal.h" #define DEFINE_ALL_METHODS(type) \ diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 670ed97a..3a475da8 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -7,692 +7,7 @@ #ifndef AVX512_ARGSORT_64BIT #define AVX512_ARGSORT_64BIT -#include "xss-common-qsort.h" -#include "xss-network-keyvaluesort.hpp" -#include - -template -X86_SIMD_SORT_INLINE void std_argselect_withnan( - T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) -{ - std::nth_element(arg + left, - arg + k, - arg + right, - [arr](arrsize_t a, arrsize_t b) -> bool { - if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { - return arr[a] < arr[b]; - } - else if (std::isnan(arr[a])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { - return arr[left] < arr[right]; - } - else if (std::isnan(arr[left])) { - return false; - } - else { - return true; - } - }); -} - -/* argsort using std::sort */ -template -X86_SIMD_SORT_INLINE void -std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) -{ - std::sort(arg + left, - arg + right, - [arr](arrsize_t left, arrsize_t right) -> bool { - // sort indices according to corresponding array element - return arr[left] < arr[right]; - }); -} - -/* - * Parition one ZMM register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE int32_t partition_vec_avx512(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); - int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); - argtype::mask_compressstoreu( - arg + left, vtype::knot_opmask(gt_mask), arg_vec); - argtype::mask_compressstoreu( - arg + right - amount_gt_pivot, gt_mask, arg_vec); - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); - return amount_gt_pivot; -} - -/* - * Parition one AVX2 register based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - /* which elements are larger than the pivot */ - typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec); - typename argtype::opmask_t ge_mask = extend_mask(ge_mask_vtype); - - auto l_store = arg + left; - auto r_store = arg + right - vtype::numlanes; - - int amount_ge_pivot = argtype::double_compressstore(l_store, r_store, ge_mask, arg_vec); - - *smallest_vec = vtype::min(curr_vec, *smallest_vec); - *biggest_vec = vtype::max(curr_vec, *biggest_vec); - - return amount_ge_pivot; -} - -template -X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, - arrsize_t left, - arrsize_t right, - const argreg_t arg_vec, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t *smallest_vec, - reg_t *biggest_vec) -{ - if constexpr (vtype::vec_type == simd_type::AVX512){ - return partition_vec_avx512(arg, left, right, arg_vec, curr_vec, pivot_vec, smallest_vec, biggest_vec); - }else if constexpr (vtype::vec_type == simd_type::AVX2){ - return partition_vec_avx2(arg, left, right, arg_vec, curr_vec, pivot_vec, smallest_vec, biggest_vec); - }else{ - static_assert(sizeof(argreg_t) == 0, "Should not get here"); - } -} - -/* - * Parition an array based on the pivot and returns the index of the - * last element that is less than equal to the pivot. - */ -template -X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - if (right - left == vtype::numlanes) { - argreg_t argvec = argtype::loadu(arg + left); - reg_t vec = vtype::i64gather(arr, arg + left); - int32_t amount_gt_pivot - = partition_vec(arg, - left, - left + vtype::numlanes, - argvec, - vec, - pivot_vec, - &min_vec, - &max_vec); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return left + (vtype::numlanes - amount_gt_pivot); - } - - // first and last vtype::numlanes values are partitioned at the end - argreg_t argvec_left = argtype::loadu(arg + left); - reg_t vec_left = vtype::i64gather(arr, arg + left); - argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); - reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += vtype::numlanes; - right -= vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec; - reg_t curr_vec; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= vtype::numlanes; - arg_vec = argtype::loadu(arg + right); - curr_vec = vtype::i64gather(arr, arg + right); - } - else { - arg_vec = argtype::loadu(arg + left); - curr_vec = vtype::i64gather(arr, arg + left); - left += vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec, - curr_vec, - pivot_vec, - &min_vec, - &max_vec); - ; - r_store -= amount_gt_pivot; - l_store += (vtype::numlanes - amount_gt_pivot); - } - - /* partition and save vec_left and vec_right */ - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left, - vec_left, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - amount_gt_pivot = partition_vec(arg, - l_store, - l_store + vtype::numlanes, - argvec_right, - vec_right, - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - type_t pivot, - type_t *smallest, - type_t *biggest) -{ - if (right - left <= 8 * num_unroll * vtype::numlanes) { - return partition_avx512( - arr, arg, left, right, pivot, smallest, biggest); - } - /* make array length divisible by vtype::numlanes , shortening the array */ - for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; - --i) { - *smallest = std::min(*smallest, arr[arg[left]], comparison_func); - *biggest = std::max(*biggest, arr[arg[left]], comparison_func); - if (!comparison_func(arr[arg[left]], pivot)) { - std::swap(arg[left], arg[--right]); - } - else { - ++left; - } - } - - if (left == right) - return left; /* less than vtype::numlanes elements in the array */ - - using reg_t = typename vtype::reg_t; - using argreg_t = typename argtype::reg_t; - reg_t pivot_vec = vtype::set1(pivot); - reg_t min_vec = vtype::set1(*smallest); - reg_t max_vec = vtype::set1(*biggest); - - // first and last vtype::numlanes values are partitioned at the end - reg_t vec_left[num_unroll], vec_right[num_unroll]; - argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); - vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); - argvec_right[ii] = argtype::loadu( - arg + (right - vtype::numlanes * (num_unroll - ii))); - vec_right[ii] = vtype::i64gather( - arr, arg + (right - vtype::numlanes * (num_unroll - ii))); - } - // store points of the vectors - arrsize_t r_store = right - vtype::numlanes; - arrsize_t l_store = left; - // indices for loading the elements - left += num_unroll * vtype::numlanes; - right -= num_unroll * vtype::numlanes; - while (right - left != 0) { - argreg_t arg_vec[num_unroll]; - reg_t curr_vec[num_unroll]; - /* - * if fewer elements are stored on the right side of the array, - * then next elements are loaded from the right side, - * otherwise from the left side - */ - if ((r_store + vtype::numlanes) - right < left - l_store) { - right -= num_unroll * vtype::numlanes; - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] - = argtype::loadu(arg + right + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + right + ii * vtype::numlanes); - } - } - else { - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); - curr_vec[ii] = vtype::i64gather( - arr, arg + left + ii * vtype::numlanes); - } - left += num_unroll * vtype::numlanes; - } - // partition the current vector and save it on both sides of the array - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - arg_vec[ii], - curr_vec[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - } - - /* partition and save vec_left and vec_right */ - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_left[ii], - vec_left[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - X86_SIMD_SORT_UNROLL_LOOP(8) - for (int ii = 0; ii < num_unroll; ++ii) { - int32_t amount_gt_pivot - = partition_vec(arg, - l_store, - r_store + vtype::numlanes, - argvec_right[ii], - vec_right[ii], - pivot_vec, - &min_vec, - &max_vec); - l_store += (vtype::numlanes - amount_gt_pivot); - r_store -= amount_gt_pivot; - } - *smallest = vtype::reducemin(min_vec); - *biggest = vtype::reducemax(max_vec); - return l_store; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, - arrsize_t *arg, - const arrsize_t left, - const arrsize_t right) -{ - if constexpr (vtype::numlanes == 8){ - if (right - left >= vtype::numlanes) { - // median of 8 - arrsize_t size = (right - left) / 8; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[arg[left + size]], - arr[arg[left + 2 * size]], - arr[arg[left + 3 * size]], - arr[arg[left + 4 * size]], - arr[arg[left + 5 * size]], - arr[arg[left + 6 * size]], - arr[arg[left + 7 * size]], - arr[arg[left + 8 * size]]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = sort_zmm_64bit(rand_vec); - return ((type_t *)&sort)[4]; - } - else { - return arr[arg[right]]; - } - }else if constexpr (vtype::numlanes == 4){ - if (right - left >= vtype::numlanes) { - // median of 4 - arrsize_t size = (right - left) / 4; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[arg[left + size]], - arr[arg[left + 2 * size]], - arr[arg[left + 3 * size]], - arr[arg[left + 4 * size]]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[2]; - } - else { - return arr[arg[right]]; - } - } -} - -template -X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, - arrsize_t *arg, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) -{ - using argtype = typename std::conditional, - zmm_vector>::type; - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 256) { - argsort_n( - arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); - if (pivot != smallest) - argsort_64bit_( - arr, arg, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) - argsort_64bit_( - arr, arg, pivot_index, right, max_iters - 1); -} - -template -X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, - arrsize_t *arg, - arrsize_t pos, - arrsize_t left, - arrsize_t right, - arrsize_t max_iters) -{ - using argtype = typename std::conditional, - zmm_vector>::type; - /* - * Resort to std::sort if quicksort isnt making any progress - */ - if (max_iters <= 0) { - std_argsort(arr, arg, left, right + 1); - return; - } - /* - * Base case: use bitonic networks to sort arrays <= 64 - */ - if (right + 1 - left <= 256) { - argsort_n( - arr, arg + left, (int32_t)(right + 1 - left)); - return; - } - type_t pivot = get_pivot_64bit(arr, arg, left, right); - type_t smallest = vtype::type_max(); - type_t biggest = vtype::type_min(); - arrsize_t pivot_index = partition_avx512_unrolled( - arr, arg, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - argselect_64bit_( - arr, arg, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_64bit_( - arr, arg, pos, pivot_index, right, max_iters - 1); -} - -/* argsort methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void -avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) -{ - /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ - using vectype = typename std::conditional, - zmm_vector>::type; - -/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of - * undefined template 'zmm_vector'*/ -#ifdef __APPLE__ - using argtype = - typename std::conditional, - zmm_vector>::type; -#else - using argtype = - typename std::conditional, - zmm_vector>::type; -#endif - - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argsort_withnan(arr, arg, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argsort_64bit_( - arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argsort(arr, indices.data(), arrsize, hasnan); - return indices; -} - -/* argsort methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void -avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) -{ - using vectype = typename std::conditional, - avx2_vector>::type; - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argsort_withnan(arr, arg, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argsort_64bit_( - arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx2_argsort(arr, indices.data(), arrsize, hasnan); - return indices; -} - -/* argselect methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ - using vectype = typename std::conditional, - zmm_vector>::type; - -/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of - * undefined template 'zmm_vector'*/ -#ifdef __APPLE__ - using argtype = - typename std::conditional, - zmm_vector>::type; -#else - using argtype = - typename std::conditional, - zmm_vector>::type; -#endif - - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argselect_64bit_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx512_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx512_argselect(arr, indices.data(), k, arrsize, hasnan); - return indices; -} - -/* argselect methods for 32-bit and 64-bit dtypes */ -template -X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, - arrsize_t *arg, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - using vectype = typename std::conditional, - avx2_vector>::type; - - if (arrsize > 1) { - if constexpr (std::is_floating_point_v) { - if ((hasnan) && (array_has_nan(arr, arrsize))) { - std_argselect_withnan(arr, arg, k, 0, arrsize); - return; - } - } - UNUSED(hasnan); - argselect_64bit_( - arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } -} - -template -X86_SIMD_SORT_INLINE std::vector -avx2_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) -{ - std::vector indices(arrsize); - std::iota(indices.begin(), indices.end(), 0); - avx2_argselect(arr, indices.data(), k, arrsize, hasnan); - return indices; -} - -/* To maintain compatibility with NumPy build */ -template -X86_SIMD_SORT_INLINE void -avx512_argselect(T *arr, int64_t *arg, arrsize_t k, arrsize_t arrsize) -{ - avx512_argselect(arr, reinterpret_cast(arg), k, arrsize); -} - -template -X86_SIMD_SORT_INLINE void -avx512_argsort(T *arr, int64_t *arg, arrsize_t arrsize) -{ - avx512_argsort(arr, reinterpret_cast(arg), arrsize); -} +#include "avx512-64bit-common.h" +#include "xss-common-argsort.h" #endif // AVX512_ARGSORT_64BIT diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h new file mode 100644 index 00000000..67aa2002 --- /dev/null +++ b/src/xss-common-argsort.h @@ -0,0 +1,705 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef XSS_COMMON_ARGSORT +#define XSS_COMMON_ARGSORT + +#include "xss-common-qsort.h" +#include "xss-network-keyvaluesort.hpp" +#include + +template +X86_SIMD_SORT_INLINE void std_argselect_withnan( + T *arr, arrsize_t *arg, arrsize_t k, arrsize_t left, arrsize_t right) +{ + std::nth_element(arg + left, + arg + k, + arg + right, + [arr](arrsize_t a, arrsize_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] < arr[b]; + } + else if (std::isnan(arr[a])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort_withnan(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { + return arr[left] < arr[right]; + } + else if (std::isnan(arr[left])) { + return false; + } + else { + return true; + } + }); +} + +/* argsort using std::sort */ +template +X86_SIMD_SORT_INLINE void +std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) +{ + std::sort(arg + left, + arg + right, + [arr](arrsize_t left, arrsize_t right) -> bool { + // sort indices according to corresponding array element + return arr[left] < arr[right]; + }); +} + +/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of + * undefined template 'zmm_vector'*/ +#ifdef __APPLE__ +using argtypeAVX512 = + typename std::conditional, + zmm_vector>::type; +#else +using argtypeAVX512 = + typename std::conditional, + zmm_vector>::type; +#endif + +/* + * Parition one ZMM register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx512(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec); + int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask); + argtype::mask_compressstoreu( + arg + left, vtype::knot_opmask(gt_mask), arg_vec); + argtype::mask_compressstoreu( + arg + right - amount_gt_pivot, gt_mask, arg_vec); + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + return amount_gt_pivot; +} + +/* + * Parition one AVX2 register based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + /* which elements are larger than the pivot */ + typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec); + typename argtype::opmask_t ge_mask + = extend_mask(ge_mask_vtype); + + auto l_store = arg + left; + auto r_store = arg + right - vtype::numlanes; + + int amount_ge_pivot + = argtype::double_compressstore(l_store, r_store, ge_mask, arg_vec); + + *smallest_vec = vtype::min(curr_vec, *smallest_vec); + *biggest_vec = vtype::max(curr_vec, *biggest_vec); + + return amount_ge_pivot; +} + +template +X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg, + arrsize_t left, + arrsize_t right, + const argreg_t arg_vec, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t *smallest_vec, + reg_t *biggest_vec) +{ + if constexpr (vtype::vec_type == simd_type::AVX512) { + return partition_vec_avx512(arg, + left, + right, + arg_vec, + curr_vec, + pivot_vec, + smallest_vec, + biggest_vec); + } + else if constexpr (vtype::vec_type == simd_type::AVX2) { + return partition_vec_avx2(arg, + left, + right, + arg_vec, + curr_vec, + pivot_vec, + smallest_vec, + biggest_vec); + } + else { + static_assert(sizeof(argreg_t) == 0, "Should not get here"); + } +} + +/* + * Parition an array based on the pivot and returns the index of the + * last element that is less than equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + using argreg_t = typename argtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + if (right - left == vtype::numlanes) { + argreg_t argvec = argtype::loadu(arg + left); + reg_t vec = vtype::i64gather(arr, arg + left); + int32_t amount_gt_pivot + = partition_vec(arg, + left, + left + vtype::numlanes, + argvec, + vec, + pivot_vec, + &min_vec, + &max_vec); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return left + (vtype::numlanes - amount_gt_pivot); + } + + // first and last vtype::numlanes values are partitioned at the end + argreg_t argvec_left = argtype::loadu(arg + left); + reg_t vec_left = vtype::i64gather(arr, arg + left); + argreg_t argvec_right = argtype::loadu(arg + (right - vtype::numlanes)); + reg_t vec_right = vtype::i64gather(arr, arg + (right - vtype::numlanes)); + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += vtype::numlanes; + right -= vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec; + reg_t curr_vec; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= vtype::numlanes; + arg_vec = argtype::loadu(arg + right); + curr_vec = vtype::i64gather(arr, arg + right); + } + else { + arg_vec = argtype::loadu(arg + left); + curr_vec = vtype::i64gather(arr, arg + left); + left += vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec, + curr_vec, + pivot_vec, + &min_vec, + &max_vec); + ; + r_store -= amount_gt_pivot; + l_store += (vtype::numlanes - amount_gt_pivot); + } + + /* partition and save vec_left and vec_right */ + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left, + vec_left, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + amount_gt_pivot = partition_vec(arg, + l_store, + l_store + vtype::numlanes, + argvec_right, + vec_right, + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + if (right - left <= 8 * num_unroll * vtype::numlanes) { + return partition_avx512( + arr, arg, left, right, pivot, smallest, biggest); + } + /* make array length divisible by vtype::numlanes , shortening the array */ + for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; + --i) { + *smallest = std::min(*smallest, arr[arg[left]], comparison_func); + *biggest = std::max(*biggest, arr[arg[left]], comparison_func); + if (!comparison_func(arr[arg[left]], pivot)) { + std::swap(arg[left], arg[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using reg_t = typename vtype::reg_t; + using argreg_t = typename argtype::reg_t; + reg_t pivot_vec = vtype::set1(pivot); + reg_t min_vec = vtype::set1(*smallest); + reg_t max_vec = vtype::set1(*biggest); + + // first and last vtype::numlanes values are partitioned at the end + reg_t vec_left[num_unroll], vec_right[num_unroll]; + argreg_t argvec_left[num_unroll], argvec_right[num_unroll]; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + argvec_left[ii] = argtype::loadu(arg + left + vtype::numlanes * ii); + vec_left[ii] = vtype::i64gather(arr, arg + left + vtype::numlanes * ii); + argvec_right[ii] = argtype::loadu( + arg + (right - vtype::numlanes * (num_unroll - ii))); + vec_right[ii] = vtype::i64gather( + arr, arg + (right - vtype::numlanes * (num_unroll - ii))); + } + // store points of the vectors + arrsize_t r_store = right - vtype::numlanes; + arrsize_t l_store = left; + // indices for loading the elements + left += num_unroll * vtype::numlanes; + right -= num_unroll * vtype::numlanes; + while (right - left != 0) { + argreg_t arg_vec[num_unroll]; + reg_t curr_vec[num_unroll]; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= num_unroll * vtype::numlanes; + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] + = argtype::loadu(arg + right + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + right + ii * vtype::numlanes); + } + } + else { + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + arg_vec[ii] = argtype::loadu(arg + left + ii * vtype::numlanes); + curr_vec[ii] = vtype::i64gather( + arr, arg + left + ii * vtype::numlanes); + } + left += num_unroll * vtype::numlanes; + } + // partition the current vector and save it on both sides of the array + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + arg_vec[ii], + curr_vec[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + } + + /* partition and save vec_left and vec_right */ + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_left[ii], + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + X86_SIMD_SORT_UNROLL_LOOP(8) + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_gt_pivot + = partition_vec(arg, + l_store, + r_store + vtype::numlanes, + argvec_right[ii], + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_gt_pivot); + r_store -= amount_gt_pivot; + } + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, + arrsize_t *arg, + const arrsize_t left, + const arrsize_t right) +{ + if constexpr (vtype::numlanes == 8) { + if (right - left >= vtype::numlanes) { + // median of 8 + arrsize_t size = (right - left) / 8; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]], + arr[arg[left + 5 * size]], + arr[arg[left + 6 * size]], + arr[arg[left + 7 * size]], + arr[arg[left + 8 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = sort_zmm_64bit(rand_vec); + return ((type_t *)&sort)[4]; + } + else { + return arr[arg[right]]; + } + } + else if constexpr (vtype::numlanes == 4) { + if (right - left >= vtype::numlanes) { + // median of 4 + arrsize_t size = (right - left) / 4; + using reg_t = typename vtype::reg_t; + reg_t rand_vec = vtype::set(arr[arg[left + size]], + arr[arg[left + 2 * size]], + arr[arg[left + 3 * size]], + arr[arg[left + 4 * size]]); + // pivot will never be a nan, since there are no nan's! + reg_t sort = vtype::sort_vec(rand_vec); + return ((type_t *)&sort)[2]; + } + else { + return arr[arg[right]]; + } + } +} + +template +X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, + arrsize_t *arg, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) +{ + using argtype = typename std::conditional, + zmm_vector>::type; + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 256) { + argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = partition_avx512_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if (pivot != smallest) + argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); + if (pivot != biggest) + argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); +} + +template +X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, + arrsize_t *arg, + arrsize_t pos, + arrsize_t left, + arrsize_t right, + arrsize_t max_iters) +{ + using argtype = typename std::conditional, + zmm_vector>::type; + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 256) { + argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + arrsize_t pivot_index = partition_avx512_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if ((pivot != smallest) && (pos < pivot_index)) + argselect_64bit_( + arr, arg, pos, left, pivot_index - 1, max_iters - 1); + else if ((pivot != biggest) && (pos >= pivot_index)) + argselect_64bit_( + arr, arg, pos, pivot_index, right, max_iters - 1); +} + +/* argsort methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void +avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) +{ + using vectype = typename std::conditional, + zmm_vector>::type; + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argsort_withnan(arr, arg, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argsort_64bit_( + arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argsort(arr, indices.data(), arrsize, hasnan); + return indices; +} + +/* argsort methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void +avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) +{ + using vectype = typename std::conditional, + avx2_vector>::type; + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argsort_withnan(arr, arg, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argsort_64bit_( + arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx2_argsort(arr, indices.data(), arrsize, hasnan); + return indices; +} + +/* argselect methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + using vectype = typename std::conditional, + zmm_vector>::type; + + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argselect_64bit_( + arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx512_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argselect(arr, indices.data(), k, arrsize, hasnan); + return indices; +} + +/* argselect methods for 32-bit and 64-bit dtypes */ +template +X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, + arrsize_t *arg, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false) +{ + using vectype = typename std::conditional, + avx2_vector>::type; + + if (arrsize > 1) { + if constexpr (std::is_floating_point_v) { + if ((hasnan) && (array_has_nan(arr, arrsize))) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + return; + } + } + UNUSED(hasnan); + argselect_64bit_( + arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } +} + +template +X86_SIMD_SORT_INLINE std::vector +avx2_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx2_argselect(arr, indices.data(), k, arrsize, hasnan); + return indices; +} + +/* To maintain compatibility with NumPy build */ +template +X86_SIMD_SORT_INLINE void +avx512_argselect(T *arr, int64_t *arg, arrsize_t k, arrsize_t arrsize) +{ + avx512_argselect(arr, reinterpret_cast(arg), k, arrsize); +} + +template +X86_SIMD_SORT_INLINE void +avx512_argsort(T *arr, int64_t *arg, arrsize_t arrsize) +{ + avx512_argsort(arr, reinterpret_cast(arg), arrsize); +} + +#endif // XSS_COMMON_ARGSORT From 01bae64898bab7f804190c42ffc0e9e4b8c84e50 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Wed, 3 Jan 2024 16:10:58 -0800 Subject: [PATCH 05/10] Fixes needed when rebasing --- src/xss-common-argsort.h | 77 ++++++++++++++++++++++++++++++---------- 1 file changed, 59 insertions(+), 18 deletions(-) diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 67aa2002..68aaa550 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -482,16 +482,13 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, } } -template +template X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, arrsize_t *arg, arrsize_t left, arrsize_t right, arrsize_t max_iters) { - using argtype = typename std::conditional, - zmm_vector>::type; /* * Resort to std::sort if quicksort isnt making any progress */ @@ -503,7 +500,8 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, * Base case: use bitonic networks to sort arrays <= 64 */ if (right + 1 - left <= 256) { - argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); @@ -512,12 +510,14 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr, arrsize_t pivot_index = partition_avx512_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) - argsort_64bit_(arr, arg, left, pivot_index - 1, max_iters - 1); + argsort_64bit_( + arr, arg, left, pivot_index - 1, max_iters - 1); if (pivot != biggest) - argsort_64bit_(arr, arg, pivot_index, right, max_iters - 1); + argsort_64bit_( + arr, arg, pivot_index, right, max_iters - 1); } -template +template X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, arrsize_t *arg, arrsize_t pos, @@ -525,9 +525,6 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, arrsize_t right, arrsize_t max_iters) { - using argtype = typename std::conditional, - zmm_vector>::type; /* * Resort to std::sort if quicksort isnt making any progress */ @@ -539,7 +536,8 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, * Base case: use bitonic networks to sort arrays <= 64 */ if (right + 1 - left <= 256) { - argsort_n(arr, arg + left, (int32_t)(right + 1 - left)); + argsort_n( + arr, arg + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, arg, left, right); @@ -548,10 +546,10 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, arrsize_t pivot_index = partition_avx512_unrolled( arr, arg, left, right + 1, pivot, &smallest, &biggest); if ((pivot != smallest) && (pos < pivot_index)) - argselect_64bit_( + argselect_64bit_( arr, arg, pos, left, pivot_index - 1, max_iters - 1); else if ((pivot != biggest) && (pos >= pivot_index)) - argselect_64bit_( + argselect_64bit_( arr, arg, pos, pivot_index, right, max_iters - 1); } @@ -560,9 +558,25 @@ template X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) { + /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ using vectype = typename std::conditional, zmm_vector>::type; + +/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of + * undefined template 'zmm_vector'*/ +#ifdef __APPLE__ + using argtype = + typename std::conditional, + zmm_vector>::type; +#else + using argtype = + typename std::conditional, + zmm_vector>::type; +#endif + if (arrsize > 1) { if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { @@ -571,7 +585,7 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) } } UNUSED(hasnan); - argsort_64bit_( + argsort_64bit_( arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } @@ -594,6 +608,13 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) using vectype = typename std::conditional, avx2_vector>::type; + + using argtype = + typename std::conditional, + avx2_vector>::type; + + if (arrsize > 1) { if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { @@ -602,7 +623,7 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) } } UNUSED(hasnan); - argsort_64bit_( + argsort_64bit_( arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } @@ -625,10 +646,25 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, arrsize_t arrsize, bool hasnan = false) { + /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ using vectype = typename std::conditional, zmm_vector>::type; +/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of + * undefined template 'zmm_vector'*/ +#ifdef __APPLE__ + using argtype = + typename std::conditional, + zmm_vector>::type; +#else + using argtype = + typename std::conditional, + zmm_vector>::type; +#endif + if (arrsize > 1) { if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { @@ -637,7 +673,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, } } UNUSED(hasnan); - argselect_64bit_( + argselect_64bit_( arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } @@ -664,6 +700,11 @@ X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, avx2_half_vector, avx2_vector>::type; + using argtype = + typename std::conditional, + avx2_vector>::type; + if (arrsize > 1) { if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { @@ -672,7 +713,7 @@ X86_SIMD_SORT_INLINE void avx2_argselect(T *arr, } } UNUSED(hasnan); - argselect_64bit_( + argselect_64bit_( arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } From 7a83e523ed1cf19101c8b90ffb9ff4dc80c9d6a5 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Wed, 3 Jan 2024 16:20:30 -0800 Subject: [PATCH 06/10] clang-format --- src/avx2-32bit-half.hpp | 92 ++++++++++++-------------------- src/avx2-64bit-qsort.hpp | 43 ++++++--------- src/avx2-emu-funcs.hpp | 28 ++++++---- src/xss-common-argsort.h | 1 - src/xss-common-includes.h | 4 +- src/xss-network-keyvaluesort.hpp | 88 +++++++++++++++--------------- 6 files changed, 112 insertions(+), 144 deletions(-) diff --git a/src/avx2-32bit-half.hpp b/src/avx2-32bit-half.hpp index f6cbd3d0..5a6ee5b5 100644 --- a/src/avx2-32bit-half.hpp +++ b/src/avx2-32bit-half.hpp @@ -30,24 +30,13 @@ template X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit_half(reg_t ymm) { using swizzle = typename vtype::swizzle_ops; - - const typename vtype::opmask_t oxAA - = vtype::seti(-1, 0, -1, 0); - const typename vtype::opmask_t oxCC - = vtype::seti(-1, -1, 0, 0); - - ymm = cmp_merge( - ymm, - swizzle::template swap_n(ymm), - oxAA); - ymm = cmp_merge( - ymm, - vtype::reverse(ymm), - oxCC); - ymm = cmp_merge( - ymm, - swizzle::template swap_n(ymm), - oxAA); + + const typename vtype::opmask_t oxAA = vtype::seti(-1, 0, -1, 0); + const typename vtype::opmask_t oxCC = vtype::seti(-1, -1, 0, 0); + + ymm = cmp_merge(ymm, swizzle::template swap_n(ymm), oxAA); + ymm = cmp_merge(ymm, vtype::reverse(ymm), oxCC); + ymm = cmp_merge(ymm, swizzle::template swap_n(ymm), oxAA); return ymm; } @@ -61,7 +50,7 @@ struct avx2_half_vector { using opmask_t = __m128i; static const uint8_t numlanes = 4; static constexpr simd_type vec_type = simd_type::AVX2; - + using swizzle_ops = avx2_32bit_half_swizzle_ops; static type_t type_max() @@ -81,13 +70,11 @@ struct avx2_half_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask_half(mask); } - static ymmi_t - seti(int v1, int v2, int v3, int v4) + static ymmi_t seti(int v1, int v2, int v3, int v4) { return _mm_set_epi32(v1, v2, v3, v4); } - static reg_t - set(int v1, int v2, int v3, int v4) + static reg_t set(int v1, int v2, int v3, int v4) { return _mm_set_epi32(v1, v2, v3, v4); } @@ -99,8 +86,8 @@ struct avx2_half_vector { { opmask_t equal = eq(x, y); opmask_t greater = _mm_cmpgt_epi32(x, y); - return _mm_castps_si128(_mm_or_ps(_mm_castsi128_ps(equal), - _mm_castsi128_ps(greater))); + return _mm_castps_si128( + _mm_or_ps(_mm_castsi128_ps(equal), _mm_castsi128_ps(greater))); } static opmask_t eq(reg_t x, reg_t y) { @@ -110,14 +97,12 @@ struct avx2_half_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi32(src, (const int *) base, index, mask, scale); + return _mm256_mask_i64gather_epi32( + src, (const int *)base, index, mask, scale); } static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return set(arr[ind[3]], - arr[ind[2]], - arr[ind[1]], - arr[ind[0]]); + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -143,8 +128,8 @@ struct avx2_half_vector { static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(x), - _mm_castsi128_ps(y), - _mm_castsi128_ps(mask))); + _mm_castsi128_ps(y), + _mm_castsi128_ps(mask))); } static void mask_storeu(void *mem, opmask_t mask, reg_t x) { @@ -217,7 +202,7 @@ struct avx2_half_vector { using opmask_t = __m128i; static const uint8_t numlanes = 4; static constexpr simd_type vec_type = simd_type::AVX2; - + using swizzle_ops = avx2_32bit_half_swizzle_ops; static type_t type_max() @@ -237,13 +222,11 @@ struct avx2_half_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask_half(mask); } - static ymmi_t - seti(int v1, int v2, int v3, int v4) + static ymmi_t seti(int v1, int v2, int v3, int v4) { return _mm_set_epi32(v1, v2, v3, v4); } - static reg_t - set(int v1, int v2, int v3, int v4) + static reg_t set(int v1, int v2, int v3, int v4) { return _mm_set_epi32(v1, v2, v3, v4); } @@ -251,14 +234,12 @@ struct avx2_half_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi32(src, (const int *) base, index, mask, scale); + return _mm256_mask_i64gather_epi32( + src, (const int *)base, index, mask, scale); } static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return set(arr[ind[3]], - arr[ind[2]], - arr[ind[1]], - arr[ind[0]]); + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static opmask_t ge(reg_t x, reg_t y) { @@ -289,8 +270,8 @@ struct avx2_half_vector { static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) { return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(x), - _mm_castsi128_ps(y), - _mm_castsi128_ps(mask))); + _mm_castsi128_ps(y), + _mm_castsi128_ps(mask))); } static void mask_storeu(void *mem, opmask_t mask, reg_t x) { @@ -363,7 +344,7 @@ struct avx2_half_vector { using opmask_t = __m128i; static const uint8_t numlanes = 4; static constexpr simd_type vec_type = simd_type::AVX2; - + using swizzle_ops = avx2_32bit_half_swizzle_ops; static type_t type_max() @@ -379,13 +360,11 @@ struct avx2_half_vector { return _mm_set1_ps(type_max()); } - static ymmi_t - seti(int v1, int v2, int v3, int v4) + static ymmi_t seti(int v1, int v2, int v3, int v4) { return _mm_set_epi32(v1, v2, v3, v4); } - static reg_t - set(float v1, float v2, float v3, float v4) + static reg_t set(float v1, float v2, float v3, float v4) { return _mm_set_ps(v1, v2, v3, v4); } @@ -424,14 +403,12 @@ struct avx2_half_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_ps(src, (const float*) base, index, _mm_castsi128_ps(mask), scale); + return _mm256_mask_i64gather_ps( + src, (const float *)base, index, _mm_castsi128_ps(mask), scale); } static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return set(arr[ind[3]], - arr[ind[2]], - arr[ind[1]], - arr[ind[0]]); + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -490,8 +467,7 @@ struct avx2_half_vector { template static reg_t shuffle(reg_t ymm) { - return _mm_castsi128_ps( - _mm_shuffle_epi32(_mm_castps_si128(ymm), mask)); + return _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(ymm), mask)); } static void storeu(void *mem, reg_t x) { @@ -566,9 +542,7 @@ struct avx2_32bit_half_swizzle_ops { __m128i v1 = vtype::cast_to(reg); __m128i v2 = vtype::cast_to(other); - if constexpr (scale == 2) { - v1 = _mm_blend_epi32(v1, v2, 0b0101); - } + if constexpr (scale == 2) { v1 = _mm_blend_epi32(v1, v2, 0b0101); } else if constexpr (scale == 4) { v1 = _mm_blend_epi32(v1, v2, 0b0011); } diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index 64d71363..709d98ef 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -77,10 +77,7 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } - static reg_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4) + static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) { return _mm256_set_epi64x(v1, v2, v3, v4); } @@ -106,14 +103,12 @@ struct avx2_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi64(src, (const long long int *) base, index, mask, scale); + return _mm256_mask_i64gather_epi64( + src, (const long long int *)base, index, mask, scale); } static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return set(arr[ind[3]], - arr[ind[2]], - arr[ind[1]], - arr[ind[0]]); + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static reg_t loadu(void const *mem) { @@ -246,10 +241,7 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } - static reg_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4) + static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) { return _mm256_set_epi64x(v1, v2, v3, v4); } @@ -257,14 +249,12 @@ struct avx2_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_epi64(src, (const long long int *) base, index, mask, scale); + return _mm256_mask_i64gather_epi64( + src, (const long long int *)base, index, mask, scale); } static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return set(arr[ind[3]], - arr[ind[2]], - arr[ind[1]], - arr[ind[0]]); + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static opmask_t gt(reg_t x, reg_t y) { @@ -427,10 +417,7 @@ struct avx2_vector { { return _mm256_set_epi64x(v1, v2, v3, v4); } - static reg_t set(type_t v1, - type_t v2, - type_t v3, - type_t v4) + static reg_t set(type_t v1, type_t v2, type_t v3, type_t v4) { return _mm256_set_pd(v1, v2, v3, v4); } @@ -450,16 +437,16 @@ struct avx2_vector { static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) { - return _mm256_mask_i64gather_pd( - src, (const type_t *) base, index, _mm256_castsi256_pd(mask), scale); + return _mm256_mask_i64gather_pd(src, + (const type_t *)base, + index, + _mm256_castsi256_pd(mask), + scale); ; } static reg_t i64gather(type_t *arr, arrsize_t *ind) { - return set(arr[ind[3]], - arr[ind[2]], - arr[ind[1]], - arr[ind[0]]); + return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); } static reg_t loadu(void const *mem) { diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index 2336439c..6e40d2a6 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -107,8 +107,10 @@ constexpr auto avx2_compressstore_lut32_half_gen = [] { return lutPair; }(); -constexpr auto avx2_compressstore_lut32_half_perm = avx2_compressstore_lut32_half_gen[0]; -constexpr auto avx2_compressstore_lut32_half_left = avx2_compressstore_lut32_half_gen[1]; +constexpr auto avx2_compressstore_lut32_half_perm + = avx2_compressstore_lut32_half_gen[0]; +constexpr auto avx2_compressstore_lut32_half_left + = avx2_compressstore_lut32_half_gen[1]; constexpr auto avx2_compressstore_lut64_gen = [] { std::array, 16> permLut {}; @@ -281,9 +283,10 @@ void avx2_emu_mask_compressstoreu32(void *base_addr, } template -void avx2_emu_mask_compressstoreu32_half(void *base_addr, - typename avx2_half_vector::opmask_t k, - typename avx2_half_vector::reg_t reg) +void avx2_emu_mask_compressstoreu32_half( + void *base_addr, + typename avx2_half_vector::opmask_t k, + typename avx2_half_vector::reg_t reg) { using vtype = avx2_half_vector; @@ -291,9 +294,11 @@ void avx2_emu_mask_compressstoreu32_half(void *base_addr, int32_t shortMask = convert_avx2_mask_to_int_half(k); const __m128i &perm = _mm_loadu_si128( - (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask].data()); + (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] + .data()); const __m128i &left = _mm_loadu_si128( - (const __m128i *)avx2_compressstore_lut32_half_left[shortMask].data()); + (const __m128i *)avx2_compressstore_lut32_half_left[shortMask] + .data()); typename vtype::reg_t temp = vtype::permutevar(reg, perm); @@ -346,9 +351,9 @@ int avx2_double_compressstore32(void *left_addr, template int avx2_double_compressstore32_half(void *left_addr, - void *right_addr, - typename avx2_half_vector::opmask_t k, - typename avx2_half_vector::reg_t reg) + void *right_addr, + typename avx2_half_vector::opmask_t k, + typename avx2_half_vector::reg_t reg) { using vtype = avx2_half_vector; @@ -357,7 +362,8 @@ int avx2_double_compressstore32_half(void *left_addr, int32_t shortMask = convert_avx2_mask_to_int_half(k); const __m128i &perm = _mm_loadu_si128( - (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask].data()); + (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] + .data()); typename vtype::reg_t temp = vtype::permutevar(reg, perm); diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 68aaa550..c72cbc72 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -614,7 +614,6 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) avx2_half_vector, avx2_vector>::type; - if (arrsize > 1) { if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index addd369a..9f793e37 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -85,8 +85,6 @@ struct avx2_vector; template struct avx2_half_vector; -enum class simd_type:int{ - AVX2, AVX512 -}; +enum class simd_type : int { AVX2, AVX512 }; #endif // XSS_COMMON_INCLUDES diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index 4bddb692..796a200e 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -20,17 +20,18 @@ struct index_64bit_vector_type<4> { }; template -typename valueType::opmask_t extend_mask(typename keyType::opmask_t mask){ - if constexpr (keyType::vec_type == simd_type::AVX512){ - return mask; - }else if constexpr (keyType::vec_type == simd_type::AVX2){ - if constexpr (sizeof(mask) == 32){ - return mask; - }else{ +typename valueType::opmask_t extend_mask(typename keyType::opmask_t mask) +{ + if constexpr (keyType::vec_type == simd_type::AVX512) { return mask; } + else if constexpr (keyType::vec_type == simd_type::AVX2) { + if constexpr (sizeof(mask) == 32) { return mask; } + else { return _mm256_cvtepi32_epi64(mask); } - }else{ - static_assert(keyType::vec_type == simd_type::AVX512, "Should not reach here"); + } + else { + static_assert(keyType::vec_type == simd_type::AVX512, + "Should not reach here"); } } >>>>>>> d1e90bb (Support for AVX2 argsort/argselect) @@ -44,13 +45,11 @@ COEX(reg_t1 &key1, reg_t1 &key2, reg_t2 &index1, reg_t2 &index2) { reg_t1 key_t1 = vtype1::min(key1, key2); reg_t1 key_t2 = vtype1::max(key1, key2); - + auto eqMask = extend_mask(vtype1::eq(key_t1, key1)); - reg_t2 index_t1 - = vtype2::mask_mov(index2, eqMask, index1); - reg_t2 index_t2 - = vtype2::mask_mov(index1, eqMask, index2); + reg_t2 index_t1 = vtype2::mask_mov(index2, eqMask, index1); + reg_t2 index_t2 = vtype2::mask_mov(index1, eqMask, index2); key1 = key_t1; key2 = key_t2; @@ -70,7 +69,10 @@ X86_SIMD_SORT_INLINE reg_t1 cmp_merge(reg_t1 in1, opmask_t mask) { reg_t1 tmp_keys = cmp_merge(in1, in2, mask); - indexes1 = vtype2::mask_mov(indexes2, extend_mask(vtype1::eq(tmp_keys, in1)), indexes1); + indexes1 = vtype2::mask_mov( + indexes2, + extend_mask(vtype1::eq(tmp_keys, in1)), + indexes1); return tmp_keys; // 0 -> min, 1 -> max } @@ -245,24 +247,21 @@ X86_SIMD_SORT_INLINE reg_t sort_ymm_64bit(reg_t key_zmm, index_type &index_zmm) { using key_swizzle = typename vtype1::swizzle_ops; using index_swizzle = typename vtype2::swizzle_ops; - - const typename vtype1::opmask_t oxAA - = vtype1::seti(-1, 0, -1, 0); - const typename vtype1::opmask_t oxCC - = vtype1::seti(-1, -1, 0, 0); - + + const typename vtype1::opmask_t oxAA = vtype1::seti(-1, 0, -1, 0); + const typename vtype1::opmask_t oxCC = vtype1::seti(-1, -1, 0, 0); + key_zmm = cmp_merge( key_zmm, key_swizzle::template swap_n(key_zmm), index_zmm, index_swizzle::template swap_n(index_zmm), oxAA); - key_zmm = cmp_merge( - key_zmm, - vtype1::reverse(key_zmm), - index_zmm, - vtype2::reverse(index_zmm), - oxCC); + key_zmm = cmp_merge(key_zmm, + vtype1::reverse(key_zmm), + index_zmm, + vtype2::reverse(index_zmm), + oxCC); key_zmm = cmp_merge( key_zmm, key_swizzle::template swap_n(key_zmm), @@ -314,12 +313,10 @@ X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_64bit(reg_t key_zmm, { using key_swizzle = typename vtype1::swizzle_ops; using index_swizzle = typename vtype2::swizzle_ops; - - const typename vtype1::opmask_t oxAA - = vtype1::seti(-1, 0, -1, 0); - const typename vtype1::opmask_t oxCC - = vtype1::seti(-1, -1, 0, 0); - + + const typename vtype1::opmask_t oxAA = vtype1::seti(-1, 0, -1, 0); + const typename vtype1::opmask_t oxCC = vtype1::seti(-1, -1, 0, 0); + // 2) half_cleaner[4] key_zmm = cmp_merge( key_zmm, @@ -351,7 +348,8 @@ bitonic_merge_dispatch(typename keyType::reg_t &key, key = bitonic_merge_reg_16lanes(key, value); ======= key = bitonic_merge_zmm_64bit(key, value); - }else if constexpr (numlanes == 4){ + } + else if constexpr (numlanes == 4) { key = bitonic_merge_ymm_64bit(key, value); >>>>>>> d1e90bb (Support for AVX2 argsort/argselect) } @@ -375,7 +373,8 @@ X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, key = sort_reg_16lanes(key, value); ======= key = sort_zmm_64bit(key, value); - }else if constexpr (numlanes == 4){ + } + else if constexpr (numlanes == 4) { key = sort_ymm_64bit(key, value); >>>>>>> d1e90bb (Support for AVX2 argsort/argselect) } @@ -487,7 +486,7 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, kreg_t keyVecs[numVecs]; ireg_t indexVecs[numVecs]; - + // Generate masks for loading and storing typename keyType::opmask_t ioMasks[numVecs - numVecs / 2]; X86_SIMD_SORT_UNROLL_LOOP(64) @@ -508,13 +507,16 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, // Masked part of the load X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2; i < numVecs; i++) { - indexVecs[i] = indexType::mask_loadu(indexType::zmm_max(), - extend_mask(ioMasks[i - numVecs/2]), - indices + i * indexType::numlanes); + indexVecs[i] = indexType::mask_loadu( + indexType::zmm_max(), + extend_mask(ioMasks[i - numVecs / 2]), + indices + i * indexType::numlanes); keyVecs[i] = keyType::template mask_i64gather( - keyType::zmm_max(), ioMasks[i - numVecs / 2], indexVecs[i], keys); + typename keyType::type_t)>(keyType::zmm_max(), + ioMasks[i - numVecs / 2], + indexVecs[i], + keys); } // Sort each loaded vector @@ -535,7 +537,9 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { indexType::mask_storeu( - indices + i * indexType::numlanes, extend_mask(ioMasks[i - numVecs/2]), indexVecs[i]); + indices + i * indexType::numlanes, + extend_mask(ioMasks[i - numVecs / 2]), + indexVecs[i]); } } From 55b607769b6915d53b04e770e74b6ae139f00fa1 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Fri, 5 Jan 2024 12:30:27 -0800 Subject: [PATCH 07/10] Fixed problems on 32-bit systems --- src/avx2-32bit-half.hpp | 21 +++++++++++++++++++++ src/avx2-64bit-qsort.hpp | 25 +++++++++++++++++++++++++ src/xss-network-keyvaluesort.hpp | 17 +++++++++++++++-- 3 files changed, 61 insertions(+), 2 deletions(-) diff --git a/src/avx2-32bit-half.hpp b/src/avx2-32bit-half.hpp index 5a6ee5b5..22f877ac 100644 --- a/src/avx2-32bit-half.hpp +++ b/src/avx2-32bit-half.hpp @@ -100,6 +100,13 @@ struct avx2_half_vector { return _mm256_mask_i64gather_epi32( src, (const int *)base, index, mask, scale); } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) + { + return _mm_mask_i32gather_epi32( + src, (const int *)base, index, mask, scale); + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); @@ -237,6 +244,13 @@ struct avx2_half_vector { return _mm256_mask_i64gather_epi32( src, (const int *)base, index, mask, scale); } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) + { + return _mm_mask_i32gather_epi32( + src, (const int *)base, index, mask, scale); + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); @@ -406,6 +420,13 @@ struct avx2_half_vector { return _mm256_mask_i64gather_ps( src, (const float *)base, index, _mm_castsi128_ps(mask), scale); } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) + { + return _mm_mask_i32gather_ps( + src, (const float *)base, index, _mm_castsi128_ps(mask), scale); + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index 709d98ef..764bdcf4 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -106,6 +106,13 @@ struct avx2_vector { return _mm256_mask_i64gather_epi64( src, (const long long int *)base, index, mask, scale); } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) + { + return _mm256_mask_i32gather_epi64( + src, (const long long int *)base, index, mask, scale); + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); @@ -252,6 +259,13 @@ struct avx2_vector { return _mm256_mask_i64gather_epi64( src, (const long long int *)base, index, mask, scale); } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) + { + return _mm256_mask_i32gather_epi64( + src, (const long long int *)base, index, mask, scale); + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); @@ -444,6 +458,17 @@ struct avx2_vector { scale); ; } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m128i index, void const *base) + { + return _mm256_mask_i32gather_pd(src, + (const type_t *)base, + index, + _mm256_castsi256_pd(mask), + scale); + ; + } static reg_t i64gather(type_t *arr, arrsize_t *ind) { return set(arr[ind[3]], arr[ind[2]], arr[ind[1]], arr[ind[0]]); diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index 796a200e..37544148 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -22,11 +22,24 @@ struct index_64bit_vector_type<4> { template typename valueType::opmask_t extend_mask(typename keyType::opmask_t mask) { + using inT = typename keyType::opmask_t; + using outT = typename valueType::opmask_t; + if constexpr (keyType::vec_type == simd_type::AVX512) { return mask; } else if constexpr (keyType::vec_type == simd_type::AVX2) { - if constexpr (sizeof(mask) == 32) { return mask; } - else { + if constexpr (sizeof(inT) == sizeof(outT)) { return mask; } + else if constexpr (sizeof(inT) == 32 && sizeof(outT) == 16){ + // We need to convert a mask made of 64 bit integers to 32 bit integers + // This does this by taking advantage of the fact that the only bit that matters + // is the very topmost bit, which becomes the sign bit when cast to floating point + + // TODO try and figure out if there is a better way to do this + return _mm_castps_si128(_mm256_cvtpd_ps(_mm256_castsi256_pd(mask))); + } + else if constexpr (sizeof(inT) == 16 && sizeof(outT) == 32){ return _mm256_cvtepi32_epi64(mask); + }else{ + static_assert(sizeof(inT) == -1, "should not reach here"); } } else { From 7388ed7cf6139f3afc1758a51f55b510e437f357 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 9 Jan 2024 15:23:59 -0800 Subject: [PATCH 08/10] Some fixes needed when rebasing --- src/xss-common-argsort.h | 2 +- src/xss-network-keyvaluesort.hpp | 31 ++++++------------------------- 2 files changed, 7 insertions(+), 26 deletions(-) diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index c72cbc72..5482941b 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -456,7 +456,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, arr[arg[left + 7 * size]], arr[arg[left + 8 * size]]); // pivot will never be a nan, since there are no nan's! - reg_t sort = sort_zmm_64bit(rand_vec); + reg_t sort = vtype::sort_vec(rand_vec); return ((type_t *)&sort)[4]; } else { diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index 37544148..b42fd87d 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -1,23 +1,13 @@ #ifndef XSS_KEYVALUE_NETWORKS #define XSS_KEYVALUE_NETWORKS -<<<<<<< HEAD -#include "avx512-32bit-qsort.hpp" -#include "avx512-64bit-qsort.hpp" -#include "avx2-64bit-qsort.hpp" -======= #include "xss-common-includes.h" -template -struct index_64bit_vector_type; -template <> -struct index_64bit_vector_type<8> { - using type = zmm_vector; -}; -template <> -struct index_64bit_vector_type<4> { - using type = avx2_vector; -}; +#define NETWORK_32BIT_1 14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1 +#define NETWORK_32BIT_3 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_32BIT_5 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +#define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 +#define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 template typename valueType::opmask_t extend_mask(typename keyType::opmask_t mask) @@ -44,10 +34,9 @@ typename valueType::opmask_t extend_mask(typename keyType::opmask_t mask) } else { static_assert(keyType::vec_type == simd_type::AVX512, - "Should not reach here"); + "should not reach here"); } } ->>>>>>> d1e90bb (Support for AVX2 argsort/argselect) template (key, value); } else if constexpr (numlanes == 16) { key = bitonic_merge_reg_16lanes(key, value); -======= - key = bitonic_merge_zmm_64bit(key, value); } else if constexpr (numlanes == 4) { key = bitonic_merge_ymm_64bit(key, value); ->>>>>>> d1e90bb (Support for AVX2 argsort/argselect) } else { static_assert(numlanes == -1, "No implementation"); @@ -379,17 +364,13 @@ X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, { constexpr int numlanes = keyType::numlanes; if constexpr (numlanes == 8) { -<<<<<<< HEAD key = sort_reg_8lanes(key, value); } else if constexpr (numlanes == 16) { key = sort_reg_16lanes(key, value); -======= - key = sort_zmm_64bit(key, value); } else if constexpr (numlanes == 4) { key = sort_ymm_64bit(key, value); ->>>>>>> d1e90bb (Support for AVX2 argsort/argselect) } else { static_assert(numlanes == -1, "No implementation"); From 1834edc8a6c517b863eadc2d00fcec889f140990 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Thu, 11 Jan 2024 15:32:52 -0800 Subject: [PATCH 09/10] Code review fixes --- src/avx2-32bit-half.hpp | 41 ++++++++------------------------ src/avx2-emu-funcs.hpp | 8 +++---- src/xss-common-argsort.h | 31 +----------------------- src/xss-network-keyvaluesort.hpp | 10 ++++---- 4 files changed, 20 insertions(+), 70 deletions(-) diff --git a/src/avx2-32bit-half.hpp b/src/avx2-32bit-half.hpp index 22f877ac..52697692 100644 --- a/src/avx2-32bit-half.hpp +++ b/src/avx2-32bit-half.hpp @@ -7,7 +7,7 @@ #ifndef AVX2_HALF_32BIT #define AVX2_HALF_32BIT -#include "xss-common-qsort.h" +#include "xss-common-includes.h" #include "avx2-emu-funcs.hpp" /* @@ -46,7 +46,7 @@ template <> struct avx2_half_vector { using type_t = int32_t; using reg_t = __m128i; - using ymmi_t = __m128i; + using regi_t = __m128i; using opmask_t = __m128i; static const uint8_t numlanes = 4; static constexpr simd_type vec_type = simd_type::AVX2; @@ -70,7 +70,7 @@ struct avx2_half_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask_half(mask); } - static ymmi_t seti(int v1, int v2, int v3, int v4) + static regi_t seti(int v1, int v2, int v3, int v4) { return _mm_set_epi32(v1, v2, v3, v4); } @@ -86,8 +86,7 @@ struct avx2_half_vector { { opmask_t equal = eq(x, y); opmask_t greater = _mm_cmpgt_epi32(x, y); - return _mm_castps_si128( - _mm_or_ps(_mm_castsi128_ps(equal), _mm_castsi128_ps(greater))); + return _mm_or_si128(equal, greater); } static opmask_t eq(reg_t x, reg_t y) { @@ -150,10 +149,6 @@ struct avx2_half_vector { { return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); } - static reg_t permutevar(reg_t ymm, __m128i idx) - { - return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); - } static reg_t reverse(reg_t ymm) { const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); @@ -205,7 +200,7 @@ template <> struct avx2_half_vector { using type_t = uint32_t; using reg_t = __m128i; - using ymmi_t = __m128i; + using regi_t = __m128i; using opmask_t = __m128i; static const uint8_t numlanes = 4; static constexpr simd_type vec_type = simd_type::AVX2; @@ -229,7 +224,7 @@ struct avx2_half_vector { auto mask = ((0x1ull << num_to_read) - 0x1ull); return convert_int_to_avx2_mask_half(mask); } - static ymmi_t seti(int v1, int v2, int v3, int v4) + static regi_t seti(int v1, int v2, int v3, int v4) { return _mm_set_epi32(v1, v2, v3, v4); } @@ -299,10 +294,6 @@ struct avx2_half_vector { { return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); } - static reg_t permutevar(reg_t ymm, __m128i idx) - { - return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx)); - } static reg_t reverse(reg_t ymm) { const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); @@ -354,7 +345,7 @@ template <> struct avx2_half_vector { using type_t = float; using reg_t = __m128; - using ymmi_t = __m128i; + using regi_t = __m128i; using opmask_t = __m128i; static const uint8_t numlanes = 4; static constexpr simd_type vec_type = simd_type::AVX2; @@ -374,7 +365,7 @@ struct avx2_half_vector { return _mm_set1_ps(type_max()); } - static ymmi_t seti(int v1, int v2, int v3, int v4) + static regi_t seti(int v1, int v2, int v3, int v4) { return _mm_set_epi32(v1, v2, v3, v4); } @@ -464,10 +455,6 @@ struct avx2_half_vector { { return _mm_permutevar_ps(ymm, idx); } - static reg_t permutevar(reg_t ymm, __m128i idx) - { - return _mm_permutevar_ps(ymm, idx); - } static reg_t reverse(reg_t ymm) { const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3); @@ -520,23 +507,15 @@ struct avx2_32bit_half_swizzle_ops { template X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) { - __m128i v = vtype::cast_to(reg); - if constexpr (scale == 2) { - __m128 vf = _mm_castsi128_ps(v); - vf = _mm_permute_ps(vf, 0b10110001); - v = _mm_castps_si128(vf); + return vtype::template shuffle<0b10110001>(reg); } else if constexpr (scale == 4) { - __m128 vf = _mm_castsi128_ps(v); - vf = _mm_permute_ps(vf, 0b01001110); - v = _mm_castps_si128(vf); + return vtype::template shuffle<0b01001110>(reg); } else { static_assert(scale == -1, "should not be reached"); } - - return vtype::cast_from(v); } template diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index 6e40d2a6..38489626 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -277,7 +277,7 @@ void avx2_emu_mask_compressstoreu32(void *base_addr, const __m256i &left = _mm256_loadu_si256( (const __m256i *)avx2_compressstore_lut32_left[shortMask].data()); - typename vtype::reg_t temp = vtype::permutevar(reg, perm); + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); vtype::mask_storeu(leftStore, left, temp); } @@ -300,7 +300,7 @@ void avx2_emu_mask_compressstoreu32_half( (const __m128i *)avx2_compressstore_lut32_half_left[shortMask] .data()); - typename vtype::reg_t temp = vtype::permutevar(reg, perm); + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); vtype::mask_storeu(leftStore, left, temp); } @@ -341,7 +341,7 @@ int avx2_double_compressstore32(void *left_addr, const __m256i &perm = _mm256_loadu_si256( (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); - typename vtype::reg_t temp = vtype::permutevar(reg, perm); + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); vtype::storeu(leftStore, temp); vtype::storeu(rightStore, temp); @@ -365,7 +365,7 @@ int avx2_double_compressstore32_half(void *left_addr, (const __m128i *)avx2_compressstore_lut32_half_perm[shortMask] .data()); - typename vtype::reg_t temp = vtype::permutevar(reg, perm); + typename vtype::reg_t temp = vtype::permutexvar(perm, reg); vtype::storeu(leftStore, temp); vtype::storeu(rightStore, temp); diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 5482941b..c826a527 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -64,20 +64,6 @@ std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) }); } -/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of - * undefined template 'zmm_vector'*/ -#ifdef __APPLE__ -using argtypeAVX512 = - typename std::conditional, - zmm_vector>::type; -#else -using argtypeAVX512 = - typename std::conditional, - zmm_vector>::type; -#endif - /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. @@ -129,7 +115,7 @@ X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg, /* which elements are larger than the pivot */ typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec); typename argtype::opmask_t ge_mask - = extend_mask(ge_mask_vtype); + = resize_mask(ge_mask_vtype); auto l_store = arg + left; auto r_store = arg + right - vtype::numlanes; @@ -727,19 +713,4 @@ avx2_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) return indices; } -/* To maintain compatibility with NumPy build */ -template -X86_SIMD_SORT_INLINE void -avx512_argselect(T *arr, int64_t *arg, arrsize_t k, arrsize_t arrsize) -{ - avx512_argselect(arr, reinterpret_cast(arg), k, arrsize); -} - -template -X86_SIMD_SORT_INLINE void -avx512_argsort(T *arr, int64_t *arg, arrsize_t arrsize) -{ - avx512_argsort(arr, reinterpret_cast(arg), arrsize); -} - #endif // XSS_COMMON_ARGSORT diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index b42fd87d..3f74633e 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -10,7 +10,7 @@ #define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 template -typename valueType::opmask_t extend_mask(typename keyType::opmask_t mask) +typename valueType::opmask_t resize_mask(typename keyType::opmask_t mask) { using inT = typename keyType::opmask_t; using outT = typename valueType::opmask_t; @@ -48,7 +48,7 @@ COEX(reg_t1 &key1, reg_t1 &key2, reg_t2 &index1, reg_t2 &index2) reg_t1 key_t1 = vtype1::min(key1, key2); reg_t1 key_t2 = vtype1::max(key1, key2); - auto eqMask = extend_mask(vtype1::eq(key_t1, key1)); + auto eqMask = resize_mask(vtype1::eq(key_t1, key1)); reg_t2 index_t1 = vtype2::mask_mov(index2, eqMask, index1); reg_t2 index_t2 = vtype2::mask_mov(index1, eqMask, index2); @@ -73,7 +73,7 @@ X86_SIMD_SORT_INLINE reg_t1 cmp_merge(reg_t1 in1, reg_t1 tmp_keys = cmp_merge(in1, in2, mask); indexes1 = vtype2::mask_mov( indexes2, - extend_mask(vtype1::eq(tmp_keys, in1)), + resize_mask(vtype1::eq(tmp_keys, in1)), indexes1); return tmp_keys; // 0 -> min, 1 -> max } @@ -503,7 +503,7 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys, for (int i = numVecs / 2; i < numVecs; i++) { indexVecs[i] = indexType::mask_loadu( indexType::zmm_max(), - extend_mask(ioMasks[i - numVecs / 2]), + resize_mask(ioMasks[i - numVecs / 2]), indices + i * indexType::numlanes); keyVecs[i] = keyType::template mask_i64gather(ioMasks[i - numVecs / 2]), + resize_mask(ioMasks[i - numVecs / 2]), indexVecs[i]); } } From 1cf135b2dc2cd9e7dceef475c127a1cb6ba94462 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 15 Jan 2024 20:46:32 -0800 Subject: [PATCH 10/10] Re-write resize_mask and move constants to common file --- src/xss-common-includes.h | 19 ++++++++++++ src/xss-network-keyvaluesort.hpp | 52 +++++++++----------------------- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index 9f793e37..fadb81f5 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -71,6 +71,25 @@ #define X86_SIMD_SORT_UNROLL_LOOP(num) #endif +template +constexpr bool always_false = false; + +/* + * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ +// ZMM 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 +#define NETWORK_32BIT_1 14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1 +#define NETWORK_32BIT_3 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_32BIT_5 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +#define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 +#define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 + typedef size_t arrsize_t; template diff --git a/src/xss-network-keyvaluesort.hpp b/src/xss-network-keyvaluesort.hpp index 3f74633e..03f15155 100644 --- a/src/xss-network-keyvaluesort.hpp +++ b/src/xss-network-keyvaluesort.hpp @@ -3,38 +3,25 @@ #include "xss-common-includes.h" -#define NETWORK_32BIT_1 14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1 -#define NETWORK_32BIT_3 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_32BIT_5 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 -#define NETWORK_32BIT_6 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 -#define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 - template typename valueType::opmask_t resize_mask(typename keyType::opmask_t mask) { using inT = typename keyType::opmask_t; using outT = typename valueType::opmask_t; - - if constexpr (keyType::vec_type == simd_type::AVX512) { return mask; } - else if constexpr (keyType::vec_type == simd_type::AVX2) { - if constexpr (sizeof(inT) == sizeof(outT)) { return mask; } - else if constexpr (sizeof(inT) == 32 && sizeof(outT) == 16){ - // We need to convert a mask made of 64 bit integers to 32 bit integers - // This does this by taking advantage of the fact that the only bit that matters - // is the very topmost bit, which becomes the sign bit when cast to floating point - - // TODO try and figure out if there is a better way to do this - return _mm_castps_si128(_mm256_cvtpd_ps(_mm256_castsi256_pd(mask))); - } - else if constexpr (sizeof(inT) == 16 && sizeof(outT) == 32){ - return _mm256_cvtepi32_epi64(mask); - }else{ - static_assert(sizeof(inT) == -1, "should not reach here"); - } + + if constexpr (sizeof(inT) == sizeof(outT)) { //std::is_same_v) { + return mask; + } + /* convert __m256i to __m128i */ + else if constexpr (sizeof(inT) == 32 && sizeof(outT) == 16) { + return _mm_castps_si128(_mm256_cvtpd_ps(_mm256_castsi256_pd(mask))); + } + /* convert __m128i to __m256i */ + else if constexpr (sizeof(inT) == 16 && sizeof(outT) == 32) { + return _mm256_cvtepi32_epi64(mask); } else { - static_assert(keyType::vec_type == simd_type::AVX512, - "should not reach here"); + static_assert(always_false, "Error in func resize_mask"); } } @@ -78,17 +65,6 @@ X86_SIMD_SORT_INLINE reg_t1 cmp_merge(reg_t1 in1, return tmp_keys; // 0 -> min, 1 -> max } -/* - * Constants used in sorting 8 elements in a ZMM registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ -// ZMM 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_64BIT_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_64BIT_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_64BIT_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_64BIT_4 3, 2, 1, 0, 7, 6, 5, 4 - template (key, value); } else { - static_assert(numlanes == -1, "No implementation"); + static_assert(always_false, "bitonic_merge_dispatch: No implementation"); UNUSED(key); UNUSED(value); } @@ -373,7 +349,7 @@ X86_SIMD_SORT_INLINE void sort_vec_dispatch(typename keyType::reg_t &key, key = sort_ymm_64bit(key, value); } else { - static_assert(numlanes == -1, "No implementation"); + static_assert(always_false, "sort_vec_dispatch: No implementation"); UNUSED(key); UNUSED(value); }