diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 1efcf1e9..b5202f46 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -433,11 +433,11 @@ void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan) { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_16bit_, uint16_t>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index bfd4a151..a0dd7f7e 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -715,7 +715,10 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) } template <> -void avx512_qselect(int32_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qselect(int32_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan) { if (arrsize > 1) { qselect_32bit_, int32_t>( @@ -724,7 +727,10 @@ void avx512_qselect(int32_t *arr, int64_t k, int64_t arrsize, bool hasn } template <> -void avx512_qselect(uint32_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qselect(uint32_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan) { if (arrsize > 1) { qselect_32bit_, uint32_t>( @@ -737,11 +743,11 @@ void avx512_qselect(float *arr, int64_t k, int64_t arrsize, bool hasnan) { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_32bit_, float>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 80c6ce4a..3626ab63 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -8,8 +8,28 @@ #define AVX512_ARGSORT_64BIT #include "avx512-64bit-common.h" -#include "avx512-common-argsort.h" #include "avx512-64bit-keyvalue-networks.hpp" +#include "avx512-common-argsort.h" + +template +void std_argselect_withnan( + T *arr, int64_t *arg, int64_t k, int64_t left, int64_t right) +{ + std::nth_element(arg + left, + arg + k, + arg + right, + [arr](int64_t a, int64_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] < arr[b]; + } + else if (std::isnan(arr[a])) { + return false; + } + else { + return true; + } + }); +} /* argsort using std::sort */ template @@ -18,9 +38,15 @@ void std_argsort_withnan(T *arr, int64_t *arg, int64_t left, int64_t right) std::sort(arg + left, arg + right, [arr](int64_t left, int64_t right) -> bool { - if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) {return arr[left] < arr[right];} - else if (std::isnan(arr[left])) {return false;} - else {return true;} + if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { + return arr[left] < arr[right]; + } + else if (std::isnan(arr[left])) { + return false; + } + else { + return true; + } }); } @@ -284,7 +310,42 @@ inline void argsort_64bit_(type_t *arr, } template -bool has_nan(type_t* arr, int64_t arrsize) +static void argselect_64bit_(type_t *arr, + int64_t *arg, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + std_argsort(arr, arg, left, right + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 64 + */ + if (right + 1 - left <= 64) { + argsort_64_64bit(arr, arg + left, (int32_t)(right + 1 - left)); + return; + } + type_t pivot = get_pivot_64bit(arr, arg, left, right); + type_t smallest = vtype::type_max(); + type_t biggest = vtype::type_min(); + int64_t pivot_index = partition_avx512_unrolled( + arr, arg, left, right + 1, pivot, &smallest, &biggest); + if ((pivot != smallest) && (pos < pivot_index)) + argselect_64bit_( + arr, arg, pos, left, pivot_index - 1, max_iters - 1); + else if ((pivot != biggest) && (pos >= pivot_index)) + argselect_64bit_( + arr, arg, pos, pivot_index, right, max_iters - 1); +} + +template +bool has_nan(type_t *arr, int64_t arrsize) { using opmask_t = typename vtype::opmask_t; using zmm_t = typename vtype::zmm_t; @@ -299,7 +360,7 @@ bool has_nan(type_t* arr, int64_t arrsize) else { in = vtype::loadu(arr); } - opmask_t nanmask = vtype::template fpclass<0x01|0x80>(in); + opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); arr += vtype::numlanes; arrsize -= vtype::numlanes; if (nanmask != 0x00) { @@ -310,8 +371,9 @@ bool has_nan(type_t* arr, int64_t arrsize) return found_nan; } +/* argsort methods for 32-bit and 64-bit dtypes */ template -void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize) +void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { argsort_64bit_>( @@ -320,7 +382,7 @@ void avx512_argsort(T* arr, int64_t *arg, int64_t arrsize) } template <> -void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize) +void avx512_argsort(double *arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { if (has_nan>(arr, arrsize)) { @@ -333,9 +395,8 @@ void avx512_argsort(double* arr, int64_t *arg, int64_t arrsize) } } - template <> -void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize) +void avx512_argsort(int32_t *arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { argsort_64bit_>( @@ -344,7 +405,7 @@ void avx512_argsort(int32_t* arr, int64_t *arg, int64_t arrsize) } template <> -void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize) +void avx512_argsort(uint32_t *arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { argsort_64bit_>( @@ -353,7 +414,7 @@ void avx512_argsort(uint32_t* arr, int64_t *arg, int64_t arrsize) } template <> -void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize) +void avx512_argsort(float *arr, int64_t *arg, int64_t arrsize) { if (arrsize > 1) { if (has_nan>(arr, arrsize)) { @@ -367,7 +428,7 @@ void avx512_argsort(float* arr, int64_t *arg, int64_t arrsize) } template -std::vector avx512_argsort(T* arr, int64_t arrsize) +std::vector avx512_argsort(T *arr, int64_t arrsize) { std::vector indices(arrsize); std::iota(indices.begin(), indices.end(), 0); @@ -375,4 +436,69 @@ std::vector avx512_argsort(T* arr, int64_t arrsize) return indices; } +/* argselect methods for 32-bit and 64-bit dtypes */ +template +void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize) +{ + if (arrsize > 1) { + argselect_64bit_>( + arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_argselect(double *arr, int64_t *arg, int64_t k, int64_t arrsize) +{ + if (arrsize > 1) { + if (has_nan>(arr, arrsize)) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + } + else { + argselect_64bit_>( + arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } + } +} + +template <> +void avx512_argselect(int32_t *arr, int64_t *arg, int64_t k, int64_t arrsize) +{ + if (arrsize > 1) { + argselect_64bit_>( + arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_argselect(uint32_t *arr, int64_t *arg, int64_t k, int64_t arrsize) +{ + if (arrsize > 1) { + argselect_64bit_>( + arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } +} + +template <> +void avx512_argselect(float *arr, int64_t *arg, int64_t k, int64_t arrsize) +{ + if (arrsize > 1) { + if (has_nan>(arr, arrsize)) { + std_argselect_withnan(arr, arg, k, 0, arrsize); + } + else { + argselect_64bit_>( + arr, arg, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } + } +} + +template +std::vector avx512_argselect(T *arr, int64_t k, int64_t arrsize) +{ + std::vector indices(arrsize); + std::iota(indices.begin(), indices.end(), 0); + avx512_argselect(arr, indices.data(), k, arrsize); + return indices; +} + #endif // AVX512_ARGSORT_64BIT diff --git a/src/avx512-64bit-keyvalue-networks.hpp b/src/avx512-64bit-keyvalue-networks.hpp index af3a2a98..b930a42b 100644 --- a/src/avx512-64bit-keyvalue-networks.hpp +++ b/src/avx512-64bit-keyvalue-networks.hpp @@ -136,14 +136,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, typename vtype1::opmask_t movmask1 = vtype1::eq(key_zmm_t1, key_zmm[0]); typename vtype1::opmask_t movmask2 = vtype1::eq(key_zmm_t2, key_zmm[1]); - index_type index_zmm_t1 = vtype2::mask_mov( - index_zmm3r, movmask1, index_zmm[0]); - index_type index_zmm_m1 = vtype2::mask_mov( - index_zmm[0], movmask1, index_zmm3r); - index_type index_zmm_t2 = vtype2::mask_mov( - index_zmm2r, movmask2, index_zmm[1]); - index_type index_zmm_m2 = vtype2::mask_mov( - index_zmm[1], movmask2, index_zmm2r); + index_type index_zmm_t1 + = vtype2::mask_mov(index_zmm3r, movmask1, index_zmm[0]); + index_type index_zmm_m1 + = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm3r); + index_type index_zmm_t2 + = vtype2::mask_mov(index_zmm2r, movmask2, index_zmm[1]); + index_type index_zmm_m2 + = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm2r); // 2) Recursive half clearer: 16 zmm_t key_zmm_t3 = vtype1::permutexvar(rev_index1, key_zmm_m2); @@ -159,14 +159,14 @@ X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_64bit(zmm_t *key_zmm, movmask1 = vtype1::eq(key_zmm0, key_zmm_t1); movmask2 = vtype1::eq(key_zmm2, key_zmm_t3); - index_type index_zmm0 = vtype2::mask_mov( - index_zmm_t2, movmask1, index_zmm_t1); - index_type index_zmm1 = vtype2::mask_mov( - index_zmm_t1, movmask1, index_zmm_t2); - index_type index_zmm2 = vtype2::mask_mov( - index_zmm_t4, movmask2, index_zmm_t3); - index_type index_zmm3 = vtype2::mask_mov( - index_zmm_t3, movmask2, index_zmm_t4); + index_type index_zmm0 + = vtype2::mask_mov(index_zmm_t2, movmask1, index_zmm_t1); + index_type index_zmm1 + = vtype2::mask_mov(index_zmm_t1, movmask1, index_zmm_t2); + index_type index_zmm2 + = vtype2::mask_mov(index_zmm_t4, movmask2, index_zmm_t3); + index_type index_zmm3 + = vtype2::mask_mov(index_zmm_t3, movmask2, index_zmm_t4); key_zmm[0] = bitonic_merge_zmm_64bit(key_zmm0, index_zmm0); key_zmm[1] = bitonic_merge_zmm_64bit(key_zmm1, index_zmm1); @@ -212,22 +212,22 @@ X86_SIMD_SORT_INLINE void bitonic_merge_eight_zmm_64bit(zmm_t *key_zmm, typename vtype1::opmask_t movmask3 = vtype1::eq(key_zmm_t3, key_zmm[2]); typename vtype1::opmask_t movmask4 = vtype1::eq(key_zmm_t4, key_zmm[3]); - index_type index_zmm_t1 = vtype2::mask_mov( - index_zmm7r, movmask1, index_zmm[0]); - index_type index_zmm_m1 = vtype2::mask_mov( - index_zmm[0], movmask1, index_zmm7r); - index_type index_zmm_t2 = vtype2::mask_mov( - index_zmm6r, movmask2, index_zmm[1]); - index_type index_zmm_m2 = vtype2::mask_mov( - index_zmm[1], movmask2, index_zmm6r); - index_type index_zmm_t3 = vtype2::mask_mov( - index_zmm5r, movmask3, index_zmm[2]); - index_type index_zmm_m3 = vtype2::mask_mov( - index_zmm[2], movmask3, index_zmm5r); - index_type index_zmm_t4 = vtype2::mask_mov( - index_zmm4r, movmask4, index_zmm[3]); - index_type index_zmm_m4 = vtype2::mask_mov( - index_zmm[3], movmask4, index_zmm4r); + index_type index_zmm_t1 + = vtype2::mask_mov(index_zmm7r, movmask1, index_zmm[0]); + index_type index_zmm_m1 + = vtype2::mask_mov(index_zmm[0], movmask1, index_zmm7r); + index_type index_zmm_t2 + = vtype2::mask_mov(index_zmm6r, movmask2, index_zmm[1]); + index_type index_zmm_m2 + = vtype2::mask_mov(index_zmm[1], movmask2, index_zmm6r); + index_type index_zmm_t3 + = vtype2::mask_mov(index_zmm5r, movmask3, index_zmm[2]); + index_type index_zmm_m3 + = vtype2::mask_mov(index_zmm[2], movmask3, index_zmm5r); + index_type index_zmm_t4 + = vtype2::mask_mov(index_zmm4r, movmask4, index_zmm[3]); + index_type index_zmm_m4 + = vtype2::mask_mov(index_zmm[3], movmask4, index_zmm4r); zmm_t key_zmm_t5 = vtype1::permutexvar(rev_index1, key_zmm_m4); zmm_t key_zmm_t6 = vtype1::permutexvar(rev_index1, key_zmm_m3); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index aa5d7958..d59a1788 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -784,7 +784,10 @@ static void qselect_64bit_(type_t *arr, } template <> -void avx512_qselect(int64_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qselect(int64_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan) { if (arrsize > 1) { qselect_64bit_, int64_t>( @@ -793,7 +796,10 @@ void avx512_qselect(int64_t *arr, int64_t k, int64_t arrsize, bool hasn } template <> -void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qselect(uint64_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan) { if (arrsize > 1) { qselect_64bit_, uint64_t>( @@ -802,15 +808,18 @@ void avx512_qselect(uint64_t *arr, int64_t k, int64_t arrsize, bool ha } template <> -void avx512_qselect(double *arr, int64_t k, int64_t arrsize, bool hasnan) +void avx512_qselect(double *arr, + int64_t k, + int64_t arrsize, + bool hasnan) { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_64bit_, double>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/src/avx512-common-argsort.h b/src/avx512-common-argsort.h index e0dcaccc..0ae50c49 100644 --- a/src/avx512-common-argsort.h +++ b/src/avx512-common-argsort.h @@ -21,6 +21,11 @@ void avx512_argsort(T *arr, int64_t *arg, int64_t arrsize); template std::vector avx512_argsort(T *arr, int64_t arrsize); +template +void avx512_argselect(T *arr, int64_t *arg, int64_t k, int64_t arrsize); + +template +std::vector avx512_argselect(T *arr, int64_t k, int64_t arrsize); /* * Parition one ZMM register based on the pivot and returns the index of the * last element that is less than equal to the pivot. diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 8e87ac29..5bb4c6c0 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -157,11 +157,11 @@ void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan) { int64_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { qselect_16bit_, _Float16>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } diff --git a/tests/test-argselect.hpp b/tests/test-argselect.hpp new file mode 100644 index 00000000..298000d4 --- /dev/null +++ b/tests/test-argselect.hpp @@ -0,0 +1,47 @@ +/******************************************* + * * Copyright (C) 2023 Intel Corporation + * * SPDX-License-Identifier: BSD-3-Clause + * *******************************************/ + +template +class avx512argselect : public ::testing::Test { +}; + +TYPED_TEST_SUITE_P(avx512argselect); + +TYPED_TEST_P(avx512argselect, test_random) +{ + if (cpu_has_avx512bw()) { + const int arrsize = 1024; + auto arr = get_uniform_rand_array(arrsize); + std::vector sorted_inx; + if (std::is_floating_point::value) { + arr[0] = std::numeric_limits::quiet_NaN(); + arr[1] = std::numeric_limits::quiet_NaN(); + } + sorted_inx = std_argsort(arr); + std::vector kth; + for (int64_t ii = 0; ii < arrsize - 3; ++ii) { + kth.push_back(ii); + } + for (auto &k : kth) { + std::vector inx + = avx512_argselect(arr.data(), k, arr.size()); + auto true_kth = arr[sorted_inx[k]]; + EXPECT_EQ(true_kth, arr[inx[k]]) << "Failed at index k = " << k; + if (k >= 1) + EXPECT_GE(true_kth, std_max_element(arr, inx, 0, k - 1)) + << "failed at k = " << k; + if (k != arrsize - 1) + EXPECT_LE(true_kth, + std_min_element(arr, inx, k + 1, arrsize - 1)) + << "failed at k = " << k; + EXPECT_UNIQUE(inx) + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +REGISTER_TYPED_TEST_SUITE_P(avx512argselect, test_random); diff --git a/tests/test-argsort-common.h b/tests/test-argsort-common.h new file mode 100644 index 00000000..2e293620 --- /dev/null +++ b/tests/test-argsort-common.h @@ -0,0 +1,81 @@ +#include "avx512-64bit-argsort.hpp" +#include "cpuinfo.h" +#include "rand_array.h" +#include +#include +#include + +template +std::vector std_argsort(const std::vector &arr) +{ + std::vector indices(arr.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), + indices.end(), + [&arr](int64_t left, int64_t right) -> bool { + if ((!std::isnan(arr[left])) && (!std::isnan(arr[right]))) { + return arr[left] < arr[right]; + } + else if (std::isnan(arr[left])) { + return false; + } + else { + return true; + } + }); + + return indices; +} + +template +T std_min_element(std::vector arr, + std::vector arg, + int64_t left, + int64_t right) +{ + std::vector::iterator res = std::min_element( + arg.begin() + left, + arg.begin() + right, + [arr](int64_t a, int64_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] < arr[b]; + } + else if (std::isnan(arr[a])) { + return false; + } + else { + return true; + } + }); + return arr[*res]; +} + +template +T std_max_element(std::vector arr, + std::vector arg, + int64_t left, + int64_t right) +{ + std::vector::iterator res = std::max_element( + arg.begin() + left, + arg.begin() + right, + [arr](int64_t a, int64_t b) -> bool { + if ((!std::isnan(arr[a])) && (!std::isnan(arr[b]))) { + return arr[a] > arr[b]; + } + else if (std::isnan(arr[a])) { + return true; + } + else { + return false; + } + }); + return arr[*res]; +} + +#define EXPECT_UNIQUE(sorted_arg) \ + std::sort(sorted_arg.begin(), sorted_arg.end()); \ + std::vector expected_arg(sorted_arg.size()); \ + std::iota(expected_arg.begin(), expected_arg.end(), 0); \ + EXPECT_EQ(sorted_arg, expected_arg) \ + << "Indices aren't unique. Array size = " << sorted_arg.size(); diff --git a/tests/test-argsort.cpp b/tests/test-argsort.cpp index 8048d751..41ce5ca4 100644 --- a/tests/test-argsort.cpp +++ b/tests/test-argsort.cpp @@ -1,305 +1,9 @@ -/******************************************* - * * Copyright (C) 2023 Intel Corporation - * * SPDX-License-Identifier: BSD-3-Clause - * *******************************************/ +#include "test-argsort-common.h" +#include "test-argsort.hpp" +#include "test-argselect.hpp" -#include "avx512-64bit-argsort.hpp" -#include "cpuinfo.h" -#include "rand_array.h" -#include -#include -#include - -template -class avx512argsort : public ::testing::Test { -}; -TYPED_TEST_SUITE_P(avx512argsort); - -template -std::vector std_argsort(const std::vector &array) -{ - std::vector indices(array.size()); - std::iota(indices.begin(), indices.end(), 0); - std::sort(indices.begin(), - indices.end(), - [&array](int left, int right) -> bool { - // sort indices according to corresponding array sizeent - return array[left] < array[right]; - }); - - return indices; -} - -#define EXPECT_UNIQUE(sorted_arg) \ - std::sort(sorted_arg.begin(), sorted_arg.end()); \ - std::vector expected_arg(sorted_arg.size()); \ - std::iota(expected_arg.begin(), expected_arg.end(), 0); \ - EXPECT_EQ(sorted_arg, expected_arg) << "Indices aren't unique. Array size = " << sorted_arg.size(); - -TYPED_TEST_P(avx512argsort, test_random) -{ - if (cpu_has_avx512bw()) { - std::vector arrsizes; - for (int64_t ii = 0; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - /* Random array */ - arr = get_uniform_rand_array(size); - std::vector inx1 = std_argsort(arr); - std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx1[jj]]); - sort2.push_back(arr[inx2[jj]]); - } - EXPECT_EQ(sort1, sort2) << "Array size =" << size; - EXPECT_UNIQUE(inx2) - arr.clear(); - } - } - else { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } -} - -TYPED_TEST_P(avx512argsort, test_constant) -{ - if (cpu_has_avx512bw()) { - std::vector arrsizes; - for (int64_t ii = 0; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - /* constant array */ - auto elem = get_uniform_rand_array(1)[0]; - for (int64_t jj = 0; jj < size; ++jj) { - arr.push_back(elem); - } - std::vector inx1 = std_argsort(arr); - std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx1[jj]]); - sort2.push_back(arr[inx2[jj]]); - } - EXPECT_EQ(sort1, sort2) << "Array size =" << size; - EXPECT_UNIQUE(inx2) - arr.clear(); - } - } - else { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } -} - -TYPED_TEST_P(avx512argsort, test_small_range) -{ - if (cpu_has_avx512bw()) { - std::vector arrsizes; - for (int64_t ii = 0; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - /* array with a smaller range of values */ - arr = get_uniform_rand_array(size, 20, 1); - std::vector inx1 = std_argsort(arr); - std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx1[jj]]); - sort2.push_back(arr[inx2[jj]]); - } - EXPECT_EQ(sort1, sort2) << "Array size = " << size; - EXPECT_UNIQUE(inx2) - arr.clear(); - } - } - else { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } -} - -TYPED_TEST_P(avx512argsort, test_sorted) -{ - if (cpu_has_avx512bw()) { - std::vector arrsizes; - for (int64_t ii = 0; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); - std::sort(arr.begin(), arr.end()); - std::vector inx1 = std_argsort(arr); - std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx1[jj]]); - sort2.push_back(arr[inx2[jj]]); - } - EXPECT_EQ(sort1, sort2) << "Array size =" << size; - EXPECT_UNIQUE(inx2) - arr.clear(); - } - } - else { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } -} - -TYPED_TEST_P(avx512argsort, test_reverse) -{ - if (cpu_has_avx512bw()) { - std::vector arrsizes; - for (int64_t ii = 0; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); - std::sort(arr.begin(), arr.end()); - std::reverse(arr.begin(), arr.end()); - std::vector inx1 = std_argsort(arr); - std::vector inx2 - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1, sort2; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx1[jj]]); - sort2.push_back(arr[inx2[jj]]); - } - EXPECT_EQ(sort1, sort2) << "Array size =" << size; - EXPECT_UNIQUE(inx2) - arr.clear(); - } - } - else { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } -} - -TYPED_TEST_P(avx512argsort, test_array_with_nan) -{ - if (!cpu_has_avx512bw()) { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } - if (!std::is_floating_point::value) { - GTEST_SKIP() << "Skipping this test, it is meant for float/double"; - } - std::vector arrsizes; - for (int64_t ii = 2; ii <= 1024; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); - arr[0] = std::numeric_limits::quiet_NaN(); - arr[1] = std::numeric_limits::quiet_NaN(); - std::vector inx - = avx512_argsort(arr.data(), arr.size()); - std::vector sort1; - for (size_t jj = 0; jj < size; ++jj) { - sort1.push_back(arr[inx[jj]]); - } - if ((!std::isnan(sort1[size - 1])) || (!std::isnan(sort1[size - 2]))) { - FAIL() << "NAN's aren't sorted to the end"; - } - if (!std::is_sorted(sort1.begin(), sort1.end() - 2)) { - FAIL() << "Array isn't sorted"; - } - EXPECT_UNIQUE(inx) - arr.clear(); - } -} - -TYPED_TEST_P(avx512argsort, test_max_value_at_end_of_array) -{ - if (!cpu_has_avx512bw()) { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } - std::vector arrsizes; - for (int64_t ii = 1; ii <= 256; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); - if (std::numeric_limits::has_infinity) { - arr[size - 1] = std::numeric_limits::infinity(); - } - else { - arr[size - 1] = std::numeric_limits::max(); - } - std::vector inx = avx512_argsort(arr.data(), arr.size()); - std::vector sorted; - for (size_t jj = 0; jj < size; ++jj) { - sorted.push_back(arr[inx[jj]]); - } - if (!std::is_sorted(sorted.begin(), sorted.end())) { - EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; - } - EXPECT_UNIQUE(inx) - arr.clear(); - } -} - -TYPED_TEST_P(avx512argsort, test_all_inf_array) -{ - if (!cpu_has_avx512bw()) { - GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; - } - std::vector arrsizes; - for (int64_t ii = 1; ii <= 256; ++ii) { - arrsizes.push_back(ii); - } - std::vector arr; - for (auto &size : arrsizes) { - arr = get_uniform_rand_array(size); - if (std::numeric_limits::has_infinity) { - for (int64_t jj = 1; jj <= size; ++jj) { - if (rand() % 0x1) { - arr.push_back(std::numeric_limits::infinity()); - } - } - } - else { - for (int64_t jj = 1; jj <= size; ++jj) { - if (rand() % 0x1) { - arr.push_back(std::numeric_limits::max()); - } - } - } - std::vector inx = avx512_argsort(arr.data(), arr.size()); - std::vector sorted; - for (size_t jj = 0; jj < size; ++jj) { - sorted.push_back(arr[inx[jj]]); - } - if (!std::is_sorted(sorted.begin(), sorted.end())) { - EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; - } - EXPECT_UNIQUE(inx) - arr.clear(); - } -} - -REGISTER_TYPED_TEST_SUITE_P(avx512argsort, - test_random, - test_reverse, - test_constant, - test_sorted, - test_small_range, - test_all_inf_array, - test_array_with_nan, - test_max_value_at_end_of_array); - -using ArgSortTestTypes +using ArgTestTypes = testing::Types; -INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512argsort, ArgSortTestTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512argsort, ArgTestTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512argselect, ArgTestTypes); diff --git a/tests/test-argsort.hpp b/tests/test-argsort.hpp new file mode 100644 index 00000000..f7a4a23f --- /dev/null +++ b/tests/test-argsort.hpp @@ -0,0 +1,272 @@ +/******************************************* + * * Copyright (C) 2023 Intel Corporation + * * SPDX-License-Identifier: BSD-3-Clause + * *******************************************/ + +template +class avx512argsort : public ::testing::Test { +}; +TYPED_TEST_SUITE_P(avx512argsort); + +TYPED_TEST_P(avx512argsort, test_random) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + /* Random array */ + arr = get_uniform_rand_array(size); + std::vector inx1 = std_argsort(arr); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + EXPECT_EQ(sort1, sort2) << "Array size =" << size; + EXPECT_UNIQUE(inx2) + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TYPED_TEST_P(avx512argsort, test_constant) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + /* constant array */ + auto elem = get_uniform_rand_array(1)[0]; + for (int64_t jj = 0; jj < size; ++jj) { + arr.push_back(elem); + } + std::vector inx1 = std_argsort(arr); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + EXPECT_EQ(sort1, sort2) << "Array size =" << size; + EXPECT_UNIQUE(inx2) + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TYPED_TEST_P(avx512argsort, test_small_range) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + /* array with a smaller range of values */ + arr = get_uniform_rand_array(size, 20, 1); + std::vector inx1 = std_argsort(arr); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + EXPECT_EQ(sort1, sort2) << "Array size = " << size; + EXPECT_UNIQUE(inx2) + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TYPED_TEST_P(avx512argsort, test_sorted) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + arr = get_uniform_rand_array(size); + std::sort(arr.begin(), arr.end()); + std::vector inx1 = std_argsort(arr); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + EXPECT_EQ(sort1, sort2) << "Array size =" << size; + EXPECT_UNIQUE(inx2) + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TYPED_TEST_P(avx512argsort, test_reverse) +{ + if (cpu_has_avx512bw()) { + std::vector arrsizes; + for (int64_t ii = 0; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + arr = get_uniform_rand_array(size); + std::sort(arr.begin(), arr.end()); + std::reverse(arr.begin(), arr.end()); + std::vector inx1 = std_argsort(arr); + std::vector inx2 + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1, sort2; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx1[jj]]); + sort2.push_back(arr[inx2[jj]]); + } + EXPECT_EQ(sort1, sort2) << "Array size =" << size; + EXPECT_UNIQUE(inx2) + arr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } +} + +TYPED_TEST_P(avx512argsort, test_array_with_nan) +{ + if (!cpu_has_avx512bw()) { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } + if (!std::is_floating_point::value) { + GTEST_SKIP() << "Skipping this test, it is meant for float/double"; + } + std::vector arrsizes; + for (int64_t ii = 2; ii <= 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + arr = get_uniform_rand_array(size); + arr[0] = std::numeric_limits::quiet_NaN(); + arr[1] = std::numeric_limits::quiet_NaN(); + std::vector inx + = avx512_argsort(arr.data(), arr.size()); + std::vector sort1; + for (size_t jj = 0; jj < size; ++jj) { + sort1.push_back(arr[inx[jj]]); + } + if ((!std::isnan(sort1[size - 1])) || (!std::isnan(sort1[size - 2]))) { + FAIL() << "NAN's aren't sorted to the end"; + } + if (!std::is_sorted(sort1.begin(), sort1.end() - 2)) { + FAIL() << "Array isn't sorted"; + } + EXPECT_UNIQUE(inx) + arr.clear(); + } +} + +TYPED_TEST_P(avx512argsort, test_max_value_at_end_of_array) +{ + if (!cpu_has_avx512bw()) { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } + std::vector arrsizes; + for (int64_t ii = 1; ii <= 256; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + arr = get_uniform_rand_array(size); + if (std::numeric_limits::has_infinity) { + arr[size - 1] = std::numeric_limits::infinity(); + } + else { + arr[size - 1] = std::numeric_limits::max(); + } + std::vector inx = avx512_argsort(arr.data(), arr.size()); + std::vector sorted; + for (size_t jj = 0; jj < size; ++jj) { + sorted.push_back(arr[inx[jj]]); + } + if (!std::is_sorted(sorted.begin(), sorted.end())) { + EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; + } + EXPECT_UNIQUE(inx) + arr.clear(); + } +} + +TYPED_TEST_P(avx512argsort, test_all_inf_array) +{ + if (!cpu_has_avx512bw()) { + GTEST_SKIP() << "Skipping this test, it requires avx512bw ISA"; + } + std::vector arrsizes; + for (int64_t ii = 1; ii <= 256; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + for (auto &size : arrsizes) { + arr = get_uniform_rand_array(size); + if (std::numeric_limits::has_infinity) { + for (int64_t jj = 1; jj <= size; ++jj) { + if (rand() % 0x1) { + arr.push_back(std::numeric_limits::infinity()); + } + } + } + else { + for (int64_t jj = 1; jj <= size; ++jj) { + if (rand() % 0x1) { + arr.push_back(std::numeric_limits::max()); + } + } + } + std::vector inx = avx512_argsort(arr.data(), arr.size()); + std::vector sorted; + for (size_t jj = 0; jj < size; ++jj) { + sorted.push_back(arr[inx[jj]]); + } + if (!std::is_sorted(sorted.begin(), sorted.end())) { + EXPECT_TRUE(false) << "Array of size " << size << "is not sorted"; + } + EXPECT_UNIQUE(inx) + arr.clear(); + } +} + +REGISTER_TYPED_TEST_SUITE_P(avx512argsort, + test_random, + test_reverse, + test_constant, + test_sorted, + test_small_range, + test_all_inf_array, + test_array_with_nan, + test_max_value_at_end_of_array); diff --git a/tests/test-qsort-fp.hpp b/tests/test-qsort-fp.hpp index 9000fb38..438305b1 100644 --- a/tests/test-qsort-fp.hpp +++ b/tests/test-qsort-fp.hpp @@ -26,15 +26,17 @@ TYPED_TEST_P(avx512_sort_fp, test_random_nan) /* Random array */ arr = get_uniform_rand_array(size); for (auto ii = 1; ii <= num_nans; ++ii) { - arr[size-ii] = std::numeric_limits::quiet_NaN(); + arr[size - ii] = std::numeric_limits::quiet_NaN(); } sortedarr = arr; - std::sort(sortedarr.begin(), sortedarr.end()-3); + std::sort(sortedarr.begin(), sortedarr.end() - 3); std::random_shuffle(arr.begin(), arr.end()); avx512_qsort(arr.data(), arr.size()); for (auto ii = 1; ii <= num_nans; ++ii) { - if (!std::isnan(arr[size-ii])) { - ASSERT_TRUE(false) << "NAN's aren't sorted to the end. Arr size = " << size; + if (!std::isnan(arr[size - ii])) { + ASSERT_TRUE(false) + << "NAN's aren't sorted to the end. Arr size = " + << size; } } if (!std::is_sorted(arr.begin(), arr.end() - num_nans)) { diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index eb3d5f77..a35d8e8c 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -1,7 +1,7 @@ +#include "test-qsort.hpp" #include "test-partial-qsort.hpp" #include "test-qselect.hpp" #include "test-qsort-fp.hpp" -#include "test-qsort.hpp" using QSortTestTypes = testing::Types