From 012a8f2682659ccbf73fdb7489008a1b757edce5 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Mon, 11 Mar 2024 09:53:07 -0700 Subject: [PATCH 01/18] Modified all of the lib/* files to support descending order --- lib/x86simdsort-avx2.cpp | 12 ++++----- lib/x86simdsort-icl.cpp | 24 +++++++++--------- lib/x86simdsort-internal.h | 18 ++++++------- lib/x86simdsort-scalar.h | 52 ++++++++++++++++++++++++++++++-------- lib/x86simdsort-skx.cpp | 12 ++++----- lib/x86simdsort-spr.cpp | 12 ++++----- lib/x86simdsort.cpp | 18 ++++++------- lib/x86simdsort.h | 6 ++--- 8 files changed, 92 insertions(+), 62 deletions(-) diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 7700c9f4..3785014b 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -7,19 +7,19 @@ #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan) \ + void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - avx2_qsort(arr, arrsize, hasnan); \ + avx2_qsort(arr, arrsize, hasnan, descending); \ } \ template <> \ - void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \ + void qselect(type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx2_qselect(arr, k, arrsize, hasnan); \ + avx2_qselect(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \ + void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx2_partial_qsort(arr, k, arrsize, hasnan); \ + avx2_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ template <> \ std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ diff --git a/lib/x86simdsort-icl.cpp b/lib/x86simdsort-icl.cpp index 09caefb5..bbde114d 100644 --- a/lib/x86simdsort-icl.cpp +++ b/lib/x86simdsort-icl.cpp @@ -5,34 +5,34 @@ namespace xss { namespace avx512 { template <> - void qsort(uint16_t *arr, size_t size, bool hasnan) + void qsort(uint16_t *arr, size_t size, bool hasnan, bool descending) { - avx512_qsort(arr, size, hasnan); + avx512_qsort(arr, size, hasnan, descending); } template <> - void qselect(uint16_t *arr, size_t k, size_t arrsize, bool hasnan) + void qselect(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, bool descending) { - avx512_qselect(arr, k, arrsize, hasnan); + avx512_qselect(arr, k, arrsize, hasnan, descending); } template <> - void partial_qsort(uint16_t *arr, size_t k, size_t arrsize, bool hasnan) + void partial_qsort(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, bool descending) { - avx512_partial_qsort(arr, k, arrsize, hasnan); + avx512_partial_qsort(arr, k, arrsize, hasnan, descending); } template <> - void qsort(int16_t *arr, size_t size, bool hasnan) + void qsort(int16_t *arr, size_t size, bool hasnan, bool descending) { - avx512_qsort(arr, size, hasnan); + avx512_qsort(arr, size, hasnan, descending); } template <> - void qselect(int16_t *arr, size_t k, size_t arrsize, bool hasnan) + void qselect(int16_t *arr, size_t k, size_t arrsize, bool hasnan, bool descending) { - avx512_qselect(arr, k, arrsize, hasnan); + avx512_qselect(arr, k, arrsize, hasnan, descending); } template <> - void partial_qsort(int16_t *arr, size_t k, size_t arrsize, bool hasnan) + void partial_qsort(int16_t *arr, size_t k, size_t arrsize, bool hasnan, bool descending) { - avx512_partial_qsort(arr, k, arrsize, hasnan); + avx512_partial_qsort(arr, k, arrsize, hasnan, descending); } } // namespace avx512 } // namespace xss diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 70f13daf..6452e9fa 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -8,7 +8,7 @@ namespace xss { namespace avx512 { // quicksort template - XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template XSS_EXPORT_SYMBOL void @@ -16,11 +16,11 @@ namespace avx512 { // quickselect template XSS_HIDE_SYMBOL void - qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); + qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // partial sort template XSS_HIDE_SYMBOL void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false); + partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector @@ -33,7 +33,7 @@ namespace avx512 { namespace avx2 { // quicksort template - XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template XSS_EXPORT_SYMBOL void @@ -41,11 +41,11 @@ namespace avx2 { // quickselect template XSS_HIDE_SYMBOL void - qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); + qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // partial sort template XSS_HIDE_SYMBOL void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false); + partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector @@ -58,7 +58,7 @@ namespace avx2 { namespace scalar { // quicksort template - XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false); + XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template XSS_EXPORT_SYMBOL void @@ -66,11 +66,11 @@ namespace scalar { // quickselect template XSS_HIDE_SYMBOL void - qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); + qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // partial sort template XSS_HIDE_SYMBOL void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false); + partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index a5348106..fc3ff9d4 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -25,35 +25,65 @@ namespace utils { namespace scalar { template - void qsort(T *arr, size_t arrsize, bool hasnan) + void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed) { if (hasnan) { - std::sort(arr, arr + arrsize, compare>()); + if (reversed){ + std::sort(arr, arr + arrsize, compare>()); + }else{ + std::sort(arr, arr + arrsize, compare>()); + } } else { - std::sort(arr, arr + arrsize); + if (reversed){ + std::sort(arr, arr + arrsize, std::greater()); + }else{ + std::sort(arr, arr + arrsize, std::less()); + } } } template - void qselect(T *arr, size_t k, size_t arrsize, bool hasnan) + void qselect(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) { if (hasnan) { - std::nth_element( - arr, arr + k, arr + arrsize, compare>()); + if (reversed){ + std::nth_element( + arr, arr + k, arr + arrsize, compare>()); + }else{ + std::nth_element( + arr, arr + k, arr + arrsize, compare>()); + } } else { - std::nth_element(arr, arr + k, arr + arrsize); + if (reversed){ + std::nth_element( + arr, arr + k, arr + arrsize, std::greater()); + }else{ + std::nth_element( + arr, arr + k, arr + arrsize, std::less()); + } } } template - void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan) + void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) { if (hasnan) { - std::partial_sort( - arr, arr + k, arr + arrsize, compare>()); + if (reversed){ + std::partial_sort( + arr, arr + k, arr + arrsize, compare>()); + }else{ + std::partial_sort( + arr, arr + k, arr + arrsize, compare>()); + } } else { - std::partial_sort(arr, arr + k, arr + arrsize); + if (reversed){ + std::partial_sort( + arr, arr + k, arr + arrsize, std::greater()); + }else{ + std::partial_sort( + arr, arr + k, arr + arrsize, std::less()); + } } } template diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 11145e3a..e923ff10 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -7,19 +7,19 @@ #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan) \ + void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - avx512_qsort(arr, arrsize, hasnan); \ + avx512_qsort(arr, arrsize, hasnan, descending); \ } \ template <> \ - void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \ + void qselect(type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx512_qselect(arr, k, arrsize, hasnan); \ + avx512_qselect(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \ + void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx512_partial_qsort(arr, k, arrsize, hasnan); \ + avx512_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ template <> \ std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ diff --git a/lib/x86simdsort-spr.cpp b/lib/x86simdsort-spr.cpp index e07de36f..06aea9af 100644 --- a/lib/x86simdsort-spr.cpp +++ b/lib/x86simdsort-spr.cpp @@ -5,19 +5,19 @@ namespace xss { namespace avx512 { template <> - void qsort(_Float16 *arr, size_t size, bool hasnan) + void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending) { - avx512_qsort(arr, size, hasnan); + avx512_qsort(arr, size, hasnan, descending); } template <> - void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan) + void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, bool descending) { - avx512_qselect(arr, k, arrsize, hasnan); + avx512_qselect(arr, k, arrsize, hasnan, descending); } template <> - void partial_qsort(_Float16 *arr, size_t k, size_t arrsize, bool hasnan) + void partial_qsort(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, bool descending) { - avx512_partial_qsort(arr, k, arrsize, hasnan); + avx512_partial_qsort(arr, k, arrsize, hasnan, descending); } } // namespace avx512 } // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 8626b185..fa86e239 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -57,29 +57,29 @@ namespace x86simdsort { #define CAT(a, b) CAT_(a, b) #define DECLARE_INTERNAL_qsort(TYPE) \ - static void (*internal_qsort##TYPE)(TYPE *, size_t, bool) = NULL; \ + static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool) = NULL; \ template <> \ - void qsort(TYPE *arr, size_t arrsize, bool hasnan) \ + void qsort(TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - (*internal_qsort##TYPE)(arr, arrsize, hasnan); \ + (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ - static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool) \ + static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void qselect(TYPE *arr, size_t k, size_t arrsize, bool hasnan) \ + void qselect(TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - (*internal_qselect##TYPE)(arr, k, arrsize, hasnan); \ + (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ - static void (*internal_partial_qsort##TYPE)(TYPE *, size_t, size_t, bool) \ + static void (*internal_partial_qsort##TYPE)(TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void partial_qsort(TYPE *arr, size_t k, size_t arrsize, bool hasnan) \ + void partial_qsort(TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan); \ + (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index 4dfc6d4b..1f2aa118 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -14,17 +14,17 @@ namespace x86simdsort { // quicksort template -XSS_EXPORT_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false); +XSS_EXPORT_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // quickselect template XSS_EXPORT_SYMBOL void -qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); +qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // partial sort template XSS_EXPORT_SYMBOL void -partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false); +partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // argsort template From f0089e60ad244da0e9af927b785dfa418586a051 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 12 Mar 2024 10:02:08 -0700 Subject: [PATCH 02/18] Add descending order tests --- tests/test-qsort-common.h | 75 +++++++++++++++++++++++++++------------ tests/test-qsort.cpp | 46 ++++++++++++++++++++++-- 2 files changed, 96 insertions(+), 25 deletions(-) diff --git a/tests/test-qsort-common.h b/tests/test-qsort-common.h index 9638387f..fd338df9 100644 --- a/tests/test-qsort-common.h +++ b/tests/test-qsort-common.h @@ -46,31 +46,62 @@ template void IS_ARR_PARTITIONED(std::vector arr, size_t k, T true_kth, - std::string type) + std::string type, + bool descending = false) { - auto cmp_eq = compare>(); - auto cmp_less = compare>(); - auto cmp_leq = compare>(); - auto cmp_geq = compare>(); + if (!descending){ + auto cmp_eq = compare>(); + auto cmp_less = compare>(); + + auto cmp_leq = compare>(); + auto cmp_geq = compare>(); - // 1) arr[k] == sorted[k]; use memcmp to handle nan - if (!cmp_eq(arr[k], true_kth)) { - REPORT_FAIL("kth element is incorrect", arr.size(), type, k); - } - // ( 2) Elements to the left of k should be atmost arr[k] - if (k >= 1) { - T max_left - = *std::max_element(arr.begin(), arr.begin() + k - 1, cmp_less); - if (!cmp_geq(arr[k], max_left)) { - REPORT_FAIL("incorrect left partition", arr.size(), type, k); + // 1) arr[k] == sorted[k]; use memcmp to handle nan + if (!cmp_eq(arr[k], true_kth)) { + REPORT_FAIL("kth element is incorrect", arr.size(), type, k); } - } - // 3) Elements to the right of k should be atleast arr[k] - if (k != (size_t)(arr.size() - 1)) { - T min_right - = *std::min_element(arr.begin() + k + 1, arr.end(), cmp_less); - if (!cmp_leq(arr[k], min_right)) { - REPORT_FAIL("incorrect right partition", arr.size(), type, k); + // ( 2) Elements to the left of k should be atmost arr[k] + if (k >= 1) { + T max_left + = *std::max_element(arr.begin(), arr.begin() + k - 1, cmp_less); + if (!cmp_geq(arr[k], max_left)) { + REPORT_FAIL("incorrect left partition", arr.size(), type, k); + } + } + // 3) Elements to the right of k should be atleast arr[k] + if (k != (size_t)(arr.size() - 1)) { + T min_right + = *std::min_element(arr.begin() + k + 1, arr.end(), cmp_less); + if (!cmp_leq(arr[k], min_right)) { + REPORT_FAIL("incorrect right partition", arr.size(), type, k); + } + } + }else{ + auto cmp_eq = compare>(); + auto cmp_less = compare>(); + + auto cmp_leq = compare>(); + auto cmp_geq = compare>(); + + // 1) arr[k] == sorted[k]; use memcmp to handle nan + if (!cmp_eq(arr[k], true_kth)) { + REPORT_FAIL("kth element is incorrect", arr.size(), type, k); + } + // ( 2) Elements to the left of k should be atleast arr[k] + if (k >= 1) { + T max_left + = *std::max_element(arr.begin(), arr.begin() + k - 1, cmp_less); + if (!cmp_geq(arr[k], max_left)) { + REPORT_FAIL("incorrect left partition", arr.size(), type, k); + } + } + // 3) Elements to the right of k should be atmost arr[k] + if (k != (size_t)(arr.size() - 1)) { + T min_right + = *std::min_element(arr.begin() + k + 1, arr.end(), cmp_less); + if (!cmp_leq(arr[k], min_right)) { + REPORT_FAIL("incorrect right partition", arr.size(), type, k); + } } } } diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index d1428ef8..9eaca32c 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -32,13 +32,26 @@ TYPED_TEST_P(simdsort, test_qsort) for (auto type : this->arrtype) { bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { - std::vector arr = get_array(type, size); + std::vector basearr = get_array(type, size); + + // Ascending order + std::vector arr = basearr; std::vector sortedarr = arr; std::sort(sortedarr.begin(), sortedarr.end(), compare>()); x86simdsort::qsort(arr.data(), arr.size(), hasnan); IS_SORTED(sortedarr, arr, type); + + // Descending order + arr = basearr; + sortedarr = arr; + std::sort(sortedarr.begin(), + sortedarr.end(), + compare>()); + x86simdsort::qsort(arr.data(), arr.size(), hasnan, true); + IS_SORTED(sortedarr, arr, type); + arr.clear(); sortedarr.clear(); } @@ -69,7 +82,10 @@ TYPED_TEST_P(simdsort, test_qselect) bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { size_t k = rand() % size; - std::vector arr = get_array(type, size); + std::vector basearr = get_array(type, size); + + // Ascending order + std::vector arr = basearr; std::vector sortedarr = arr; std::nth_element(sortedarr.begin(), sortedarr.begin() + k, @@ -77,6 +93,17 @@ TYPED_TEST_P(simdsort, test_qselect) compare>()); x86simdsort::qselect(arr.data(), k, arr.size(), hasnan); IS_ARR_PARTITIONED(arr, k, sortedarr[k], type); + + // Descending order + arr = basearr; + sortedarr = arr; + std::nth_element(sortedarr.begin(), + sortedarr.begin() + k, + sortedarr.end(), + compare>()); + x86simdsort::qselect(arr.data(), k, arr.size(), hasnan, true); + IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); + arr.clear(); sortedarr.clear(); } @@ -110,13 +137,26 @@ TYPED_TEST_P(simdsort, test_partial_qsort) for (auto size : this->arrsize) { // k should be at least 1 size_t k = std::max((size_t)1, rand() % size); - std::vector arr = get_array(type, size); + std::vector basearr = get_array(type, size); + + // Ascending order + std::vector arr = basearr; std::vector sortedarr = arr; std::sort(sortedarr.begin(), sortedarr.end(), compare>()); x86simdsort::partial_qsort(arr.data(), k, arr.size(), hasnan); IS_ARR_PARTIALSORTED(arr, k, sortedarr, type); + + // Descending order + arr = basearr; + sortedarr = arr; + std::sort(sortedarr.begin(), + sortedarr.end(), + compare>()); + x86simdsort::partial_qsort(arr.data(), k, arr.size(), hasnan, true); + IS_ARR_PARTIALSORTED(arr, k, sortedarr, type); + arr.clear(); sortedarr.clear(); } From 6b9d9fdd4cb89b22a52dbd6abe1bec6f5c719aaa Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 12 Mar 2024 10:02:33 -0700 Subject: [PATCH 03/18] Added support for descending sorting order --- src/avx2-32bit-qsort.hpp | 14 +- src/avx2-64bit-qsort.hpp | 14 +- src/avx512-16bit-qsort.hpp | 49 ++- src/avx512-32bit-qsort.hpp | 14 +- src/avx512-64bit-common.h | 14 +- src/avx512fp16-16bit-qsort.hpp | 4 + src/xss-common-comparators.hpp | 139 ++++++++ src/xss-common-qsort.h | 199 +++++++---- src/xss-network-qsort.hpp | 55 +-- src/xss-optimal-networks.hpp | 605 ++++++++++++++++----------------- src/xss-pivot-selection.hpp | 45 +-- 11 files changed, 717 insertions(+), 435 deletions(-) create mode 100644 src/xss-common-comparators.hpp diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index c93c5c2a..ad4e99fc 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -85,7 +85,11 @@ struct avx2_vector { static reg_t zmm_max() { return _mm256_set1_epi32(type_max()); - } // TODO: this should broadcast bits as is? + } + static reg_t zmm_min() + { + return _mm256_set1_epi32(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1); @@ -251,6 +255,10 @@ struct avx2_vector { { return _mm256_set1_epi32(type_max()); } + static reg_t zmm_min() + { + return _mm256_set1_epi32(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1); @@ -405,6 +413,10 @@ struct avx2_vector { { return _mm256_set1_ps(type_max()); } + static reg_t zmm_min() + { + return _mm256_set1_ps(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1); diff --git a/src/avx2-64bit-qsort.hpp b/src/avx2-64bit-qsort.hpp index 4028655c..c633b4b9 100644 --- a/src/avx2-64bit-qsort.hpp +++ b/src/avx2-64bit-qsort.hpp @@ -67,7 +67,11 @@ struct avx2_vector { static reg_t zmm_max() { return _mm256_set1_epi64x(type_max()); - } // TODO: this should broadcast bits as is? + } + static reg_t zmm_min() + { + return _mm256_set1_epi64x(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF); @@ -248,6 +252,10 @@ struct avx2_vector { { return _mm256_set1_epi64x(type_max()); } + static reg_t zmm_min() + { + return _mm256_set1_epi64x(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF); @@ -439,6 +447,10 @@ struct avx2_vector { { return _mm256_set1_pd(type_max()); } + static reg_t zmm_min() + { + return _mm256_set1_pd(type_min()); + } static opmask_t knot_opmask(opmask_t x) { auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF); diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 8210ef40..cbeb3e27 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -46,6 +46,10 @@ struct zmm_vector { { return _mm512_set1_epi16(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_epi16(type_min()); + } static opmask_t knot_opmask(opmask_t x) { return _knot_mask32(x); @@ -237,6 +241,10 @@ struct zmm_vector { { return _mm512_set1_epi16(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_epi16(type_min()); + } static opmask_t knot_opmask(opmask_t x) { return _knot_mask32(x); @@ -381,6 +389,10 @@ struct zmm_vector { { return _mm512_set1_epi16(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_epi16(type_min()); + } static opmask_t knot_opmask(opmask_t x) { @@ -549,41 +561,58 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan(uint16_t elem) } X86_SIMD_SORT_INLINE void -avx512_qsort_fp16(uint16_t *arr, arrsize_t arrsize, bool hasnan = false) +avx512_qsort_fp16(uint16_t *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false) { + using vtype = zmm_vector; + if (arrsize > 1) { arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf, uint16_t>( arr, arrsize); } - qsort_, uint16_t>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + if (descending){ + qsort_, uint16_t>( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + + }else{ + qsort_, uint16_t>( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } + replace_inf_with_nan(arr, arrsize, nan_count, descending); } } X86_SIMD_SORT_INLINE void avx512_qselect_fp16(uint16_t *arr, arrsize_t k, arrsize_t arrsize, - bool hasnan = false) + bool hasnan = false, + bool descending = false) { + using vtype = zmm_vector; + arrsize_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { - qselect_, uint16_t>( - arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); + if (descending){ + qselect_, uint16_t>( + arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); + }else{ + qselect_, uint16_t>( + arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); + } } } X86_SIMD_SORT_INLINE void avx512_partial_qsort_fp16(uint16_t *arr, arrsize_t k, arrsize_t arrsize, - bool hasnan = false) + bool hasnan = false, + bool descending = false) { - avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); - avx512_qsort_fp16(arr, k - 1); + avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending); + avx512_qsort_fp16(arr, k - 1, descending); } #endif // AVX512_QSORT_16BIT diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 96fb965f..8b44e76e 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -58,6 +58,10 @@ struct zmm_vector { { return _mm512_set1_epi32(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_epi32(type_min()); + } static opmask_t knot_opmask(opmask_t x) { @@ -240,7 +244,11 @@ struct zmm_vector { static reg_t zmm_max() { return _mm512_set1_epi32(type_max()); - } // TODO: this should broadcast bits as is? + } + static reg_t zmm_min() + { + return _mm512_set1_epi32(type_min()); + } template static halfreg_t i64gather(__m512i index, void const *base) @@ -424,6 +432,10 @@ struct zmm_vector { { return _mm512_set1_ps(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_ps(type_min()); + } static opmask_t knot_opmask(opmask_t x) { diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 1cd4ca1c..68735c33 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -591,7 +591,11 @@ struct zmm_vector { static reg_t zmm_max() { return _mm512_set1_epi64(type_max()); - } // TODO: this should broadcast bits as is? + } + static reg_t zmm_min() + { + return _mm512_set1_epi64(type_min()); + } static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) @@ -775,6 +779,10 @@ struct zmm_vector { { return _mm512_set1_epi64(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_epi64(type_min()); + } static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) @@ -963,6 +971,10 @@ struct zmm_vector { { return _mm512_set1_pd(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_pd(type_min()); + } static regi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 7710dc48..28dff1c0 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -47,6 +47,10 @@ struct zmm_vector<_Float16> { { return _mm512_set1_ph(type_max()); } + static reg_t zmm_min() + { + return _mm512_set1_ph(type_min()); + } static opmask_t knot_opmask(opmask_t x) { return _knot_mask32(x); diff --git a/src/xss-common-comparators.hpp b/src/xss-common-comparators.hpp new file mode 100644 index 00000000..42234727 --- /dev/null +++ b/src/xss-common-comparators.hpp @@ -0,0 +1,139 @@ +#ifndef XSS_COMMON_COMPARATORS +#define XSS_COMMON_COMPARATORS + +template +type_t prev_value(type_t value) +{ + // TODO this probably handles non-native float16 wrong + if constexpr (std::is_floating_point::value) { + return std::nextafter(value, -std::numeric_limits::infinity()); + } + else { + if (value > std::numeric_limits::min()) { return value - 1; } + else { + return value; + } + } +} + +template +type_t next_value(type_t value) +{ + // TODO this probably handles non-native float16 wrong + if constexpr (std::is_floating_point::value) { + return std::nextafter(value, std::numeric_limits::infinity()); + } + else { + if (value < std::numeric_limits::max()) { return value + 1; } + else { + return value; + } + } +} + +template +X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); + +template +struct AscendingComparator{ + using reg_t = typename vtype::reg_t; + using opmask_t = typename vtype::opmask_t; + using type_t = typename vtype::type_t; + + static inline bool STDSortComparator(const type_t &a, const type_t &b){ + return comparison_func(a, b); + } + + static inline opmask_t PartitionComparator(reg_t a, reg_t b){ + return vtype::ge(a, b); + } + + static inline void COEX(reg_t &a, reg_t &b){ + ::COEX(a, b); + } + + // Returns a vector of values that would be sorted as far right as possible + // For ascending order, this is the maximum possible value + static inline reg_t rightmostPossibleVec(){ + return vtype::zmm_max(); + } + + // Returns the value that would be leftmost of the two when sorted + // For ascending order, that is the smaller value + static inline type_t leftmost(type_t smaller, type_t larger){ + UNUSED(larger); + return smaller; + } + + // Returns the value that would be rightmost of the two when sorted + // For ascending order, that is the larger value + static inline type_t rightmost(type_t smaller, type_t larger){ + UNUSED(smaller); + return larger; + } + + // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample + // Try just doing the next largest value greater than this seemingly very common value to seperate them out + static inline type_t choosePivotMedianIsSmallest(type_t median){ + return next_value(median); + } + + // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample + // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition + static inline type_t choosePivotMedianIsLargest(type_t median){ + return median; + } +}; + +template +struct DescendingComparator{ + using reg_t = typename vtype::reg_t; + using opmask_t = typename vtype::opmask_t; + using type_t = typename vtype::type_t; + + static inline bool STDSortComparator(const type_t &a, const type_t &b){ + return comparison_func(b, a); + } + + static inline opmask_t PartitionComparator(reg_t a, reg_t b){ + return vtype::ge(b, a); + } + + static inline void COEX(reg_t &a, reg_t &b){ + ::COEX(b, a); + } + + // Returns a vector of values that would be sorted as far right as possible + // For descending order, this is the minimum possible value + static inline reg_t rightmostPossibleVec(){ + return vtype::zmm_min(); + } + + // Returns the value that would be leftmost of the two when sorted + // For descending order, that is the larger value + static inline type_t leftmost(type_t smaller, type_t bigger){ + UNUSED(smaller); + return bigger; + } + + // Returns the value that would be rightmost of the two when sorted + // For descending order, that is the smaller value + static inline type_t rightmost(type_t smaller, type_t bigger){ + UNUSED(bigger); + return smaller; + } + + // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample + // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition + static inline type_t choosePivotMedianIsSmallest(type_t median){ + return median; + } + + // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample + // Try just doing the next smallest value less than this seemingly very common value to seperate them out + static inline type_t choosePivotMedianIsLargest(type_t median){ + return prev_value(median); + } +}; + +#endif // XSS_COMMON_COMPARATORS \ No newline at end of file diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 3d5e20ea..764488d3 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -37,6 +37,7 @@ #include "xss-common-includes.h" #include "xss-pivot-selection.hpp" #include "xss-network-qsort.hpp" +#include "xss-common-comparators.hpp" template bool is_a_nan(T elem) @@ -99,16 +100,28 @@ X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size) template X86_SIMD_SORT_INLINE void -replace_inf_with_nan(type_t *arr, arrsize_t size, arrsize_t nan_count) +replace_inf_with_nan(type_t *arr, arrsize_t size, arrsize_t nan_count, bool descending = false) { - for (arrsize_t ii = size - 1; nan_count > 0; --ii) { - if constexpr (std::is_floating_point_v) { - arr[ii] = std::numeric_limits::quiet_NaN(); + if (descending){ + for (arrsize_t ii = 0; nan_count > 0; ++ii) { + if constexpr (std::is_floating_point_v) { + arr[ii] = std::numeric_limits::quiet_NaN(); + } + else { + arr[ii] = 0xFFFF; + } + nan_count -= 1; } - else { - arr[ii] = 0xFFFF; + }else{ + for (arrsize_t ii = size - 1; nan_count > 0; --ii) { + if constexpr (std::is_floating_point_v) { + arr[ii] = std::numeric_limits::quiet_NaN(); + } + else { + arr[ii] = 0xFFFF; + } + nan_count -= 1; } - nan_count -= 1; } } @@ -137,6 +150,25 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T *arr, arrsize_t size) return size - count - 1; } +/* + * Sort all the NAN's to start of the array and return the index of the first elem + * in the array which is not a nan + */ +template +X86_SIMD_SORT_INLINE arrsize_t move_nans_to_start_of_array(T *arr, arrsize_t size) +{ + arrsize_t count = 0; + + for(arrsize_t i = 0; i < size; i++){ + if (is_a_nan(arr[i])){ + std::swap(arr[count], arr[i]); + count++; + } + } + + return count; +} + template X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b) { @@ -181,6 +213,7 @@ int avx512_double_compressstore(type_t *left_addr, // Generic function dispatches to AVX2 or AVX512 code template X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, @@ -190,10 +223,11 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, reg_t &smallest_vec, reg_t &biggest_vec) { - typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); + typename vtype::opmask_t right_mask + = comparator::PartitionComparator(curr_vec, pivot_vec); int amount_ge_pivot - = vtype::double_compressstore(l_store, r_store, ge_mask, curr_vec); + = vtype::double_compressstore(l_store, r_store, right_mask, curr_vec); smallest_vec = vtype::min(curr_vec, smallest_vec); biggest_vec = vtype::max(curr_vec, biggest_vec); @@ -205,7 +239,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, * Parition an array based on the pivot and returns the index of the * first element that is greater than or equal to the pivot. */ -template +template X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, arrsize_t left, arrsize_t right, @@ -217,7 +251,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, for (int32_t i = (right - left) % vtype::numlanes; i > 0; --i) { *smallest = std::min(*smallest, arr[left], comparison_func); *biggest = std::max(*biggest, arr[left], comparison_func); - if (!comparison_func(arr[left], pivot)) { + if (!comparator::STDSortComparator(arr[left], pivot)) { std::swap(arr[left], arr[--right]); } else { @@ -239,7 +273,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, arrsize_t l_store = left; arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, + = partition_vec(arr + l_store, arr + l_store + unpartitioned, vec, pivot_vec, @@ -278,7 +312,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, } // partition the current vector and save it on both sides of the array arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, + = partition_vec(arr + l_store, arr + l_store + unpartitioned, curr_vec, pivot_vec, @@ -290,7 +324,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, /* partition and save vec_left and vec_right */ arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, + = partition_vec(arr + l_store, arr + l_store + unpartitioned, vec_left, pivot_vec, @@ -299,7 +333,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; - amount_ge_pivot = partition_vec(arr + l_store, + amount_ge_pivot = partition_vec(arr + l_store, arr + l_store + unpartitioned, vec_right, pivot_vec, @@ -314,6 +348,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, } template X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, @@ -324,19 +359,19 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, type_t *biggest) { if constexpr (num_unroll == 0) { - return partition(arr, left, right, pivot, smallest, biggest); + return partition(arr, left, right, pivot, smallest, biggest); } /* Use regular partition for smaller arrays */ if (right - left < 3 * num_unroll * vtype::numlanes) { - return partition(arr, left, right, pivot, smallest, biggest); + return partition(arr, left, right, pivot, smallest, biggest); } /* make array length divisible by vtype::numlanes, shortening the array */ for (int32_t i = ((right - left) % (vtype::numlanes)); i > 0; --i) { *smallest = std::min(*smallest, arr[left], comparison_func); *biggest = std::max(*biggest, arr[left], comparison_func); - if (!comparison_func(arr[left], pivot)) { + if (!comparator::STDSortComparator(arr[left], pivot)) { std::swap(arr[left], arr[--right]); } else { @@ -419,7 +454,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, + = partition_vec(arr + l_store, arr + l_store + unpartitioned, curr_vec[ii], pivot_vec, @@ -434,7 +469,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, + = partition_vec(arr + l_store, arr + l_store + unpartitioned, vec_left[ii], pivot_vec, @@ -446,7 +481,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, + = partition_vec(arr + l_store, arr + l_store + unpartitioned, vec_right[ii], pivot_vec, @@ -460,7 +495,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < vecsToPartition; ++ii) { arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, + = partition_vec(arr + l_store, arr + l_store + unpartitioned, vec_align[ii], pivot_vec, @@ -478,7 +513,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, template void sort_n(typename vtype::type_t *arr, int N); -template +template static void qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) { @@ -486,19 +521,19 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) * Resort to std::sort if quicksort isnt making any progress */ if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1, comparison_func); + std::sort(arr + left, arr + right + 1, comparator::STDSortComparator); return; } /* * Base case: use bitonic networks to sort arrays <= vtype::network_sort_threshold */ if (right + 1 - left <= vtype::network_sort_threshold) { - sort_n( + sort_n( arr + left, (int32_t)(right + 1 - left)); return; } - auto pivot_result = get_pivot_smart(arr, left, right); + auto pivot_result = get_pivot_smart(arr, left, right); type_t pivot = pivot_result.pivot; if (pivot_result.result == pivot_result_t::Sorted) { return; } @@ -507,17 +542,21 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) type_t biggest = vtype::type_min(); arrsize_t pivot_index - = partition_unrolled( + = partition_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if (pivot_result.result == pivot_result_t::Only2Values) { return; } - - if (pivot != smallest) - qsort_(arr, left, pivot_index - 1, max_iters - 1); - if (pivot != biggest) qsort_(arr, pivot_index, right, max_iters - 1); + + type_t leftmostValue = comparator::leftmost(smallest, biggest); + type_t rightmostValue = comparator::rightmost(smallest, biggest); + + if (pivot != leftmostValue) + qsort_(arr, left, pivot_index - 1, max_iters - 1); + if (pivot != rightmostValue) + qsort_(arr, pivot_index, right, max_iters - 1); } -template +template X86_SIMD_SORT_INLINE void qselect_(type_t *arr, arrsize_t pos, arrsize_t left, @@ -528,35 +567,44 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, * Resort to std::sort if quicksort isnt making any progress */ if (max_iters <= 0) { - std::sort(arr + left, arr + right + 1, comparison_func); + std::sort(arr + left, arr + right + 1, comparator::STDSortComparator); return; } /* * Base case: use bitonic networks to sort arrays <= vtype::network_sort_threshold */ if (right + 1 - left <= vtype::network_sort_threshold) { - sort_n( + sort_n( arr + left, (int32_t)(right + 1 - left)); return; } - type_t pivot = get_pivot(arr, left, right); + auto pivot_result = get_pivot_smart(arr, left, right); + type_t pivot = pivot_result.pivot; + + if (pivot_result.result == pivot_result_t::Sorted) { return; } + type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); arrsize_t pivot_index - = partition_unrolled( + = partition_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); - if ((pivot != smallest) && (pos < pivot_index)) - qselect_(arr, pos, left, pivot_index - 1, max_iters - 1); - else if ((pivot != biggest) && (pos >= pivot_index)) - qselect_(arr, pos, pivot_index, right, max_iters - 1); + if (pivot_result.result == pivot_result_t::Only2Values) { return; } + + type_t leftmostValue = comparator::leftmost(smallest, biggest); + type_t rightmostValue = comparator::rightmost(smallest, biggest); + + if ((pivot != leftmostValue) && (pos < pivot_index)) + qselect_(arr, pos, left, pivot_index - 1, max_iters - 1); + else if ((pivot != rightmostValue) && (pos >= pivot_index)) + qselect_(arr, pos, pivot_index, right, max_iters - 1); } // Quicksort routines: template -X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan) +X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool descending) { if (arrsize > 1) { if constexpr (std::is_floating_point_v) { @@ -564,12 +612,20 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan) if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf(arr, arrsize); } - qsort_(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + if (descending){ + qsort_, T>(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + }else{ + qsort_, T>(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } + replace_inf_with_nan(arr, arrsize, nan_count, descending); } else { UNUSED(hasnan); - qsort_(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + if (descending){ + qsort_, T>(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + }else{ + qsort_, T>(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } } } } @@ -577,48 +633,65 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan) // Quick select methods template X86_SIMD_SORT_INLINE void -xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) +xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) { - arrsize_t indx_last_elem = arrsize - 1; - if constexpr (std::is_floating_point_v) { - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if (descending){ + arrsize_t index_first_elem = 0; + if constexpr (std::is_floating_point_v) { + if (UNLIKELY(hasnan)) { + index_first_elem = move_nans_to_start_of_array(arr, arrsize); + } + } + + arrsize_t size_without_nans = arrsize - index_first_elem; + + UNUSED(hasnan); + if (index_first_elem <= k) { + qselect_, T>( + arr, k, index_first_elem, arrsize - 1, 2 * (arrsize_t)log2(size_without_nans)); + } + }else{ + arrsize_t indx_last_elem = arrsize - 1; + if constexpr (std::is_floating_point_v) { + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + } + UNUSED(hasnan); + if (indx_last_elem >= k) { + qselect_, T>( + arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); } - } - UNUSED(hasnan); - if (indx_last_elem >= k) { - qselect_( - arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); } } // Partial sort methods: template X86_SIMD_SORT_INLINE void -xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) +xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) { - xss_qselect(arr, k - 1, arrsize, hasnan); - xss_qsort(arr, k - 1, hasnan); + xss_qselect(arr, k - 1, arrsize, hasnan, descending); + xss_qsort(arr, k - 1, hasnan, descending); } #define DEFINE_METHODS(ISA, VTYPE) \ template \ X86_SIMD_SORT_INLINE void ISA##_qsort( \ - T *arr, arrsize_t size, bool hasnan = false) \ + T *arr, arrsize_t size, bool hasnan = false, bool descending = false) \ { \ - xss_qsort(arr, size, hasnan); \ + xss_qsort(arr, size, hasnan, descending); \ } \ template \ X86_SIMD_SORT_INLINE void ISA##_qselect( \ - T *arr, arrsize_t k, arrsize_t size, bool hasnan = false) \ + T *arr, arrsize_t k, arrsize_t size, bool hasnan = false, bool descending = false) \ { \ - xss_qselect(arr, k, size, hasnan); \ + xss_qselect(arr, k, size, hasnan, descending); \ } \ template \ X86_SIMD_SORT_INLINE void ISA##_partial_qsort( \ - T *arr, arrsize_t k, arrsize_t size, bool hasnan = false) \ + T *arr, arrsize_t k, arrsize_t size, bool hasnan = false, bool descending = false) \ { \ - xss_partial_qsort(arr, k, size, hasnan); \ + xss_partial_qsort(arr, k, size, hasnan, descending); \ } DEFINE_METHODS(avx512, zmm_vector) diff --git a/src/xss-network-qsort.hpp b/src/xss-network-qsort.hpp index d883004a..0af2597e 100644 --- a/src/xss-network-qsort.hpp +++ b/src/xss-network-qsort.hpp @@ -7,7 +7,7 @@ template X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); -template +template X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) { if constexpr (numVecs == 1) { @@ -15,19 +15,19 @@ X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) return; } else if constexpr (numVecs == 2) { - COEX(regs[0], regs[1]); + comparator::COEX(regs[0], regs[1]); } else if constexpr (numVecs == 4) { - optimal_sort_4(regs); + optimal_sort_4(regs); } else if constexpr (numVecs == 8) { - optimal_sort_8(regs); + optimal_sort_8(regs); } else if constexpr (numVecs == 16) { - optimal_sort_16(regs); + optimal_sort_16(regs); } else if constexpr (numVecs == 32) { - optimal_sort_32(regs); + optimal_sort_32(regs); } else { static_assert(numVecs == -1, "should not reach here"); @@ -53,7 +53,7 @@ X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) * merge_n<8> = [a,a,a,a,b,b,b,b] */ -template +template X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg) { using reg_t = typename vtype::reg_t; @@ -69,7 +69,7 @@ X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg) for (int i = 0; i < numVecs; i++) { reg_t &v = reg[i]; reg_t rev = swizzle::template reverse_n(v); - COEX(rev, v); + comparator::COEX(rev, v); v = swizzle::template merge_n(v, rev); } } @@ -79,15 +79,16 @@ X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg) for (int i = 0; i < numVecs; i++) { reg_t &v = reg[i]; reg_t swap = swizzle::template swap_n(v); - COEX(swap, v); + comparator::COEX(swap, v); v = swizzle::template merge_n(v, swap); } } - internal_merge_n_vec(reg); + internal_merge_n_vec(reg); } } template @@ -107,27 +108,29 @@ X86_SIMD_SORT_FINLINE void merge_substep_n_vec(reg_t *regs) // Do compare exchanges X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs / 2; i++) { - COEX(regs[i], regs[numVecs - 1 - i]); + comparator::COEX(regs[i], regs[numVecs - 1 - i]); } - merge_substep_n_vec(regs); - merge_substep_n_vec(regs + numVecs / 2); + merge_substep_n_vec(regs); + merge_substep_n_vec(regs + numVecs / 2); } template X86_SIMD_SORT_FINLINE void merge_step_n_vec(reg_t *regs) { // Do cross vector merges - merge_substep_n_vec(regs); + merge_substep_n_vec(regs); // Do internal vector merges - internal_merge_n_vec(regs); + internal_merge_n_vec(regs); } template @@ -138,30 +141,30 @@ X86_SIMD_SORT_FINLINE void merge_n_vec(reg_t *regs) return; } else { - merge_step_n_vec(regs); - merge_n_vec(regs); + merge_step_n_vec(regs); + merge_n_vec(regs); } } -template +template X86_SIMD_SORT_FINLINE void sort_vectors(reg_t *vecs) { /* Run the initial sorting network to sort the columns of the [numVecs x * num_lanes] matrix */ - bitonic_sort_n_vec(vecs); + bitonic_sort_n_vec(vecs); // Merge the vectors using bitonic merging networks - merge_n_vec(vecs); + merge_n_vec(vecs); } -template +template X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) { static_assert(numVecs > 0, "numVecs should be > 0"); if constexpr (numVecs > 1) { if (N * 2 <= numVecs * vtype::numlanes) { - sort_n_vec(arr, N); + sort_n_vec(arr, N); return; } } @@ -187,10 +190,10 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { vecs[i] = vtype::mask_loadu( - vtype::zmm_max(), ioMasks[j], arr + i * vtype::numlanes); + comparator::rightmostPossibleVec(), ioMasks[j], arr + i * vtype::numlanes); } - sort_vectors(vecs); + sort_vectors(vecs); // Unmasked part of the store X86_SIMD_SORT_UNROLL_LOOP(64) @@ -204,7 +207,7 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) } } -template +template X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N) { constexpr int numVecs = maxN / vtype::numlanes; @@ -213,6 +216,6 @@ X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N) static_assert(powerOfTwo == true && isMultiple == true, "maxN must be vtype::numlanes times a power of 2"); - sort_n_vec(arr, N); + sort_n_vec(arr, N); } #endif diff --git a/src/xss-optimal-networks.hpp b/src/xss-optimal-networks.hpp index bffe493d..835c9224 100644 --- a/src/xss-optimal-networks.hpp +++ b/src/xss-optimal-networks.hpp @@ -1,323 +1,320 @@ // All of these sources files are generated from the optimal networks described in // https://bertdobbelaere.github.io/sorting_networks.html -template -X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); - -template +template X86_SIMD_SORT_FINLINE void optimal_sort_4(reg_t *vecs) { - COEX(vecs[0], vecs[2]); - COEX(vecs[1], vecs[3]); + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); - COEX(vecs[0], vecs[1]); - COEX(vecs[2], vecs[3]); + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); - COEX(vecs[1], vecs[2]); + comparator::COEX(vecs[1], vecs[2]); } -template +template X86_SIMD_SORT_FINLINE void optimal_sort_8(reg_t *vecs) { - COEX(vecs[0], vecs[2]); - COEX(vecs[1], vecs[3]); - COEX(vecs[4], vecs[6]); - COEX(vecs[5], vecs[7]); - - COEX(vecs[0], vecs[4]); - COEX(vecs[1], vecs[5]); - COEX(vecs[2], vecs[6]); - COEX(vecs[3], vecs[7]); - - COEX(vecs[0], vecs[1]); - COEX(vecs[2], vecs[3]); - COEX(vecs[4], vecs[5]); - COEX(vecs[6], vecs[7]); - - COEX(vecs[2], vecs[4]); - COEX(vecs[3], vecs[5]); - - COEX(vecs[1], vecs[4]); - COEX(vecs[3], vecs[6]); - - COEX(vecs[1], vecs[2]); - COEX(vecs[3], vecs[4]); - COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); + comparator::COEX(vecs[4], vecs[6]); + comparator::COEX(vecs[5], vecs[7]); + + comparator::COEX(vecs[0], vecs[4]); + comparator::COEX(vecs[1], vecs[5]); + comparator::COEX(vecs[2], vecs[6]); + comparator::COEX(vecs[3], vecs[7]); + + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); + comparator::COEX(vecs[4], vecs[5]); + comparator::COEX(vecs[6], vecs[7]); + + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[3], vecs[5]); + + comparator::COEX(vecs[1], vecs[4]); + comparator::COEX(vecs[3], vecs[6]); + + comparator::COEX(vecs[1], vecs[2]); + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[5], vecs[6]); } -template +template X86_SIMD_SORT_FINLINE void optimal_sort_16(reg_t *vecs) { - COEX(vecs[0], vecs[13]); - COEX(vecs[1], vecs[12]); - COEX(vecs[2], vecs[15]); - COEX(vecs[3], vecs[14]); - COEX(vecs[4], vecs[8]); - COEX(vecs[5], vecs[6]); - COEX(vecs[7], vecs[11]); - COEX(vecs[9], vecs[10]); - - COEX(vecs[0], vecs[5]); - COEX(vecs[1], vecs[7]); - COEX(vecs[2], vecs[9]); - COEX(vecs[3], vecs[4]); - COEX(vecs[6], vecs[13]); - COEX(vecs[8], vecs[14]); - COEX(vecs[10], vecs[15]); - COEX(vecs[11], vecs[12]); - - COEX(vecs[0], vecs[1]); - COEX(vecs[2], vecs[3]); - COEX(vecs[4], vecs[5]); - COEX(vecs[6], vecs[8]); - COEX(vecs[7], vecs[9]); - COEX(vecs[10], vecs[11]); - COEX(vecs[12], vecs[13]); - COEX(vecs[14], vecs[15]); - - COEX(vecs[0], vecs[2]); - COEX(vecs[1], vecs[3]); - COEX(vecs[4], vecs[10]); - COEX(vecs[5], vecs[11]); - COEX(vecs[6], vecs[7]); - COEX(vecs[8], vecs[9]); - COEX(vecs[12], vecs[14]); - COEX(vecs[13], vecs[15]); - - COEX(vecs[1], vecs[2]); - COEX(vecs[3], vecs[12]); - COEX(vecs[4], vecs[6]); - COEX(vecs[5], vecs[7]); - COEX(vecs[8], vecs[10]); - COEX(vecs[9], vecs[11]); - COEX(vecs[13], vecs[14]); - - COEX(vecs[1], vecs[4]); - COEX(vecs[2], vecs[6]); - COEX(vecs[5], vecs[8]); - COEX(vecs[7], vecs[10]); - COEX(vecs[9], vecs[13]); - COEX(vecs[11], vecs[14]); - - COEX(vecs[2], vecs[4]); - COEX(vecs[3], vecs[6]); - COEX(vecs[9], vecs[12]); - COEX(vecs[11], vecs[13]); - - COEX(vecs[3], vecs[5]); - COEX(vecs[6], vecs[8]); - COEX(vecs[7], vecs[9]); - COEX(vecs[10], vecs[12]); - - COEX(vecs[3], vecs[4]); - COEX(vecs[5], vecs[6]); - COEX(vecs[7], vecs[8]); - COEX(vecs[9], vecs[10]); - COEX(vecs[11], vecs[12]); - - COEX(vecs[6], vecs[7]); - COEX(vecs[8], vecs[9]); + comparator::COEX(vecs[0], vecs[13]); + comparator::COEX(vecs[1], vecs[12]); + comparator::COEX(vecs[2], vecs[15]); + comparator::COEX(vecs[3], vecs[14]); + comparator::COEX(vecs[4], vecs[8]); + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[7], vecs[11]); + comparator::COEX(vecs[9], vecs[10]); + + comparator::COEX(vecs[0], vecs[5]); + comparator::COEX(vecs[1], vecs[7]); + comparator::COEX(vecs[2], vecs[9]); + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[6], vecs[13]); + comparator::COEX(vecs[8], vecs[14]); + comparator::COEX(vecs[10], vecs[15]); + comparator::COEX(vecs[11], vecs[12]); + + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); + comparator::COEX(vecs[4], vecs[5]); + comparator::COEX(vecs[6], vecs[8]); + comparator::COEX(vecs[7], vecs[9]); + comparator::COEX(vecs[10], vecs[11]); + comparator::COEX(vecs[12], vecs[13]); + comparator::COEX(vecs[14], vecs[15]); + + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); + comparator::COEX(vecs[4], vecs[10]); + comparator::COEX(vecs[5], vecs[11]); + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); + comparator::COEX(vecs[12], vecs[14]); + comparator::COEX(vecs[13], vecs[15]); + + comparator::COEX(vecs[1], vecs[2]); + comparator::COEX(vecs[3], vecs[12]); + comparator::COEX(vecs[4], vecs[6]); + comparator::COEX(vecs[5], vecs[7]); + comparator::COEX(vecs[8], vecs[10]); + comparator::COEX(vecs[9], vecs[11]); + comparator::COEX(vecs[13], vecs[14]); + + comparator::COEX(vecs[1], vecs[4]); + comparator::COEX(vecs[2], vecs[6]); + comparator::COEX(vecs[5], vecs[8]); + comparator::COEX(vecs[7], vecs[10]); + comparator::COEX(vecs[9], vecs[13]); + comparator::COEX(vecs[11], vecs[14]); + + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[3], vecs[6]); + comparator::COEX(vecs[9], vecs[12]); + comparator::COEX(vecs[11], vecs[13]); + + comparator::COEX(vecs[3], vecs[5]); + comparator::COEX(vecs[6], vecs[8]); + comparator::COEX(vecs[7], vecs[9]); + comparator::COEX(vecs[10], vecs[12]); + + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[7], vecs[8]); + comparator::COEX(vecs[9], vecs[10]); + comparator::COEX(vecs[11], vecs[12]); + + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); } -template +template X86_SIMD_SORT_FINLINE void optimal_sort_32(reg_t *vecs) -{ - COEX(vecs[0], vecs[1]); - COEX(vecs[2], vecs[3]); - COEX(vecs[4], vecs[5]); - COEX(vecs[6], vecs[7]); - COEX(vecs[8], vecs[9]); - COEX(vecs[10], vecs[11]); - COEX(vecs[12], vecs[13]); - COEX(vecs[14], vecs[15]); - COEX(vecs[16], vecs[17]); - COEX(vecs[18], vecs[19]); - COEX(vecs[20], vecs[21]); - COEX(vecs[22], vecs[23]); - COEX(vecs[24], vecs[25]); - COEX(vecs[26], vecs[27]); - COEX(vecs[28], vecs[29]); - COEX(vecs[30], vecs[31]); - - COEX(vecs[0], vecs[2]); - COEX(vecs[1], vecs[3]); - COEX(vecs[4], vecs[6]); - COEX(vecs[5], vecs[7]); - COEX(vecs[8], vecs[10]); - COEX(vecs[9], vecs[11]); - COEX(vecs[12], vecs[14]); - COEX(vecs[13], vecs[15]); - COEX(vecs[16], vecs[18]); - COEX(vecs[17], vecs[19]); - COEX(vecs[20], vecs[22]); - COEX(vecs[21], vecs[23]); - COEX(vecs[24], vecs[26]); - COEX(vecs[25], vecs[27]); - COEX(vecs[28], vecs[30]); - COEX(vecs[29], vecs[31]); - - COEX(vecs[0], vecs[4]); - COEX(vecs[1], vecs[5]); - COEX(vecs[2], vecs[6]); - COEX(vecs[3], vecs[7]); - COEX(vecs[8], vecs[12]); - COEX(vecs[9], vecs[13]); - COEX(vecs[10], vecs[14]); - COEX(vecs[11], vecs[15]); - COEX(vecs[16], vecs[20]); - COEX(vecs[17], vecs[21]); - COEX(vecs[18], vecs[22]); - COEX(vecs[19], vecs[23]); - COEX(vecs[24], vecs[28]); - COEX(vecs[25], vecs[29]); - COEX(vecs[26], vecs[30]); - COEX(vecs[27], vecs[31]); - - COEX(vecs[0], vecs[8]); - COEX(vecs[1], vecs[9]); - COEX(vecs[2], vecs[10]); - COEX(vecs[3], vecs[11]); - COEX(vecs[4], vecs[12]); - COEX(vecs[5], vecs[13]); - COEX(vecs[6], vecs[14]); - COEX(vecs[7], vecs[15]); - COEX(vecs[16], vecs[24]); - COEX(vecs[17], vecs[25]); - COEX(vecs[18], vecs[26]); - COEX(vecs[19], vecs[27]); - COEX(vecs[20], vecs[28]); - COEX(vecs[21], vecs[29]); - COEX(vecs[22], vecs[30]); - COEX(vecs[23], vecs[31]); - - COEX(vecs[0], vecs[16]); - COEX(vecs[1], vecs[8]); - COEX(vecs[2], vecs[4]); - COEX(vecs[3], vecs[12]); - COEX(vecs[5], vecs[10]); - COEX(vecs[6], vecs[9]); - COEX(vecs[7], vecs[14]); - COEX(vecs[11], vecs[13]); - COEX(vecs[15], vecs[31]); - COEX(vecs[17], vecs[24]); - COEX(vecs[18], vecs[20]); - COEX(vecs[19], vecs[28]); - COEX(vecs[21], vecs[26]); - COEX(vecs[22], vecs[25]); - COEX(vecs[23], vecs[30]); - COEX(vecs[27], vecs[29]); - - COEX(vecs[1], vecs[2]); - COEX(vecs[3], vecs[5]); - COEX(vecs[4], vecs[8]); - COEX(vecs[6], vecs[22]); - COEX(vecs[7], vecs[11]); - COEX(vecs[9], vecs[25]); - COEX(vecs[10], vecs[12]); - COEX(vecs[13], vecs[14]); - COEX(vecs[17], vecs[18]); - COEX(vecs[19], vecs[21]); - COEX(vecs[20], vecs[24]); - COEX(vecs[23], vecs[27]); - COEX(vecs[26], vecs[28]); - COEX(vecs[29], vecs[30]); - - COEX(vecs[1], vecs[17]); - COEX(vecs[2], vecs[18]); - COEX(vecs[3], vecs[19]); - COEX(vecs[4], vecs[20]); - COEX(vecs[5], vecs[10]); - COEX(vecs[7], vecs[23]); - COEX(vecs[8], vecs[24]); - COEX(vecs[11], vecs[27]); - COEX(vecs[12], vecs[28]); - COEX(vecs[13], vecs[29]); - COEX(vecs[14], vecs[30]); - COEX(vecs[21], vecs[26]); - - COEX(vecs[3], vecs[17]); - COEX(vecs[4], vecs[16]); - COEX(vecs[5], vecs[21]); - COEX(vecs[6], vecs[18]); - COEX(vecs[7], vecs[9]); - COEX(vecs[8], vecs[20]); - COEX(vecs[10], vecs[26]); - COEX(vecs[11], vecs[23]); - COEX(vecs[13], vecs[25]); - COEX(vecs[14], vecs[28]); - COEX(vecs[15], vecs[27]); - COEX(vecs[22], vecs[24]); - - COEX(vecs[1], vecs[4]); - COEX(vecs[3], vecs[8]); - COEX(vecs[5], vecs[16]); - COEX(vecs[7], vecs[17]); - COEX(vecs[9], vecs[21]); - COEX(vecs[10], vecs[22]); - COEX(vecs[11], vecs[19]); - COEX(vecs[12], vecs[20]); - COEX(vecs[14], vecs[24]); - COEX(vecs[15], vecs[26]); - COEX(vecs[23], vecs[28]); - COEX(vecs[27], vecs[30]); - - COEX(vecs[2], vecs[5]); - COEX(vecs[7], vecs[8]); - COEX(vecs[9], vecs[18]); - COEX(vecs[11], vecs[17]); - COEX(vecs[12], vecs[16]); - COEX(vecs[13], vecs[22]); - COEX(vecs[14], vecs[20]); - COEX(vecs[15], vecs[19]); - COEX(vecs[23], vecs[24]); - COEX(vecs[26], vecs[29]); - - COEX(vecs[2], vecs[4]); - COEX(vecs[6], vecs[12]); - COEX(vecs[9], vecs[16]); - COEX(vecs[10], vecs[11]); - COEX(vecs[13], vecs[17]); - COEX(vecs[14], vecs[18]); - COEX(vecs[15], vecs[22]); - COEX(vecs[19], vecs[25]); - COEX(vecs[20], vecs[21]); - COEX(vecs[27], vecs[29]); - - COEX(vecs[5], vecs[6]); - COEX(vecs[8], vecs[12]); - COEX(vecs[9], vecs[10]); - COEX(vecs[11], vecs[13]); - COEX(vecs[14], vecs[16]); - COEX(vecs[15], vecs[17]); - COEX(vecs[18], vecs[20]); - COEX(vecs[19], vecs[23]); - COEX(vecs[21], vecs[22]); - COEX(vecs[25], vecs[26]); - - COEX(vecs[3], vecs[5]); - COEX(vecs[6], vecs[7]); - COEX(vecs[8], vecs[9]); - COEX(vecs[10], vecs[12]); - COEX(vecs[11], vecs[14]); - COEX(vecs[13], vecs[16]); - COEX(vecs[15], vecs[18]); - COEX(vecs[17], vecs[20]); - COEX(vecs[19], vecs[21]); - COEX(vecs[22], vecs[23]); - COEX(vecs[24], vecs[25]); - COEX(vecs[26], vecs[28]); - - COEX(vecs[3], vecs[4]); - COEX(vecs[5], vecs[6]); - COEX(vecs[7], vecs[8]); - COEX(vecs[9], vecs[10]); - COEX(vecs[11], vecs[12]); - COEX(vecs[13], vecs[14]); - COEX(vecs[15], vecs[16]); - COEX(vecs[17], vecs[18]); - COEX(vecs[19], vecs[20]); - COEX(vecs[21], vecs[22]); - COEX(vecs[23], vecs[24]); - COEX(vecs[25], vecs[26]); - COEX(vecs[27], vecs[28]); +{ + comparator::COEX(vecs[0], vecs[1]); + comparator::COEX(vecs[2], vecs[3]); + comparator::COEX(vecs[4], vecs[5]); + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); + comparator::COEX(vecs[10], vecs[11]); + comparator::COEX(vecs[12], vecs[13]); + comparator::COEX(vecs[14], vecs[15]); + comparator::COEX(vecs[16], vecs[17]); + comparator::COEX(vecs[18], vecs[19]); + comparator::COEX(vecs[20], vecs[21]); + comparator::COEX(vecs[22], vecs[23]); + comparator::COEX(vecs[24], vecs[25]); + comparator::COEX(vecs[26], vecs[27]); + comparator::COEX(vecs[28], vecs[29]); + comparator::COEX(vecs[30], vecs[31]); + + comparator::COEX(vecs[0], vecs[2]); + comparator::COEX(vecs[1], vecs[3]); + comparator::COEX(vecs[4], vecs[6]); + comparator::COEX(vecs[5], vecs[7]); + comparator::COEX(vecs[8], vecs[10]); + comparator::COEX(vecs[9], vecs[11]); + comparator::COEX(vecs[12], vecs[14]); + comparator::COEX(vecs[13], vecs[15]); + comparator::COEX(vecs[16], vecs[18]); + comparator::COEX(vecs[17], vecs[19]); + comparator::COEX(vecs[20], vecs[22]); + comparator::COEX(vecs[21], vecs[23]); + comparator::COEX(vecs[24], vecs[26]); + comparator::COEX(vecs[25], vecs[27]); + comparator::COEX(vecs[28], vecs[30]); + comparator::COEX(vecs[29], vecs[31]); + + comparator::COEX(vecs[0], vecs[4]); + comparator::COEX(vecs[1], vecs[5]); + comparator::COEX(vecs[2], vecs[6]); + comparator::COEX(vecs[3], vecs[7]); + comparator::COEX(vecs[8], vecs[12]); + comparator::COEX(vecs[9], vecs[13]); + comparator::COEX(vecs[10], vecs[14]); + comparator::COEX(vecs[11], vecs[15]); + comparator::COEX(vecs[16], vecs[20]); + comparator::COEX(vecs[17], vecs[21]); + comparator::COEX(vecs[18], vecs[22]); + comparator::COEX(vecs[19], vecs[23]); + comparator::COEX(vecs[24], vecs[28]); + comparator::COEX(vecs[25], vecs[29]); + comparator::COEX(vecs[26], vecs[30]); + comparator::COEX(vecs[27], vecs[31]); + + comparator::COEX(vecs[0], vecs[8]); + comparator::COEX(vecs[1], vecs[9]); + comparator::COEX(vecs[2], vecs[10]); + comparator::COEX(vecs[3], vecs[11]); + comparator::COEX(vecs[4], vecs[12]); + comparator::COEX(vecs[5], vecs[13]); + comparator::COEX(vecs[6], vecs[14]); + comparator::COEX(vecs[7], vecs[15]); + comparator::COEX(vecs[16], vecs[24]); + comparator::COEX(vecs[17], vecs[25]); + comparator::COEX(vecs[18], vecs[26]); + comparator::COEX(vecs[19], vecs[27]); + comparator::COEX(vecs[20], vecs[28]); + comparator::COEX(vecs[21], vecs[29]); + comparator::COEX(vecs[22], vecs[30]); + comparator::COEX(vecs[23], vecs[31]); + + comparator::COEX(vecs[0], vecs[16]); + comparator::COEX(vecs[1], vecs[8]); + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[3], vecs[12]); + comparator::COEX(vecs[5], vecs[10]); + comparator::COEX(vecs[6], vecs[9]); + comparator::COEX(vecs[7], vecs[14]); + comparator::COEX(vecs[11], vecs[13]); + comparator::COEX(vecs[15], vecs[31]); + comparator::COEX(vecs[17], vecs[24]); + comparator::COEX(vecs[18], vecs[20]); + comparator::COEX(vecs[19], vecs[28]); + comparator::COEX(vecs[21], vecs[26]); + comparator::COEX(vecs[22], vecs[25]); + comparator::COEX(vecs[23], vecs[30]); + comparator::COEX(vecs[27], vecs[29]); + + comparator::COEX(vecs[1], vecs[2]); + comparator::COEX(vecs[3], vecs[5]); + comparator::COEX(vecs[4], vecs[8]); + comparator::COEX(vecs[6], vecs[22]); + comparator::COEX(vecs[7], vecs[11]); + comparator::COEX(vecs[9], vecs[25]); + comparator::COEX(vecs[10], vecs[12]); + comparator::COEX(vecs[13], vecs[14]); + comparator::COEX(vecs[17], vecs[18]); + comparator::COEX(vecs[19], vecs[21]); + comparator::COEX(vecs[20], vecs[24]); + comparator::COEX(vecs[23], vecs[27]); + comparator::COEX(vecs[26], vecs[28]); + comparator::COEX(vecs[29], vecs[30]); + + comparator::COEX(vecs[1], vecs[17]); + comparator::COEX(vecs[2], vecs[18]); + comparator::COEX(vecs[3], vecs[19]); + comparator::COEX(vecs[4], vecs[20]); + comparator::COEX(vecs[5], vecs[10]); + comparator::COEX(vecs[7], vecs[23]); + comparator::COEX(vecs[8], vecs[24]); + comparator::COEX(vecs[11], vecs[27]); + comparator::COEX(vecs[12], vecs[28]); + comparator::COEX(vecs[13], vecs[29]); + comparator::COEX(vecs[14], vecs[30]); + comparator::COEX(vecs[21], vecs[26]); + + comparator::COEX(vecs[3], vecs[17]); + comparator::COEX(vecs[4], vecs[16]); + comparator::COEX(vecs[5], vecs[21]); + comparator::COEX(vecs[6], vecs[18]); + comparator::COEX(vecs[7], vecs[9]); + comparator::COEX(vecs[8], vecs[20]); + comparator::COEX(vecs[10], vecs[26]); + comparator::COEX(vecs[11], vecs[23]); + comparator::COEX(vecs[13], vecs[25]); + comparator::COEX(vecs[14], vecs[28]); + comparator::COEX(vecs[15], vecs[27]); + comparator::COEX(vecs[22], vecs[24]); + + comparator::COEX(vecs[1], vecs[4]); + comparator::COEX(vecs[3], vecs[8]); + comparator::COEX(vecs[5], vecs[16]); + comparator::COEX(vecs[7], vecs[17]); + comparator::COEX(vecs[9], vecs[21]); + comparator::COEX(vecs[10], vecs[22]); + comparator::COEX(vecs[11], vecs[19]); + comparator::COEX(vecs[12], vecs[20]); + comparator::COEX(vecs[14], vecs[24]); + comparator::COEX(vecs[15], vecs[26]); + comparator::COEX(vecs[23], vecs[28]); + comparator::COEX(vecs[27], vecs[30]); + + comparator::COEX(vecs[2], vecs[5]); + comparator::COEX(vecs[7], vecs[8]); + comparator::COEX(vecs[9], vecs[18]); + comparator::COEX(vecs[11], vecs[17]); + comparator::COEX(vecs[12], vecs[16]); + comparator::COEX(vecs[13], vecs[22]); + comparator::COEX(vecs[14], vecs[20]); + comparator::COEX(vecs[15], vecs[19]); + comparator::COEX(vecs[23], vecs[24]); + comparator::COEX(vecs[26], vecs[29]); + + comparator::COEX(vecs[2], vecs[4]); + comparator::COEX(vecs[6], vecs[12]); + comparator::COEX(vecs[9], vecs[16]); + comparator::COEX(vecs[10], vecs[11]); + comparator::COEX(vecs[13], vecs[17]); + comparator::COEX(vecs[14], vecs[18]); + comparator::COEX(vecs[15], vecs[22]); + comparator::COEX(vecs[19], vecs[25]); + comparator::COEX(vecs[20], vecs[21]); + comparator::COEX(vecs[27], vecs[29]); + + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[8], vecs[12]); + comparator::COEX(vecs[9], vecs[10]); + comparator::COEX(vecs[11], vecs[13]); + comparator::COEX(vecs[14], vecs[16]); + comparator::COEX(vecs[15], vecs[17]); + comparator::COEX(vecs[18], vecs[20]); + comparator::COEX(vecs[19], vecs[23]); + comparator::COEX(vecs[21], vecs[22]); + comparator::COEX(vecs[25], vecs[26]); + + comparator::COEX(vecs[3], vecs[5]); + comparator::COEX(vecs[6], vecs[7]); + comparator::COEX(vecs[8], vecs[9]); + comparator::COEX(vecs[10], vecs[12]); + comparator::COEX(vecs[11], vecs[14]); + comparator::COEX(vecs[13], vecs[16]); + comparator::COEX(vecs[15], vecs[18]); + comparator::COEX(vecs[17], vecs[20]); + comparator::COEX(vecs[19], vecs[21]); + comparator::COEX(vecs[22], vecs[23]); + comparator::COEX(vecs[24], vecs[25]); + comparator::COEX(vecs[26], vecs[28]); + + comparator::COEX(vecs[3], vecs[4]); + comparator::COEX(vecs[5], vecs[6]); + comparator::COEX(vecs[7], vecs[8]); + comparator::COEX(vecs[9], vecs[10]); + comparator::COEX(vecs[11], vecs[12]); + comparator::COEX(vecs[13], vecs[14]); + comparator::COEX(vecs[15], vecs[16]); + comparator::COEX(vecs[17], vecs[18]); + comparator::COEX(vecs[19], vecs[20]); + comparator::COEX(vecs[21], vecs[22]); + comparator::COEX(vecs[23], vecs[24]); + comparator::COEX(vecs[25], vecs[26]); + comparator::COEX(vecs[27], vecs[28]); } diff --git a/src/xss-pivot-selection.hpp b/src/xss-pivot-selection.hpp index 59dc0489..f6ed6eeb 100644 --- a/src/xss-pivot-selection.hpp +++ b/src/xss-pivot-selection.hpp @@ -2,6 +2,7 @@ #define XSS_PIVOT_SELECTION #include "xss-network-qsort.hpp" +#include "xss-common-comparators.hpp" enum class pivot_result_t : int { Normal, Sorted, Only2Values }; @@ -19,21 +20,6 @@ struct pivot_results { } }; -template -type_t next_value(type_t value) -{ - // TODO this probably handles non-native float16 wrong - if constexpr (std::is_floating_point::value) { - return std::nextafter(value, std::numeric_limits::infinity()); - } - else { - if (value < std::numeric_limits::max()) { return value + 1; } - else { - return value; - } - } -} - template X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); @@ -98,14 +84,14 @@ X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr, return data[vtype::numlanes / 2]; } -template +template X86_SIMD_SORT_INLINE pivot_results get_pivot_near_constant(type_t *arr, type_t commonValue, const arrsize_t left, const arrsize_t right); -template +template X86_SIMD_SORT_INLINE pivot_results get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right) { @@ -127,7 +113,9 @@ get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right) } // Sort the samples - sort_vectors(vecs); + // Note that this intentionally uses the AscendingComparator + // instead of the provided comparator + sort_vectors, numVecs>(vecs); type_t samples[N]; for (int i = 0; i < numVecs; i++) { @@ -141,21 +129,21 @@ get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right) if (smallest == largest) { // We have a very unlucky sample, or the array is constant / near constant // Run a special function meant to deal with this situation - return get_pivot_near_constant(arr, median, left, right); + return get_pivot_near_constant(arr, median, left, right); } else if (median != smallest && median != largest) { // We have a normal sample; use it's median return pivot_results(median); } else if (median == smallest) { - // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample - // Try just doing the next largest value greater than this seemingly very common value to seperate them out - return pivot_results(next_value(median)); + // We will either return the median or the next value larger than the median, + // depending on the comparator (see xss-common-comparators.hpp for more details) + return pivot_results(comparator::choosePivotMedianIsSmallest(median)); } else if (median == largest) { - // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample - // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition - return pivot_results(median); + // We will either return the median or the next value smaller than the median, + // depending on the comparator (see xss-common-comparators.hpp for more details) + return pivot_results(comparator::choosePivotMedianIsLargest(median)); } else { // Should be unreachable @@ -167,7 +155,7 @@ get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right) } // Handles the case where we seem to have a near-constant array, since our sample of the array was constant -template +template X86_SIMD_SORT_INLINE pivot_results get_pivot_near_constant(type_t *arr, type_t commonValue, @@ -228,9 +216,10 @@ get_pivot_near_constant(type_t *arr, if (index == right + 1) { // The array contains only 2 values // We must pick the larger one, else the right partition is empty - // We can also skip recursing, as it is guaranteed both partitions are constant after partitioning with the larger value + // (note that larger is determined using the provided comparator, so it might actually be the smaller one) + // We can also skip recursing, as it is guaranteed both partitions are constant after partitioning with the chosen value // TODO this logic now assumes we use greater than or equal to specifically when partitioning, might be worth noting that somewhere - type_t pivot = std::max(value1, commonValue, comparison_func); + type_t pivot = std::max(value1, commonValue, comparator::STDSortComparator); return pivot_results(pivot, pivot_result_t::Only2Values); } From 80729a81c15119399a704157b46e89774f1dfd42 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 12 Mar 2024 10:24:11 -0700 Subject: [PATCH 04/18] Changed some things to be force inlined --- src/xss-common-comparators.hpp | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/xss-common-comparators.hpp b/src/xss-common-comparators.hpp index 42234727..c7d525a1 100644 --- a/src/xss-common-comparators.hpp +++ b/src/xss-common-comparators.hpp @@ -40,47 +40,47 @@ struct AscendingComparator{ using opmask_t = typename vtype::opmask_t; using type_t = typename vtype::type_t; - static inline bool STDSortComparator(const type_t &a, const type_t &b){ + X86_SIMD_SORT_FINLINE bool STDSortComparator(const type_t &a, const type_t &b){ return comparison_func(a, b); } - static inline opmask_t PartitionComparator(reg_t a, reg_t b){ + X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b){ return vtype::ge(a, b); } - static inline void COEX(reg_t &a, reg_t &b){ + X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b){ ::COEX(a, b); } // Returns a vector of values that would be sorted as far right as possible // For ascending order, this is the maximum possible value - static inline reg_t rightmostPossibleVec(){ + X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec(){ return vtype::zmm_max(); } // Returns the value that would be leftmost of the two when sorted // For ascending order, that is the smaller value - static inline type_t leftmost(type_t smaller, type_t larger){ + X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t larger){ UNUSED(larger); return smaller; } // Returns the value that would be rightmost of the two when sorted // For ascending order, that is the larger value - static inline type_t rightmost(type_t smaller, type_t larger){ + X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t larger){ UNUSED(smaller); return larger; } // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample // Try just doing the next largest value greater than this seemingly very common value to seperate them out - static inline type_t choosePivotMedianIsSmallest(type_t median){ + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median){ return next_value(median); } // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition - static inline type_t choosePivotMedianIsLargest(type_t median){ + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median){ return median; } }; @@ -91,47 +91,47 @@ struct DescendingComparator{ using opmask_t = typename vtype::opmask_t; using type_t = typename vtype::type_t; - static inline bool STDSortComparator(const type_t &a, const type_t &b){ + X86_SIMD_SORT_FINLINE bool STDSortComparator(const type_t &a, const type_t &b){ return comparison_func(b, a); } - static inline opmask_t PartitionComparator(reg_t a, reg_t b){ + X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b){ return vtype::ge(b, a); } - static inline void COEX(reg_t &a, reg_t &b){ + X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b){ ::COEX(b, a); } // Returns a vector of values that would be sorted as far right as possible // For descending order, this is the minimum possible value - static inline reg_t rightmostPossibleVec(){ + X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec(){ return vtype::zmm_min(); } // Returns the value that would be leftmost of the two when sorted // For descending order, that is the larger value - static inline type_t leftmost(type_t smaller, type_t bigger){ + X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t bigger){ UNUSED(smaller); return bigger; } // Returns the value that would be rightmost of the two when sorted // For descending order, that is the smaller value - static inline type_t rightmost(type_t smaller, type_t bigger){ + X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t bigger){ UNUSED(bigger); return smaller; } // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition - static inline type_t choosePivotMedianIsSmallest(type_t median){ + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median){ return median; } // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample // Try just doing the next smallest value less than this seemingly very common value to seperate them out - static inline type_t choosePivotMedianIsLargest(type_t median){ + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median){ return prev_value(median); } }; From edae7959b33839d5021e9b5bd7072412e0582258 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 12 Mar 2024 11:14:25 -0700 Subject: [PATCH 05/18] Undid change to qselect pivot algorithm for performance --- src/xss-common-qsort.h | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 764488d3..0c25b99e 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -579,10 +579,7 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, return; } - auto pivot_result = get_pivot_smart(arr, left, right); - type_t pivot = pivot_result.pivot; - - if (pivot_result.result == pivot_result_t::Sorted) { return; } + type_t pivot = get_pivot(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); @@ -590,8 +587,6 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, arrsize_t pivot_index = partition_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); - - if (pivot_result.result == pivot_result_t::Only2Values) { return; } type_t leftmostValue = comparator::leftmost(smallest, biggest); type_t rightmostValue = comparator::rightmost(smallest, biggest); From f6da098e2c2585dc2142eb51296554d55f35280f Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 12 Mar 2024 11:36:42 -0700 Subject: [PATCH 06/18] clang-format --- src/avx512-16bit-qsort.hpp | 33 +++-- src/xss-common-comparators.hpp | 86 ++++++++----- src/xss-common-qsort.h | 220 +++++++++++++++++++-------------- src/xss-network-qsort.hpp | 29 +++-- src/xss-optimal-networks.hpp | 18 ++- src/xss-pivot-selection.hpp | 12 +- tests/test-qsort.cpp | 18 +-- 7 files changed, 252 insertions(+), 164 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index cbeb3e27..3ed6d447 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -560,22 +560,24 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan(uint16_t elem) return ((elem & 0x7c00u) == 0x7c00u) && ((elem & 0x03ffu) != 0); } -X86_SIMD_SORT_INLINE void -avx512_qsort_fp16(uint16_t *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false) +X86_SIMD_SORT_INLINE void avx512_qsort_fp16(uint16_t *arr, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) { using vtype = zmm_vector; - + if (arrsize > 1) { arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf, uint16_t>( arr, arrsize); } - if (descending){ + if (descending) { qsort_, uint16_t>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - - }else{ + } + else { qsort_, uint16_t>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } @@ -590,18 +592,27 @@ X86_SIMD_SORT_INLINE void avx512_qselect_fp16(uint16_t *arr, bool descending = false) { using vtype = zmm_vector; - + arrsize_t indx_last_elem = arrsize - 1; if (UNLIKELY(hasnan)) { indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } if (indx_last_elem >= k) { - if (descending){ + if (descending) { qselect_, uint16_t>( - arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); - }else{ + arr, + k, + 0, + indx_last_elem, + 2 * (arrsize_t)log2(indx_last_elem)); + } + else { qselect_, uint16_t>( - arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); + arr, + k, + 0, + indx_last_elem, + 2 * (arrsize_t)log2(indx_last_elem)); } } } diff --git a/src/xss-common-comparators.hpp b/src/xss-common-comparators.hpp index c7d525a1..48a1d020 100644 --- a/src/xss-common-comparators.hpp +++ b/src/xss-common-comparators.hpp @@ -35,103 +35,121 @@ template X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); template -struct AscendingComparator{ +struct AscendingComparator { using reg_t = typename vtype::reg_t; using opmask_t = typename vtype::opmask_t; using type_t = typename vtype::type_t; - - X86_SIMD_SORT_FINLINE bool STDSortComparator(const type_t &a, const type_t &b){ + + X86_SIMD_SORT_FINLINE bool STDSortComparator(const type_t &a, + const type_t &b) + { return comparison_func(a, b); } - - X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b){ + + X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b) + { return vtype::ge(a, b); } - - X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b){ + + X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b) + { ::COEX(a, b); } - + // Returns a vector of values that would be sorted as far right as possible // For ascending order, this is the maximum possible value - X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec(){ + X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec() + { return vtype::zmm_max(); } - + // Returns the value that would be leftmost of the two when sorted // For ascending order, that is the smaller value - X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t larger){ + X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t larger) + { UNUSED(larger); return smaller; } - + // Returns the value that would be rightmost of the two when sorted // For ascending order, that is the larger value - X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t larger){ + X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t larger) + { UNUSED(smaller); return larger; } - + // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample // Try just doing the next largest value greater than this seemingly very common value to seperate them out - X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median){ + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median) + { return next_value(median); } - + // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition - X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median){ + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median) + { return median; } }; template -struct DescendingComparator{ +struct DescendingComparator { using reg_t = typename vtype::reg_t; using opmask_t = typename vtype::opmask_t; using type_t = typename vtype::type_t; - - X86_SIMD_SORT_FINLINE bool STDSortComparator(const type_t &a, const type_t &b){ + + X86_SIMD_SORT_FINLINE bool STDSortComparator(const type_t &a, + const type_t &b) + { return comparison_func(b, a); } - - X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b){ + + X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b) + { return vtype::ge(b, a); } - - X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b){ + + X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b) + { ::COEX(b, a); } - + // Returns a vector of values that would be sorted as far right as possible // For descending order, this is the minimum possible value - X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec(){ + X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec() + { return vtype::zmm_min(); } - + // Returns the value that would be leftmost of the two when sorted // For descending order, that is the larger value - X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t bigger){ + X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t bigger) + { UNUSED(smaller); return bigger; } - + // Returns the value that would be rightmost of the two when sorted // For descending order, that is the smaller value - X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t bigger){ + X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t bigger) + { UNUSED(bigger); return smaller; } - + // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition - X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median){ + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median) + { return median; } - + // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample // Try just doing the next smallest value less than this seemingly very common value to seperate them out - X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median){ + X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median) + { return prev_value(median); } }; diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 0c25b99e..435abd56 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -99,10 +99,12 @@ X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size) } template -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(type_t *arr, arrsize_t size, arrsize_t nan_count, bool descending = false) +X86_SIMD_SORT_INLINE void replace_inf_with_nan(type_t *arr, + arrsize_t size, + arrsize_t nan_count, + bool descending = false) { - if (descending){ + if (descending) { for (arrsize_t ii = 0; nan_count > 0; ++ii) { if constexpr (std::is_floating_point_v) { arr[ii] = std::numeric_limits::quiet_NaN(); @@ -112,7 +114,8 @@ replace_inf_with_nan(type_t *arr, arrsize_t size, arrsize_t nan_count, bool desc } nan_count -= 1; } - }else{ + } + else { for (arrsize_t ii = size - 1; nan_count > 0; --ii) { if constexpr (std::is_floating_point_v) { arr[ii] = std::numeric_limits::quiet_NaN(); @@ -155,17 +158,18 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T *arr, arrsize_t size) * in the array which is not a nan */ template -X86_SIMD_SORT_INLINE arrsize_t move_nans_to_start_of_array(T *arr, arrsize_t size) +X86_SIMD_SORT_INLINE arrsize_t move_nans_to_start_of_array(T *arr, + arrsize_t size) { arrsize_t count = 0; - - for(arrsize_t i = 0; i < size; i++){ - if (is_a_nan(arr[i])){ + + for (arrsize_t i = 0; i < size; i++) { + if (is_a_nan(arr[i])) { std::swap(arr[count], arr[i]); count++; } } - + return count; } @@ -226,8 +230,8 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, typename vtype::opmask_t right_mask = comparator::PartitionComparator(curr_vec, pivot_vec); - int amount_ge_pivot - = vtype::double_compressstore(l_store, r_store, right_mask, curr_vec); + int amount_ge_pivot = vtype::double_compressstore( + l_store, r_store, right_mask, curr_vec); smallest_vec = vtype::min(curr_vec, smallest_vec); biggest_vec = vtype::max(curr_vec, biggest_vec); @@ -272,13 +276,13 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, arrsize_t unpartitioned = right - left - vtype::numlanes; arrsize_t l_store = left; - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec, - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + vec, + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); *smallest = vtype::reducemin(min_vec); *biggest = vtype::reducemax(max_vec); @@ -311,13 +315,13 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, left += vtype::numlanes; } // partition the current vector and save it on both sides of the array - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - curr_vec, - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + curr_vec, + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; } @@ -325,20 +329,21 @@ X86_SIMD_SORT_INLINE arrsize_t partition(type_t *arr, /* partition and save vec_left and vec_right */ arrsize_t amount_ge_pivot = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec_left, - pivot_vec, - min_vec, - max_vec); + arr + l_store + unpartitioned, + vec_left, + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; - amount_ge_pivot = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec_right, - pivot_vec, - min_vec, - max_vec); + amount_ge_pivot + = partition_vec(arr + l_store, + arr + l_store + unpartitioned, + vec_right, + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; @@ -453,13 +458,13 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, * */ X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - curr_vec[ii], - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + curr_vec[ii], + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; } @@ -468,25 +473,25 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, /* partition and save vec_left[num_unroll] and vec_right[num_unroll] */ X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec_left[ii], - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + vec_left[ii], + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; } X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < num_unroll; ++ii) { - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec_right[ii], - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + vec_right[ii], + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; } @@ -494,13 +499,13 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, /* partition and save vec_align[vecsToPartition] */ X86_SIMD_SORT_UNROLL_LOOP(8) for (int ii = 0; ii < vecsToPartition; ++ii) { - arrsize_t amount_ge_pivot - = partition_vec(arr + l_store, - arr + l_store + unpartitioned, - vec_align[ii], - pivot_vec, - min_vec, - max_vec); + arrsize_t amount_ge_pivot = partition_vec( + arr + l_store, + arr + l_store + unpartitioned, + vec_align[ii], + pivot_vec, + min_vec, + max_vec); l_store += (vtype::numlanes - amount_ge_pivot); unpartitioned -= vtype::numlanes; } @@ -533,7 +538,8 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) return; } - auto pivot_result = get_pivot_smart(arr, left, right); + auto pivot_result + = get_pivot_smart(arr, left, right); type_t pivot = pivot_result.pivot; if (pivot_result.result == pivot_result_t::Sorted) { return; } @@ -546,7 +552,7 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) arr, left, right + 1, pivot, &smallest, &biggest); if (pivot_result.result == pivot_result_t::Only2Values) { return; } - + type_t leftmostValue = comparator::leftmost(smallest, biggest); type_t rightmostValue = comparator::rightmost(smallest, biggest); @@ -587,19 +593,22 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, arrsize_t pivot_index = partition_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); - + type_t leftmostValue = comparator::leftmost(smallest, biggest); type_t rightmostValue = comparator::rightmost(smallest, biggest); if ((pivot != leftmostValue) && (pos < pivot_index)) - qselect_(arr, pos, left, pivot_index - 1, max_iters - 1); + qselect_( + arr, pos, left, pivot_index - 1, max_iters - 1); else if ((pivot != rightmostValue) && (pos >= pivot_index)) - qselect_(arr, pos, pivot_index, right, max_iters - 1); + qselect_( + arr, pos, pivot_index, right, max_iters - 1); } // Quicksort routines: template -X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool descending) +X86_SIMD_SORT_INLINE void +xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool descending) { if (arrsize > 1) { if constexpr (std::is_floating_point_v) { @@ -607,19 +616,25 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf(arr, arrsize); } - if (descending){ - qsort_, T>(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - }else{ - qsort_, T>(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + if (descending) { + qsort_, T>( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } + else { + qsort_, T>( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } replace_inf_with_nan(arr, arrsize, nan_count, descending); } else { UNUSED(hasnan); - if (descending){ - qsort_, T>(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - }else{ - qsort_, T>(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + if (descending) { + qsort_, T>( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } + else { + qsort_, T>( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } } @@ -627,25 +642,30 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool // Quick select methods template -X86_SIMD_SORT_INLINE void -xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) +X86_SIMD_SORT_INLINE void xss_qselect( + T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) { - if (descending){ + if (descending) { arrsize_t index_first_elem = 0; if constexpr (std::is_floating_point_v) { if (UNLIKELY(hasnan)) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); } } - + arrsize_t size_without_nans = arrsize - index_first_elem; - + UNUSED(hasnan); if (index_first_elem <= k) { qselect_, T>( - arr, k, index_first_elem, arrsize - 1, 2 * (arrsize_t)log2(size_without_nans)); + arr, + k, + index_first_elem, + arrsize - 1, + 2 * (arrsize_t)log2(size_without_nans)); } - }else{ + } + else { arrsize_t indx_last_elem = arrsize - 1; if constexpr (std::is_floating_point_v) { if (UNLIKELY(hasnan)) { @@ -655,15 +675,19 @@ xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending UNUSED(hasnan); if (indx_last_elem >= k) { qselect_, T>( - arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); + arr, + k, + 0, + indx_last_elem, + 2 * (arrsize_t)log2(indx_last_elem)); } } } // Partial sort methods: template -X86_SIMD_SORT_INLINE void -xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) +X86_SIMD_SORT_INLINE void xss_partial_qsort( + T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) { xss_qselect(arr, k - 1, arrsize, hasnan, descending); xss_qsort(arr, k - 1, hasnan, descending); @@ -671,20 +695,28 @@ xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool desc #define DEFINE_METHODS(ISA, VTYPE) \ template \ - X86_SIMD_SORT_INLINE void ISA##_qsort( \ - T *arr, arrsize_t size, bool hasnan = false, bool descending = false) \ + X86_SIMD_SORT_INLINE void ISA##_qsort(T *arr, \ + arrsize_t size, \ + bool hasnan = false, \ + bool descending = false) \ { \ xss_qsort(arr, size, hasnan, descending); \ } \ template \ - X86_SIMD_SORT_INLINE void ISA##_qselect( \ - T *arr, arrsize_t k, arrsize_t size, bool hasnan = false, bool descending = false) \ + X86_SIMD_SORT_INLINE void ISA##_qselect(T *arr, \ + arrsize_t k, \ + arrsize_t size, \ + bool hasnan = false, \ + bool descending = false) \ { \ xss_qselect(arr, k, size, hasnan, descending); \ } \ template \ - X86_SIMD_SORT_INLINE void ISA##_partial_qsort( \ - T *arr, arrsize_t k, arrsize_t size, bool hasnan = false, bool descending = false) \ + X86_SIMD_SORT_INLINE void ISA##_partial_qsort(T *arr, \ + arrsize_t k, \ + arrsize_t size, \ + bool hasnan = false, \ + bool descending = false) \ { \ xss_partial_qsort(arr, k, size, hasnan, descending); \ } diff --git a/src/xss-network-qsort.hpp b/src/xss-network-qsort.hpp index 0af2597e..dd299507 100644 --- a/src/xss-network-qsort.hpp +++ b/src/xss-network-qsort.hpp @@ -7,7 +7,10 @@ template X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); -template +template X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) { if constexpr (numVecs == 1) { @@ -53,7 +56,11 @@ X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) * merge_n<8> = [a,a,a,a,b,b,b,b] */ -template +template X86_SIMD_SORT_FINLINE void internal_merge_n_vec(typename vtype::reg_t *reg) { using reg_t = typename vtype::reg_t; @@ -112,7 +119,8 @@ X86_SIMD_SORT_FINLINE void merge_substep_n_vec(reg_t *regs) } merge_substep_n_vec(regs); - merge_substep_n_vec(regs + numVecs / 2); + merge_substep_n_vec(regs + + numVecs / 2); } template +template X86_SIMD_SORT_FINLINE void sort_vectors(reg_t *vecs) { /* Run the initial sorting network to sort the columns of the [numVecs x @@ -158,7 +169,10 @@ X86_SIMD_SORT_FINLINE void sort_vectors(reg_t *vecs) merge_n_vec(vecs); } -template +template X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) { static_assert(numVecs > 0, "numVecs should be > 0"); @@ -189,8 +203,9 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) // Masked part of the load X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - vecs[i] = vtype::mask_loadu( - comparator::rightmostPossibleVec(), ioMasks[j], arr + i * vtype::numlanes); + vecs[i] = vtype::mask_loadu(comparator::rightmostPossibleVec(), + ioMasks[j], + arr + i * vtype::numlanes); } sort_vectors(vecs); diff --git a/src/xss-optimal-networks.hpp b/src/xss-optimal-networks.hpp index 835c9224..e722b1f1 100644 --- a/src/xss-optimal-networks.hpp +++ b/src/xss-optimal-networks.hpp @@ -1,7 +1,9 @@ // All of these sources files are generated from the optimal networks described in // https://bertdobbelaere.github.io/sorting_networks.html -template +template X86_SIMD_SORT_FINLINE void optimal_sort_4(reg_t *vecs) { comparator::COEX(vecs[0], vecs[2]); @@ -13,7 +15,9 @@ X86_SIMD_SORT_FINLINE void optimal_sort_4(reg_t *vecs) comparator::COEX(vecs[1], vecs[2]); } -template +template X86_SIMD_SORT_FINLINE void optimal_sort_8(reg_t *vecs) { comparator::COEX(vecs[0], vecs[2]); @@ -42,7 +46,9 @@ X86_SIMD_SORT_FINLINE void optimal_sort_8(reg_t *vecs) comparator::COEX(vecs[5], vecs[6]); } -template +template X86_SIMD_SORT_FINLINE void optimal_sort_16(reg_t *vecs) { comparator::COEX(vecs[0], vecs[13]); @@ -116,9 +122,11 @@ X86_SIMD_SORT_FINLINE void optimal_sort_16(reg_t *vecs) comparator::COEX(vecs[8], vecs[9]); } -template +template X86_SIMD_SORT_FINLINE void optimal_sort_32(reg_t *vecs) -{ +{ comparator::COEX(vecs[0], vecs[1]); comparator::COEX(vecs[2], vecs[3]); comparator::COEX(vecs[4], vecs[5]); diff --git a/src/xss-pivot-selection.hpp b/src/xss-pivot-selection.hpp index f6ed6eeb..5d955616 100644 --- a/src/xss-pivot-selection.hpp +++ b/src/xss-pivot-selection.hpp @@ -129,7 +129,8 @@ get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right) if (smallest == largest) { // We have a very unlucky sample, or the array is constant / near constant // Run a special function meant to deal with this situation - return get_pivot_near_constant(arr, median, left, right); + return get_pivot_near_constant( + arr, median, left, right); } else if (median != smallest && median != largest) { // We have a normal sample; use it's median @@ -138,12 +139,14 @@ get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right) else if (median == smallest) { // We will either return the median or the next value larger than the median, // depending on the comparator (see xss-common-comparators.hpp for more details) - return pivot_results(comparator::choosePivotMedianIsSmallest(median)); + return pivot_results( + comparator::choosePivotMedianIsSmallest(median)); } else if (median == largest) { // We will either return the median or the next value smaller than the median, // depending on the comparator (see xss-common-comparators.hpp for more details) - return pivot_results(comparator::choosePivotMedianIsLargest(median)); + return pivot_results( + comparator::choosePivotMedianIsLargest(median)); } else { // Should be unreachable @@ -219,7 +222,8 @@ get_pivot_near_constant(type_t *arr, // (note that larger is determined using the provided comparator, so it might actually be the smaller one) // We can also skip recursing, as it is guaranteed both partitions are constant after partitioning with the chosen value // TODO this logic now assumes we use greater than or equal to specifically when partitioning, might be worth noting that somewhere - type_t pivot = std::max(value1, commonValue, comparator::STDSortComparator); + type_t pivot + = std::max(value1, commonValue, comparator::STDSortComparator); return pivot_results(pivot, pivot_result_t::Only2Values); } diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 9eaca32c..51af2696 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -33,7 +33,7 @@ TYPED_TEST_P(simdsort, test_qsort) bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { std::vector basearr = get_array(type, size); - + // Ascending order std::vector arr = basearr; std::vector sortedarr = arr; @@ -42,7 +42,7 @@ TYPED_TEST_P(simdsort, test_qsort) compare>()); x86simdsort::qsort(arr.data(), arr.size(), hasnan); IS_SORTED(sortedarr, arr, type); - + // Descending order arr = basearr; sortedarr = arr; @@ -51,7 +51,7 @@ TYPED_TEST_P(simdsort, test_qsort) compare>()); x86simdsort::qsort(arr.data(), arr.size(), hasnan, true); IS_SORTED(sortedarr, arr, type); - + arr.clear(); sortedarr.clear(); } @@ -83,7 +83,7 @@ TYPED_TEST_P(simdsort, test_qselect) for (auto size : this->arrsize) { size_t k = rand() % size; std::vector basearr = get_array(type, size); - + // Ascending order std::vector arr = basearr; std::vector sortedarr = arr; @@ -93,7 +93,7 @@ TYPED_TEST_P(simdsort, test_qselect) compare>()); x86simdsort::qselect(arr.data(), k, arr.size(), hasnan); IS_ARR_PARTITIONED(arr, k, sortedarr[k], type); - + // Descending order arr = basearr; sortedarr = arr; @@ -103,7 +103,7 @@ TYPED_TEST_P(simdsort, test_qselect) compare>()); x86simdsort::qselect(arr.data(), k, arr.size(), hasnan, true); IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); - + arr.clear(); sortedarr.clear(); } @@ -138,7 +138,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort) // k should be at least 1 size_t k = std::max((size_t)1, rand() % size); std::vector basearr = get_array(type, size); - + // Ascending order std::vector arr = basearr; std::vector sortedarr = arr; @@ -147,7 +147,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort) compare>()); x86simdsort::partial_qsort(arr.data(), k, arr.size(), hasnan); IS_ARR_PARTIALSORTED(arr, k, sortedarr, type); - + // Descending order arr = basearr; sortedarr = arr; @@ -156,7 +156,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort) compare>()); x86simdsort::partial_qsort(arr.data(), k, arr.size(), hasnan, true); IS_ARR_PARTIALSORTED(arr, k, sortedarr, type); - + arr.clear(); sortedarr.clear(); } From 06697014d449df0b042276bc20acd974938455cf Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 12 Mar 2024 12:14:10 -0700 Subject: [PATCH 07/18] Added reverse benchmark for quicksort --- benchmarks/bench-qsort.hpp | 42 +++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/benchmarks/bench-qsort.hpp b/benchmarks/bench-qsort.hpp index f95b05ba..a6e1c794 100644 --- a/benchmarks/bench-qsort.hpp +++ b/benchmarks/bench-qsort.hpp @@ -36,9 +36,49 @@ static void simdsort(benchmark::State &state, Args &&...args) } } +template +static void scalar_revsort(benchmark::State &state, Args &&...args) +{ + // Get args + auto args_tuple = std::make_tuple(std::move(args)...); + size_t arrsize = std::get<0>(args_tuple); + std::string arrtype = std::get<1>(args_tuple); + // set up array + std::vector arr = get_array(arrtype, arrsize); + std::vector arr_bkp = arr; + // benchmark + for (auto _ : state) { + std::sort(arr.rbegin(), arr.rend()); + state.PauseTiming(); + arr = arr_bkp; + state.ResumeTiming(); + } +} + +template +static void simd_revsort(benchmark::State &state, Args &&...args) +{ + // Get args + auto args_tuple = std::make_tuple(std::move(args)...); + size_t arrsize = std::get<0>(args_tuple); + std::string arrtype = std::get<1>(args_tuple); + // set up array + std::vector arr = get_array(arrtype, arrsize); + std::vector arr_bkp = arr; + // benchmark + for (auto _ : state) { + x86simdsort::qsort(arr.data(), arrsize, false, true); + state.PauseTiming(); + arr = arr_bkp; + state.ResumeTiming(); + } +} + #define BENCH_BOTH_QSORT(type) \ BENCH_SORT(simdsort, type) \ - BENCH_SORT(scalarsort, type) + BENCH_SORT(scalarsort, type) \ + BENCH_SORT(simd_revsort, type) \ + BENCH_SORT(scalar_revsort, type) BENCH_BOTH_QSORT(uint64_t) BENCH_BOTH_QSORT(int64_t) From e7e82980fb7bdf97dfca2bbd3707a8771bdb1718 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 12 Mar 2024 15:21:56 -0700 Subject: [PATCH 08/18] Fixed fp16 code issues --- src/avx512fp16-16bit-qsort.hpp | 96 ++++++++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 23 deletions(-) diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 28dff1c0..53ad856e 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -179,51 +179,101 @@ X86_SIMD_SORT_INLINE_ONLY bool is_a_nan<_Float16>(_Float16 elem) } template <> -X86_SIMD_SORT_INLINE_ONLY void -replace_inf_with_nan(_Float16 *arr, arrsize_t size, arrsize_t nan_count) +X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan(_Float16 *arr, + arrsize_t size, + arrsize_t nan_count, + bool descending) { Fp16Bits val; val.i_ = 0x7c01; - for (arrsize_t ii = size - 1; nan_count > 0; --ii) { - arr[ii] = val.f_; - nan_count -= 1; + + if (descending) { + for (arrsize_t ii = 0; nan_count > 0; ++ii) { + arr[ii] = val.f_; + nan_count -= 1; + } + } + else { + for (arrsize_t ii = size - 1; nan_count > 0; --ii) { + arr[ii] = val.f_; + nan_count -= 1; + } } } /* Specialized template function for _Float16 qsort_*/ template <> X86_SIMD_SORT_INLINE_ONLY void -avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan) +avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, bool descending) { + using vtype = zmm_vector<_Float16>; + if (arrsize > 1) { arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { - nan_count = replace_nan_with_inf, _Float16>( - arr, arrsize); + nan_count = replace_nan_with_inf(arr, arrsize); + } + if (descending) { + qsort_, _Float16>( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + } + else { + qsort_, _Float16>( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } - qsort_, _Float16>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); + replace_inf_with_nan(arr, arrsize, nan_count, descending); } } template <> -X86_SIMD_SORT_INLINE_ONLY void -avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) +X86_SIMD_SORT_INLINE_ONLY void avx512_qselect(_Float16 *arr, + arrsize_t k, + arrsize_t arrsize, + bool hasnan, + bool descending) { - arrsize_t indx_last_elem = arrsize - 1; - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + using vtype = zmm_vector<_Float16>; + + if (descending) { + arrsize_t index_first_elem = 0; + if (UNLIKELY(hasnan)) { + index_first_elem = move_nans_to_start_of_array(arr, arrsize); + } + + arrsize_t size_without_nans = arrsize - index_first_elem; + + if (index_first_elem <= k) { + qselect_, _Float16>( + arr, + k, + index_first_elem, + arrsize - 1, + 2 * (arrsize_t)log2(size_without_nans)); + } } - if (indx_last_elem >= k) { - qselect_, _Float16>( - arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); + else { + arrsize_t indx_last_elem = arrsize - 1; + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + + if (indx_last_elem >= k) { + qselect_, _Float16>( + arr, + k, + 0, + indx_last_elem, + 2 * (arrsize_t)log2(indx_last_elem)); + } } } template <> -X86_SIMD_SORT_INLINE_ONLY void -avx512_partial_qsort(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) +X86_SIMD_SORT_INLINE_ONLY void avx512_partial_qsort(_Float16 *arr, + arrsize_t k, + arrsize_t arrsize, + bool hasnan, + bool descending) { - avx512_qselect(arr, k - 1, arrsize, hasnan); - avx512_qsort(arr, k - 1, hasnan); + avx512_qselect(arr, k - 1, arrsize, hasnan, descending); + avx512_qsort(arr, k - 1, hasnan, descending); } #endif // AVX512FP16_QSORT_16BIT From d032c31606607c725799e8da0f8b18e9d57764a4 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Wed, 13 Mar 2024 13:42:07 -0700 Subject: [PATCH 09/18] additional clang-format --- lib/x86simdsort-avx2.cpp | 6 ++-- lib/x86simdsort-icl.cpp | 24 ++++++++++++--- lib/x86simdsort-internal.h | 51 +++++++++++++++++++++--------- lib/x86simdsort-scalar.h | 63 ++++++++++++++++++++++---------------- lib/x86simdsort-skx.cpp | 6 ++-- lib/x86simdsort-spr.cpp | 12 ++++++-- lib/x86simdsort.cpp | 9 ++++-- lib/x86simdsort.h | 17 +++++++--- tests/test-qsort-common.h | 25 +++++++-------- 9 files changed, 142 insertions(+), 71 deletions(-) diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 3785014b..345653d9 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -12,12 +12,14 @@ avx2_qsort(arr, arrsize, hasnan, descending); \ } \ template <> \ - void qselect(type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void qselect( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ avx2_qselect(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void partial_qsort( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ avx2_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ diff --git a/lib/x86simdsort-icl.cpp b/lib/x86simdsort-icl.cpp index bbde114d..20095369 100644 --- a/lib/x86simdsort-icl.cpp +++ b/lib/x86simdsort-icl.cpp @@ -10,12 +10,20 @@ namespace avx512 { avx512_qsort(arr, size, hasnan, descending); } template <> - void qselect(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, bool descending) + void qselect(uint16_t *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { avx512_qselect(arr, k, arrsize, hasnan, descending); } template <> - void partial_qsort(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, bool descending) + void partial_qsort(uint16_t *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { avx512_partial_qsort(arr, k, arrsize, hasnan, descending); } @@ -25,12 +33,20 @@ namespace avx512 { avx512_qsort(arr, size, hasnan, descending); } template <> - void qselect(int16_t *arr, size_t k, size_t arrsize, bool hasnan, bool descending) + void qselect(int16_t *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { avx512_qselect(arr, k, arrsize, hasnan, descending); } template <> - void partial_qsort(int16_t *arr, size_t k, size_t arrsize, bool hasnan, bool descending) + void partial_qsort(int16_t *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { avx512_partial_qsort(arr, k, arrsize, hasnan, descending); } diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 6452e9fa..dad32b91 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -8,19 +8,26 @@ namespace xss { namespace avx512 { // quicksort template - XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void + qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false); // quickselect template - XSS_HIDE_SYMBOL void - qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void qselect(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // partial sort template - XSS_HIDE_SYMBOL void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void partial_qsort(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector @@ -33,19 +40,26 @@ namespace avx512 { namespace avx2 { // quicksort template - XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void + qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false); // quickselect template - XSS_HIDE_SYMBOL void - qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void qselect(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // partial sort template - XSS_HIDE_SYMBOL void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void partial_qsort(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector @@ -58,19 +72,26 @@ namespace avx2 { namespace scalar { // quicksort template - XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void + qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false); // quickselect template - XSS_HIDE_SYMBOL void - qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void qselect(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // partial sort template - XSS_HIDE_SYMBOL void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void partial_qsort(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index fc3ff9d4..284d4c28 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -28,16 +28,16 @@ namespace scalar { void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed) { if (hasnan) { - if (reversed){ + if (reversed) { std::sort(arr, arr + arrsize, compare>()); - }else{ + } + else { std::sort(arr, arr + arrsize, compare>()); } } else { - if (reversed){ - std::sort(arr, arr + arrsize, std::greater()); - }else{ + if (reversed) { std::sort(arr, arr + arrsize, std::greater()); } + else { std::sort(arr, arr + arrsize, std::less()); } } @@ -46,43 +46,54 @@ namespace scalar { void qselect(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) { if (hasnan) { - if (reversed){ - std::nth_element( - arr, arr + k, arr + arrsize, compare>()); - }else{ - std::nth_element( - arr, arr + k, arr + arrsize, compare>()); + if (reversed) { + std::nth_element(arr, + arr + k, + arr + arrsize, + compare>()); + } + else { + std::nth_element(arr, + arr + k, + arr + arrsize, + compare>()); } } else { - if (reversed){ + if (reversed) { std::nth_element( arr, arr + k, arr + arrsize, std::greater()); - }else{ - std::nth_element( - arr, arr + k, arr + arrsize, std::less()); + } + else { + std::nth_element(arr, arr + k, arr + arrsize, std::less()); } } } template - void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) + void + partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) { if (hasnan) { - if (reversed){ - std::partial_sort( - arr, arr + k, arr + arrsize, compare>()); - }else{ - std::partial_sort( - arr, arr + k, arr + arrsize, compare>()); + if (reversed) { + std::partial_sort(arr, + arr + k, + arr + arrsize, + compare>()); + } + else { + std::partial_sort(arr, + arr + k, + arr + arrsize, + compare>()); } } else { - if (reversed){ + if (reversed) { std::partial_sort( arr, arr + k, arr + arrsize, std::greater()); - }else{ - std::partial_sort( - arr, arr + k, arr + arrsize, std::less()); + } + else { + std::partial_sort(arr, arr + k, arr + arrsize, std::less()); } } } diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index e923ff10..4a1c2a9f 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -12,12 +12,14 @@ avx512_qsort(arr, arrsize, hasnan, descending); \ } \ template <> \ - void qselect(type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void qselect( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ avx512_qselect(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void partial_qsort( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ avx512_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ diff --git a/lib/x86simdsort-spr.cpp b/lib/x86simdsort-spr.cpp index 06aea9af..d19de1ed 100644 --- a/lib/x86simdsort-spr.cpp +++ b/lib/x86simdsort-spr.cpp @@ -10,12 +10,20 @@ namespace avx512 { avx512_qsort(arr, size, hasnan, descending); } template <> - void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, bool descending) + void qselect(_Float16 *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { avx512_qselect(arr, k, arrsize, hasnan, descending); } template <> - void partial_qsort(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, bool descending) + void partial_qsort(_Float16 *arr, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { avx512_partial_qsort(arr, k, arrsize, hasnan, descending); } diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index fa86e239..21c8b34f 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -68,16 +68,19 @@ namespace x86simdsort { static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void qselect(TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void qselect( \ + TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ - static void (*internal_partial_qsort##TYPE)(TYPE *, size_t, size_t, bool, bool) \ + static void (*internal_partial_qsort##TYPE)( \ + TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void partial_qsort(TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void partial_qsort( \ + TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending); \ } diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index 1f2aa118..42d5247f 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -14,17 +14,24 @@ namespace x86simdsort { // quicksort template -XSS_EXPORT_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); +XSS_EXPORT_SYMBOL void +qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // quickselect template -XSS_EXPORT_SYMBOL void -qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); +XSS_EXPORT_SYMBOL void qselect(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // partial sort template -XSS_EXPORT_SYMBOL void -partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); +XSS_EXPORT_SYMBOL void partial_qsort(T *arr, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argsort template diff --git a/tests/test-qsort-common.h b/tests/test-qsort-common.h index fd338df9..af1578ed 100644 --- a/tests/test-qsort-common.h +++ b/tests/test-qsort-common.h @@ -49,10 +49,10 @@ void IS_ARR_PARTITIONED(std::vector arr, std::string type, bool descending = false) { - if (!descending){ + if (!descending) { auto cmp_eq = compare>(); auto cmp_less = compare>(); - + auto cmp_leq = compare>(); auto cmp_geq = compare>(); @@ -62,24 +62,25 @@ void IS_ARR_PARTITIONED(std::vector arr, } // ( 2) Elements to the left of k should be atmost arr[k] if (k >= 1) { - T max_left - = *std::max_element(arr.begin(), arr.begin() + k - 1, cmp_less); + T max_left = *std::max_element( + arr.begin(), arr.begin() + k - 1, cmp_less); if (!cmp_geq(arr[k], max_left)) { REPORT_FAIL("incorrect left partition", arr.size(), type, k); } } // 3) Elements to the right of k should be atleast arr[k] if (k != (size_t)(arr.size() - 1)) { - T min_right - = *std::min_element(arr.begin() + k + 1, arr.end(), cmp_less); + T min_right = *std::min_element( + arr.begin() + k + 1, arr.end(), cmp_less); if (!cmp_leq(arr[k], min_right)) { REPORT_FAIL("incorrect right partition", arr.size(), type, k); } } - }else{ + } + else { auto cmp_eq = compare>(); auto cmp_less = compare>(); - + auto cmp_leq = compare>(); auto cmp_geq = compare>(); @@ -89,16 +90,16 @@ void IS_ARR_PARTITIONED(std::vector arr, } // ( 2) Elements to the left of k should be atleast arr[k] if (k >= 1) { - T max_left - = *std::max_element(arr.begin(), arr.begin() + k - 1, cmp_less); + T max_left = *std::max_element( + arr.begin(), arr.begin() + k - 1, cmp_less); if (!cmp_geq(arr[k], max_left)) { REPORT_FAIL("incorrect left partition", arr.size(), type, k); } } // 3) Elements to the right of k should be atmost arr[k] if (k != (size_t)(arr.size() - 1)) { - T min_right - = *std::min_element(arr.begin() + k + 1, arr.end(), cmp_less); + T min_right = *std::min_element( + arr.begin() + k + 1, arr.end(), cmp_less); if (!cmp_leq(arr[k], min_right)) { REPORT_FAIL("incorrect right partition", arr.size(), type, k); } From a63c07fca4f9ced6360c8f7d1487b3aae6950615 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 19 Mar 2024 11:34:35 -0700 Subject: [PATCH 10/18] Add utility function to get the compare function object --- lib/x86simdsort-scalar.h | 88 ++++++++++++++-------------------------- 1 file changed, 30 insertions(+), 58 deletions(-) diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 284d4c28..6afc7287 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -4,8 +4,10 @@ namespace xss { namespace utils { - /* O(1) permute array in place: stolen from - * http://www.davidespataro.it/apply-a-permutation-to-a-vector */ + /* + * O(1) permute array in place: stolen from + * http://www.davidespataro.it/apply-a-permutation-to-a-vector + */ template void apply_permutation_in_place(T *arr, std::vector arg) { @@ -21,81 +23,51 @@ namespace utils { arg[curr] = curr; } } -} // namespace utils - -namespace scalar { template - void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed) + decltype(auto) get_cmp_func(bool hasnan, bool reverse) { + std::function cmp; if (hasnan) { - if (reversed) { - std::sort(arr, arr + arrsize, compare>()); - } + if (reverse == true) { cmp = compare>(); } else { - std::sort(arr, arr + arrsize, compare>()); + cmp = compare>(); } } else { - if (reversed) { std::sort(arr, arr + arrsize, std::greater()); } + if (reverse == true) { cmp = std::greater(); } else { - std::sort(arr, arr + arrsize, std::less()); + cmp = std::less(); } } + return cmp; } +} // namespace utils + +namespace scalar { + template + void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed) + { + std::sort(arr, + arr + arrsize, + xss::utils::get_cmp_func(hasnan, reversed)); + } + template void qselect(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) { - if (hasnan) { - if (reversed) { - std::nth_element(arr, - arr + k, - arr + arrsize, - compare>()); - } - else { - std::nth_element(arr, - arr + k, - arr + arrsize, - compare>()); - } - } - else { - if (reversed) { - std::nth_element( - arr, arr + k, arr + arrsize, std::greater()); - } - else { - std::nth_element(arr, arr + k, arr + arrsize, std::less()); - } - } + std::nth_element(arr, + arr + k, + arr + arrsize, + xss::utils::get_cmp_func(hasnan, reversed)); } template void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) { - if (hasnan) { - if (reversed) { - std::partial_sort(arr, - arr + k, - arr + arrsize, - compare>()); - } - else { - std::partial_sort(arr, - arr + k, - arr + arrsize, - compare>()); - } - } - else { - if (reversed) { - std::partial_sort( - arr, arr + k, arr + arrsize, std::greater()); - } - else { - std::partial_sort(arr, arr + k, arr + arrsize, std::less()); - } - } + std::partial_sort(arr, + arr + k, + arr + arrsize, + xss::utils::get_cmp_func(hasnan, reversed)); } template std::vector argsort(T *arr, size_t arrsize, bool hasnan) From 098032c65f8c417bc80265f11857d04e94c16c60 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 19 Mar 2024 16:10:48 -0700 Subject: [PATCH 11/18] Split ascending and descending tests --- tests/test-qsort.cpp | 69 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 51af2696..72e173a4 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -27,7 +27,7 @@ class simdsort : public ::testing::Test { TYPED_TEST_SUITE_P(simdsort); -TYPED_TEST_P(simdsort, test_qsort) +TYPED_TEST_P(simdsort, test_qsort_ascending) { for (auto type : this->arrtype) { bool hasnan = (type == "rand_with_nan") ? true : false; @@ -43,9 +43,22 @@ TYPED_TEST_P(simdsort, test_qsort) x86simdsort::qsort(arr.data(), arr.size(), hasnan); IS_SORTED(sortedarr, arr, type); + arr.clear(); + sortedarr.clear(); + } + } +} + +TYPED_TEST_P(simdsort, test_qsort_descending) +{ + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + std::vector basearr = get_array(type, size); + // Descending order - arr = basearr; - sortedarr = arr; + std::vector arr = basearr; + std::vector sortedarr = arr; std::sort(sortedarr.begin(), sortedarr.end(), compare>()); @@ -76,7 +89,7 @@ TYPED_TEST_P(simdsort, test_argsort) } } -TYPED_TEST_P(simdsort, test_qselect) +TYPED_TEST_P(simdsort, test_qselect_ascending) { for (auto type : this->arrtype) { bool hasnan = (type == "rand_with_nan") ? true : false; @@ -94,9 +107,23 @@ TYPED_TEST_P(simdsort, test_qselect) x86simdsort::qselect(arr.data(), k, arr.size(), hasnan); IS_ARR_PARTITIONED(arr, k, sortedarr[k], type); + arr.clear(); + sortedarr.clear(); + } + } +} + +TYPED_TEST_P(simdsort, test_qselect_descending) +{ + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + size_t k = rand() % size; + std::vector basearr = get_array(type, size); + // Descending order - arr = basearr; - sortedarr = arr; + std::vector arr = basearr; + std::vector sortedarr = arr; std::nth_element(sortedarr.begin(), sortedarr.begin() + k, sortedarr.end(), @@ -130,7 +157,7 @@ TYPED_TEST_P(simdsort, test_argselect) } } -TYPED_TEST_P(simdsort, test_partial_qsort) +TYPED_TEST_P(simdsort, test_partial_qsort_ascending) { for (auto type : this->arrtype) { bool hasnan = (type == "rand_with_nan") ? true : false; @@ -148,9 +175,24 @@ TYPED_TEST_P(simdsort, test_partial_qsort) x86simdsort::partial_qsort(arr.data(), k, arr.size(), hasnan); IS_ARR_PARTIALSORTED(arr, k, sortedarr, type); + arr.clear(); + sortedarr.clear(); + } + } +} + +TYPED_TEST_P(simdsort, test_partial_qsort_descending) +{ + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + // k should be at least 1 + size_t k = std::max((size_t)1, rand() % size); + std::vector basearr = get_array(type, size); + // Descending order - arr = basearr; - sortedarr = arr; + std::vector arr = basearr; + std::vector sortedarr = arr; std::sort(sortedarr.begin(), sortedarr.end(), compare>()); @@ -197,11 +239,14 @@ TYPED_TEST_P(simdsort, test_comparator) } REGISTER_TYPED_TEST_SUITE_P(simdsort, - test_qsort, + test_qsort_ascending, + test_qsort_descending, test_argsort, test_argselect, - test_qselect, - test_partial_qsort, + test_qselect_ascending, + test_qselect_descending, + test_partial_qsort_ascending, + test_partial_qsort_descending, test_comparator); using QSortTestTypes = testing::Types Date: Thu, 21 Mar 2024 09:42:24 -0700 Subject: [PATCH 12/18] Changed from boolean descending flag to enum --- benchmarks/bench-qsort.hpp | 5 +++- lib/x86simdsort-avx2.cpp | 24 +++++++++++------ lib/x86simdsort-icl.cpp | 26 ++++++++++--------- lib/x86simdsort-internal.h | 24 +++++++++++------ lib/x86simdsort-orders.h | 10 ++++++++ lib/x86simdsort-scalar.h | 30 +++++++++++++--------- lib/x86simdsort-skx.cpp | 24 +++++++++++------ lib/x86simdsort-spr.cpp | 14 +++++----- lib/x86simdsort.cpp | 32 +++++++++++++++-------- lib/x86simdsort.h | 14 +++++++--- src/avx512fp16-16bit-qsort.hpp | 17 ++++++------ src/xss-common-qsort.h | 47 ++++++++++++++++++++-------------- tests/test-qsort.cpp | 20 ++++++++++++--- 13 files changed, 186 insertions(+), 101 deletions(-) create mode 100644 lib/x86simdsort-orders.h diff --git a/benchmarks/bench-qsort.hpp b/benchmarks/bench-qsort.hpp index a6e1c794..7a105428 100644 --- a/benchmarks/bench-qsort.hpp +++ b/benchmarks/bench-qsort.hpp @@ -67,7 +67,10 @@ static void simd_revsort(benchmark::State &state, Args &&...args) std::vector arr_bkp = arr; // benchmark for (auto _ : state) { - x86simdsort::qsort(arr.data(), arrsize, false, true); + x86simdsort::qsort(arr.data(), + arrsize, + false, + x86simdsort::sort_order::sort_descending); state.PauseTiming(); arr = arr_bkp; state.ResumeTiming(); diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 345653d9..165241e1 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -5,23 +5,31 @@ #include "xss-common-argsort.h" #include "x86simdsort-internal.h" +using x86simdsort::sort_order; + #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ + void qsort(type *arr, size_t arrsize, bool hasnan, sort_order order) \ { \ - avx2_qsort(arr, arrsize, hasnan, descending); \ + avx2_qsort(arr, arrsize, hasnan, order); \ } \ template <> \ - void qselect( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void qselect(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + sort_order order) \ { \ - avx2_qselect(arr, k, arrsize, hasnan, descending); \ + avx2_qselect(arr, k, arrsize, hasnan, order); \ } \ template <> \ - void partial_qsort( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void partial_qsort(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + sort_order order) \ { \ - avx2_partial_qsort(arr, k, arrsize, hasnan, descending); \ + avx2_partial_qsort(arr, k, arrsize, hasnan, order); \ } \ template <> \ std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ diff --git a/lib/x86simdsort-icl.cpp b/lib/x86simdsort-icl.cpp index 20095369..d051643f 100644 --- a/lib/x86simdsort-icl.cpp +++ b/lib/x86simdsort-icl.cpp @@ -2,53 +2,55 @@ #include "avx512-16bit-qsort.hpp" #include "x86simdsort-internal.h" +using x86simdsort::sort_order; + namespace xss { namespace avx512 { template <> - void qsort(uint16_t *arr, size_t size, bool hasnan, bool descending) + void qsort(uint16_t *arr, size_t size, bool hasnan, sort_order order) { - avx512_qsort(arr, size, hasnan, descending); + avx512_qsort(arr, size, hasnan, order); } template <> void qselect(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + sort_order order) { - avx512_qselect(arr, k, arrsize, hasnan, descending); + avx512_qselect(arr, k, arrsize, hasnan, order); } template <> void partial_qsort(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + sort_order order) { - avx512_partial_qsort(arr, k, arrsize, hasnan, descending); + avx512_partial_qsort(arr, k, arrsize, hasnan, order); } template <> - void qsort(int16_t *arr, size_t size, bool hasnan, bool descending) + void qsort(int16_t *arr, size_t size, bool hasnan, sort_order order) { - avx512_qsort(arr, size, hasnan, descending); + avx512_qsort(arr, size, hasnan, order); } template <> void qselect(int16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + sort_order order) { - avx512_qselect(arr, k, arrsize, hasnan, descending); + avx512_qselect(arr, k, arrsize, hasnan, order); } template <> void partial_qsort(int16_t *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + sort_order order) { - avx512_partial_qsort(arr, k, arrsize, hasnan, descending); + avx512_partial_qsort(arr, k, arrsize, hasnan, order); } } // namespace avx512 } // namespace xss diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index dad32b91..94e50b82 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -4,12 +4,16 @@ #include #include +using x86simdsort::sort_order; + namespace xss { namespace avx512 { // quicksort template - XSS_HIDE_SYMBOL void - qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void qsort(T *arr, + size_t arrsize, + bool hasnan = false, + sort_order order = sort_order::sort_ascending); // key-value quicksort template XSS_EXPORT_SYMBOL void @@ -20,14 +24,15 @@ namespace avx512 { size_t k, size_t arrsize, bool hasnan = false, - bool descending = false); + sort_order order = sort_order::sort_ascending); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, - bool descending = false); + sort_order order + = sort_order::sort_ascending); // argsort template XSS_HIDE_SYMBOL std::vector @@ -40,8 +45,10 @@ namespace avx512 { namespace avx2 { // quicksort template - XSS_HIDE_SYMBOL void - qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void qsort(T *arr, + size_t arrsize, + bool hasnan = false, + sort_order order = sort_order::sort_ascending); // key-value quicksort template XSS_EXPORT_SYMBOL void @@ -52,14 +59,15 @@ namespace avx2 { size_t k, size_t arrsize, bool hasnan = false, - bool descending = false); + sort_order order = sort_order::sort_ascending); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, - bool descending = false); + sort_order order + = sort_order::sort_ascending); // argsort template XSS_HIDE_SYMBOL std::vector diff --git a/lib/x86simdsort-orders.h b/lib/x86simdsort-orders.h new file mode 100644 index 00000000..02fbc005 --- /dev/null +++ b/lib/x86simdsort-orders.h @@ -0,0 +1,10 @@ +#ifndef XSS_ORDERS +#define XSS_ORDERS + +namespace x86simdsort { + +enum class sort_order : int { sort_ascending, sort_descending }; + +} // namespace x86simdsort + +#endif \ No newline at end of file diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 6afc7287..0e635275 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -2,6 +2,8 @@ #include #include +using x86simdsort::sort_order; + namespace xss { namespace utils { /* @@ -24,17 +26,21 @@ namespace utils { } } template - decltype(auto) get_cmp_func(bool hasnan, bool reverse) + decltype(auto) get_cmp_func(bool hasnan, sort_order order) { std::function cmp; if (hasnan) { - if (reverse == true) { cmp = compare>(); } + if (order == sort_order::sort_descending) { + cmp = compare>(); + } else { cmp = compare>(); } } else { - if (reverse == true) { cmp = std::greater(); } + if (order == sort_order::sort_descending) { + cmp = std::greater(); + } else { cmp = std::less(); } @@ -45,29 +51,29 @@ namespace utils { namespace scalar { template - void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed) + void qsort(T *arr, size_t arrsize, bool hasnan, sort_order order) { - std::sort(arr, - arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed)); + std::sort( + arr, arr + arrsize, xss::utils::get_cmp_func(hasnan, order)); } template - void qselect(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) + void + qselect(T *arr, size_t k, size_t arrsize, bool hasnan, sort_order order) { std::nth_element(arr, arr + k, arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed)); + xss::utils::get_cmp_func(hasnan, order)); } template - void - partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) + void partial_qsort( + T *arr, size_t k, size_t arrsize, bool hasnan, sort_order order) { std::partial_sort(arr, arr + k, arr + arrsize, - xss::utils::get_cmp_func(hasnan, reversed)); + xss::utils::get_cmp_func(hasnan, order)); } template std::vector argsort(T *arr, size_t arrsize, bool hasnan) diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 4a1c2a9f..2b92c6c6 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -5,23 +5,31 @@ #include "avx512-64bit-qsort.hpp" #include "x86simdsort-internal.h" +using x86simdsort::sort_order; + #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ + void qsort(type *arr, size_t arrsize, bool hasnan, sort_order order) \ { \ - avx512_qsort(arr, arrsize, hasnan, descending); \ + avx512_qsort(arr, arrsize, hasnan, order); \ } \ template <> \ - void qselect( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void qselect(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + sort_order order) \ { \ - avx512_qselect(arr, k, arrsize, hasnan, descending); \ + avx512_qselect(arr, k, arrsize, hasnan, order); \ } \ template <> \ - void partial_qsort( \ - type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void partial_qsort(type *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + sort_order order) \ { \ - avx512_partial_qsort(arr, k, arrsize, hasnan, descending); \ + avx512_partial_qsort(arr, k, arrsize, hasnan, order); \ } \ template <> \ std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ diff --git a/lib/x86simdsort-spr.cpp b/lib/x86simdsort-spr.cpp index d19de1ed..e3b0946e 100644 --- a/lib/x86simdsort-spr.cpp +++ b/lib/x86simdsort-spr.cpp @@ -2,30 +2,32 @@ #include "avx512fp16-16bit-qsort.hpp" #include "x86simdsort-internal.h" +using x86simdsort::sort_order; + namespace xss { namespace avx512 { template <> - void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending) + void qsort(_Float16 *arr, size_t size, bool hasnan, sort_order order) { - avx512_qsort(arr, size, hasnan, descending); + avx512_qsort(arr, size, hasnan, order); } template <> void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + sort_order order) { - avx512_qselect(arr, k, arrsize, hasnan, descending); + avx512_qselect(arr, k, arrsize, hasnan, order); } template <> void partial_qsort(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - bool descending) + sort_order order) { - avx512_partial_qsort(arr, k, arrsize, hasnan, descending); + avx512_partial_qsort(arr, k, arrsize, hasnan, order); } } // namespace avx512 } // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 21c8b34f..dd39b561 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -5,6 +5,8 @@ #include #include +using x86simdsort::sort_order; + static int check_cpu_feature_support(std::string_view cpufeature) { const char *disable_avx512 = std::getenv("XSS_DISABLE_AVX512"); @@ -57,32 +59,40 @@ namespace x86simdsort { #define CAT(a, b) CAT_(a, b) #define DECLARE_INTERNAL_qsort(TYPE) \ - static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool) = NULL; \ + static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, sort_order) \ + = NULL; \ template <> \ - void qsort(TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + void qsort(TYPE *arr, size_t arrsize, bool hasnan, sort_order order) \ { \ - (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending); \ + (*internal_qsort##TYPE)(arr, arrsize, hasnan, order); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ - static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool) \ + static void (*internal_qselect##TYPE)( \ + TYPE *, size_t, size_t, bool, sort_order) \ = NULL; \ template <> \ - void qselect( \ - TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void qselect(TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + sort_order order) \ { \ - (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending); \ + (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, order); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ static void (*internal_partial_qsort##TYPE)( \ - TYPE *, size_t, size_t, bool, bool) \ + TYPE *, size_t, size_t, bool, sort_order) \ = NULL; \ template <> \ - void partial_qsort( \ - TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void partial_qsort(TYPE *arr, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + sort_order order) \ { \ - (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending); \ + (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, order); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index 42d5247f..369d4153 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -5,6 +5,7 @@ #include #include #include +#include "x86simdsort-orders.h" #define XSS_EXPORT_SYMBOL __attribute__((visibility("default"))) #define XSS_HIDE_SYMBOL __attribute__((visibility("hidden"))) @@ -14,8 +15,11 @@ namespace x86simdsort { // quicksort template -XSS_EXPORT_SYMBOL void -qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); +XSS_EXPORT_SYMBOL void qsort(T *arr, + size_t arrsize, + bool hasnan = false, + x86simdsort::sort_order order + = x86simdsort::sort_order::sort_ascending); // quickselect template @@ -23,7 +27,8 @@ XSS_EXPORT_SYMBOL void qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, - bool descending = false); + x86simdsort::sort_order order + = x86simdsort::sort_order::sort_ascending); // partial sort template @@ -31,7 +36,8 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, - bool descending = false); + x86simdsort::sort_order order + = x86simdsort::sort_order::sort_ascending); // argsort template diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 53ad856e..7c9da83c 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -203,7 +203,7 @@ X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan(_Float16 *arr, /* Specialized template function for _Float16 qsort_*/ template <> X86_SIMD_SORT_INLINE_ONLY void -avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, bool descending) +avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, sort_order order) { using vtype = zmm_vector<_Float16>; @@ -212,7 +212,7 @@ avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, bool descending) if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf(arr, arrsize); } - if (descending) { + if (order == sort_order::sort_descending) { qsort_, _Float16>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } @@ -220,7 +220,8 @@ avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, bool descending) qsort_, _Float16>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } - replace_inf_with_nan(arr, arrsize, nan_count, descending); + replace_inf_with_nan( + arr, arrsize, nan_count, order == sort_order::sort_descending); } } @@ -229,11 +230,11 @@ X86_SIMD_SORT_INLINE_ONLY void avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, - bool descending) + sort_order order) { using vtype = zmm_vector<_Float16>; - if (descending) { + if (order == sort_order::sort_descending) { arrsize_t index_first_elem = 0; if (UNLIKELY(hasnan)) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); @@ -271,9 +272,9 @@ X86_SIMD_SORT_INLINE_ONLY void avx512_partial_qsort(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, - bool descending) + sort_order order) { - avx512_qselect(arr, k - 1, arrsize, hasnan, descending); - avx512_qsort(arr, k - 1, hasnan, descending); + avx512_qselect(arr, k - 1, arrsize, hasnan, order); + avx512_qsort(arr, k - 1, hasnan, order); } #endif // AVX512FP16_QSORT_16BIT diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 435abd56..ce315c27 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -38,6 +38,9 @@ #include "xss-pivot-selection.hpp" #include "xss-network-qsort.hpp" #include "xss-common-comparators.hpp" +#include "../lib/x86simdsort-orders.h" + +using x86simdsort::sort_order; template bool is_a_nan(T elem) @@ -608,7 +611,7 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, // Quicksort routines: template X86_SIMD_SORT_INLINE void -xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool descending) +xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, sort_order order) { if (arrsize > 1) { if constexpr (std::is_floating_point_v) { @@ -616,7 +619,7 @@ xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool descending) if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf(arr, arrsize); } - if (descending) { + if (order == sort_order::sort_descending) { qsort_, T>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } @@ -624,11 +627,14 @@ xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool descending) qsort_, T>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } - replace_inf_with_nan(arr, arrsize, nan_count, descending); + replace_inf_with_nan(arr, + arrsize, + nan_count, + order == sort_order::sort_descending); } else { UNUSED(hasnan); - if (descending) { + if (order == sort_order::sort_descending) { qsort_, T>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } @@ -643,9 +649,9 @@ xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool descending) // Quick select methods template X86_SIMD_SORT_INLINE void xss_qselect( - T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) + T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, sort_order order) { - if (descending) { + if (order == sort_order::sort_descending) { arrsize_t index_first_elem = 0; if constexpr (std::is_floating_point_v) { if (UNLIKELY(hasnan)) { @@ -687,10 +693,10 @@ X86_SIMD_SORT_INLINE void xss_qselect( // Partial sort methods: template X86_SIMD_SORT_INLINE void xss_partial_qsort( - T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) + T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, sort_order order) { - xss_qselect(arr, k - 1, arrsize, hasnan, descending); - xss_qsort(arr, k - 1, hasnan, descending); + xss_qselect(arr, k - 1, arrsize, hasnan, order); + xss_qsort(arr, k - 1, hasnan, order); } #define DEFINE_METHODS(ISA, VTYPE) \ @@ -698,27 +704,30 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort( X86_SIMD_SORT_INLINE void ISA##_qsort(T *arr, \ arrsize_t size, \ bool hasnan = false, \ - bool descending = false) \ + sort_order order \ + = sort_order::sort_ascending) \ { \ - xss_qsort(arr, size, hasnan, descending); \ + xss_qsort(arr, size, hasnan, order); \ } \ template \ X86_SIMD_SORT_INLINE void ISA##_qselect(T *arr, \ arrsize_t k, \ arrsize_t size, \ bool hasnan = false, \ - bool descending = false) \ + sort_order order \ + = sort_order::sort_ascending) \ { \ - xss_qselect(arr, k, size, hasnan, descending); \ + xss_qselect(arr, k, size, hasnan, order); \ } \ template \ - X86_SIMD_SORT_INLINE void ISA##_partial_qsort(T *arr, \ - arrsize_t k, \ - arrsize_t size, \ - bool hasnan = false, \ - bool descending = false) \ + X86_SIMD_SORT_INLINE void ISA##_partial_qsort( \ + T *arr, \ + arrsize_t k, \ + arrsize_t size, \ + bool hasnan = false, \ + sort_order order = sort_order::sort_ascending) \ { \ - xss_partial_qsort(arr, k, size, hasnan, descending); \ + xss_partial_qsort(arr, k, size, hasnan, order); \ } DEFINE_METHODS(avx512, zmm_vector) diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 72e173a4..a92da1e3 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -55,14 +55,17 @@ TYPED_TEST_P(simdsort, test_qsort_descending) bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { std::vector basearr = get_array(type, size); - + // Descending order std::vector arr = basearr; std::vector sortedarr = arr; std::sort(sortedarr.begin(), sortedarr.end(), compare>()); - x86simdsort::qsort(arr.data(), arr.size(), hasnan, true); + x86simdsort::qsort(arr.data(), + arr.size(), + hasnan, + x86simdsort::sort_order::sort_descending); IS_SORTED(sortedarr, arr, type); arr.clear(); @@ -128,7 +131,11 @@ TYPED_TEST_P(simdsort, test_qselect_descending) sortedarr.begin() + k, sortedarr.end(), compare>()); - x86simdsort::qselect(arr.data(), k, arr.size(), hasnan, true); + x86simdsort::qselect(arr.data(), + k, + arr.size(), + hasnan, + x86simdsort::sort_order::sort_descending); IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); arr.clear(); @@ -196,7 +203,12 @@ TYPED_TEST_P(simdsort, test_partial_qsort_descending) std::sort(sortedarr.begin(), sortedarr.end(), compare>()); - x86simdsort::partial_qsort(arr.data(), k, arr.size(), hasnan, true); + x86simdsort::partial_qsort( + arr.data(), + k, + arr.size(), + hasnan, + x86simdsort::sort_order::sort_descending); IS_ARR_PARTIALSORTED(arr, k, sortedarr, type); arr.clear(); From e5b3603186c5437b85b4adb4addb2c9ada18e979 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 22 Mar 2024 12:50:32 -0700 Subject: [PATCH 13/18] Remove duplicate code in IS_ARR_PARTITIONED --- tests/test-qsort-common.h | 76 ++++++++++++++------------------------- 1 file changed, 27 insertions(+), 49 deletions(-) diff --git a/tests/test-qsort-common.h b/tests/test-qsort-common.h index af1578ed..998537be 100644 --- a/tests/test-qsort-common.h +++ b/tests/test-qsort-common.h @@ -49,60 +49,38 @@ void IS_ARR_PARTITIONED(std::vector arr, std::string type, bool descending = false) { - if (!descending) { - auto cmp_eq = compare>(); - auto cmp_less = compare>(); - - auto cmp_leq = compare>(); - auto cmp_geq = compare>(); + std::function cmp_eq, cmp_less, cmp_leq, cmp_geq; + cmp_eq = compare>(); - // 1) arr[k] == sorted[k]; use memcmp to handle nan - if (!cmp_eq(arr[k], true_kth)) { - REPORT_FAIL("kth element is incorrect", arr.size(), type, k); - } - // ( 2) Elements to the left of k should be atmost arr[k] - if (k >= 1) { - T max_left = *std::max_element( - arr.begin(), arr.begin() + k - 1, cmp_less); - if (!cmp_geq(arr[k], max_left)) { - REPORT_FAIL("incorrect left partition", arr.size(), type, k); - } - } - // 3) Elements to the right of k should be atleast arr[k] - if (k != (size_t)(arr.size() - 1)) { - T min_right = *std::min_element( - arr.begin() + k + 1, arr.end(), cmp_less); - if (!cmp_leq(arr[k], min_right)) { - REPORT_FAIL("incorrect right partition", arr.size(), type, k); - } - } + if (!descending) { + cmp_less = compare>(); + cmp_leq = compare>(); + cmp_geq = compare>(); } else { - auto cmp_eq = compare>(); - auto cmp_less = compare>(); - - auto cmp_leq = compare>(); - auto cmp_geq = compare>(); + cmp_less = compare>(); + cmp_leq = compare>(); + cmp_geq = compare>(); + } - // 1) arr[k] == sorted[k]; use memcmp to handle nan - if (!cmp_eq(arr[k], true_kth)) { - REPORT_FAIL("kth element is incorrect", arr.size(), type, k); - } - // ( 2) Elements to the left of k should be atleast arr[k] - if (k >= 1) { - T max_left = *std::max_element( - arr.begin(), arr.begin() + k - 1, cmp_less); - if (!cmp_geq(arr[k], max_left)) { - REPORT_FAIL("incorrect left partition", arr.size(), type, k); - } + // 1) arr[k] == sorted[k]; use memcmp to handle nan + if (!cmp_eq(arr[k], true_kth)) { + REPORT_FAIL("kth element is incorrect", arr.size(), type, k); + } + // ( 2) Elements to the left of k should be atmost arr[k] + if (k >= 1) { + T max_left = *std::max_element( + arr.begin(), arr.begin() + k - 1, cmp_less); + if (!cmp_geq(arr[k], max_left)) { + REPORT_FAIL("incorrect left partition", arr.size(), type, k); } - // 3) Elements to the right of k should be atmost arr[k] - if (k != (size_t)(arr.size() - 1)) { - T min_right = *std::min_element( - arr.begin() + k + 1, arr.end(), cmp_less); - if (!cmp_leq(arr[k], min_right)) { - REPORT_FAIL("incorrect right partition", arr.size(), type, k); - } + } + // 3) Elements to the right of k should be atleast arr[k] + if (k != (size_t)(arr.size() - 1)) { + T min_right = *std::min_element( + arr.begin() + k + 1, arr.end(), cmp_less); + if (!cmp_leq(arr[k], min_right)) { + REPORT_FAIL("incorrect right partition", arr.size(), type, k); } } } From 491860912cbc12b3ab189b68c204b12bdc759b7e Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 26 Mar 2024 09:31:23 -0700 Subject: [PATCH 14/18] Revert "Changed from boolean descending flag to enum" This reverts commit 9d508c1986402c8e8b6eac803e919ce49f41a1c0. --- benchmarks/bench-qsort.hpp | 5 +--- lib/x86simdsort-avx2.cpp | 24 ++++++----------- lib/x86simdsort-icl.cpp | 26 +++++++++---------- lib/x86simdsort-internal.h | 24 ++++++----------- lib/x86simdsort-orders.h | 10 -------- lib/x86simdsort-scalar.h | 30 +++++++++------------- lib/x86simdsort-skx.cpp | 24 ++++++----------- lib/x86simdsort-spr.cpp | 14 +++++----- lib/x86simdsort.cpp | 32 ++++++++--------------- lib/x86simdsort.h | 14 +++------- src/avx512fp16-16bit-qsort.hpp | 17 ++++++------ src/xss-common-qsort.h | 47 ++++++++++++++-------------------- tests/test-qsort.cpp | 20 +++------------ 13 files changed, 101 insertions(+), 186 deletions(-) delete mode 100644 lib/x86simdsort-orders.h diff --git a/benchmarks/bench-qsort.hpp b/benchmarks/bench-qsort.hpp index 7a105428..a6e1c794 100644 --- a/benchmarks/bench-qsort.hpp +++ b/benchmarks/bench-qsort.hpp @@ -67,10 +67,7 @@ static void simd_revsort(benchmark::State &state, Args &&...args) std::vector arr_bkp = arr; // benchmark for (auto _ : state) { - x86simdsort::qsort(arr.data(), - arrsize, - false, - x86simdsort::sort_order::sort_descending); + x86simdsort::qsort(arr.data(), arrsize, false, true); state.PauseTiming(); arr = arr_bkp; state.ResumeTiming(); diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 165241e1..345653d9 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -5,31 +5,23 @@ #include "xss-common-argsort.h" #include "x86simdsort-internal.h" -using x86simdsort::sort_order; - #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan, sort_order order) \ + void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - avx2_qsort(arr, arrsize, hasnan, order); \ + avx2_qsort(arr, arrsize, hasnan, descending); \ } \ template <> \ - void qselect(type *arr, \ - size_t k, \ - size_t arrsize, \ - bool hasnan, \ - sort_order order) \ + void qselect( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx2_qselect(arr, k, arrsize, hasnan, order); \ + avx2_qselect(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - void partial_qsort(type *arr, \ - size_t k, \ - size_t arrsize, \ - bool hasnan, \ - sort_order order) \ + void partial_qsort( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx2_partial_qsort(arr, k, arrsize, hasnan, order); \ + avx2_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ template <> \ std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ diff --git a/lib/x86simdsort-icl.cpp b/lib/x86simdsort-icl.cpp index d051643f..20095369 100644 --- a/lib/x86simdsort-icl.cpp +++ b/lib/x86simdsort-icl.cpp @@ -2,55 +2,53 @@ #include "avx512-16bit-qsort.hpp" #include "x86simdsort-internal.h" -using x86simdsort::sort_order; - namespace xss { namespace avx512 { template <> - void qsort(uint16_t *arr, size_t size, bool hasnan, sort_order order) + void qsort(uint16_t *arr, size_t size, bool hasnan, bool descending) { - avx512_qsort(arr, size, hasnan, order); + avx512_qsort(arr, size, hasnan, descending); } template <> void qselect(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, - sort_order order) + bool descending) { - avx512_qselect(arr, k, arrsize, hasnan, order); + avx512_qselect(arr, k, arrsize, hasnan, descending); } template <> void partial_qsort(uint16_t *arr, size_t k, size_t arrsize, bool hasnan, - sort_order order) + bool descending) { - avx512_partial_qsort(arr, k, arrsize, hasnan, order); + avx512_partial_qsort(arr, k, arrsize, hasnan, descending); } template <> - void qsort(int16_t *arr, size_t size, bool hasnan, sort_order order) + void qsort(int16_t *arr, size_t size, bool hasnan, bool descending) { - avx512_qsort(arr, size, hasnan, order); + avx512_qsort(arr, size, hasnan, descending); } template <> void qselect(int16_t *arr, size_t k, size_t arrsize, bool hasnan, - sort_order order) + bool descending) { - avx512_qselect(arr, k, arrsize, hasnan, order); + avx512_qselect(arr, k, arrsize, hasnan, descending); } template <> void partial_qsort(int16_t *arr, size_t k, size_t arrsize, bool hasnan, - sort_order order) + bool descending) { - avx512_partial_qsort(arr, k, arrsize, hasnan, order); + avx512_partial_qsort(arr, k, arrsize, hasnan, descending); } } // namespace avx512 } // namespace xss diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 94e50b82..dad32b91 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -4,16 +4,12 @@ #include #include -using x86simdsort::sort_order; - namespace xss { namespace avx512 { // quicksort template - XSS_HIDE_SYMBOL void qsort(T *arr, - size_t arrsize, - bool hasnan = false, - sort_order order = sort_order::sort_ascending); + XSS_HIDE_SYMBOL void + qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template XSS_EXPORT_SYMBOL void @@ -24,15 +20,14 @@ namespace avx512 { size_t k, size_t arrsize, bool hasnan = false, - sort_order order = sort_order::sort_ascending); + bool descending = false); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, - sort_order order - = sort_order::sort_ascending); + bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector @@ -45,10 +40,8 @@ namespace avx512 { namespace avx2 { // quicksort template - XSS_HIDE_SYMBOL void qsort(T *arr, - size_t arrsize, - bool hasnan = false, - sort_order order = sort_order::sort_ascending); + XSS_HIDE_SYMBOL void + qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template XSS_EXPORT_SYMBOL void @@ -59,15 +52,14 @@ namespace avx2 { size_t k, size_t arrsize, bool hasnan = false, - sort_order order = sort_order::sort_ascending); + bool descending = false); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, - sort_order order - = sort_order::sort_ascending); + bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector diff --git a/lib/x86simdsort-orders.h b/lib/x86simdsort-orders.h deleted file mode 100644 index 02fbc005..00000000 --- a/lib/x86simdsort-orders.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef XSS_ORDERS -#define XSS_ORDERS - -namespace x86simdsort { - -enum class sort_order : int { sort_ascending, sort_descending }; - -} // namespace x86simdsort - -#endif \ No newline at end of file diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 0e635275..6afc7287 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -2,8 +2,6 @@ #include #include -using x86simdsort::sort_order; - namespace xss { namespace utils { /* @@ -26,21 +24,17 @@ namespace utils { } } template - decltype(auto) get_cmp_func(bool hasnan, sort_order order) + decltype(auto) get_cmp_func(bool hasnan, bool reverse) { std::function cmp; if (hasnan) { - if (order == sort_order::sort_descending) { - cmp = compare>(); - } + if (reverse == true) { cmp = compare>(); } else { cmp = compare>(); } } else { - if (order == sort_order::sort_descending) { - cmp = std::greater(); - } + if (reverse == true) { cmp = std::greater(); } else { cmp = std::less(); } @@ -51,29 +45,29 @@ namespace utils { namespace scalar { template - void qsort(T *arr, size_t arrsize, bool hasnan, sort_order order) + void qsort(T *arr, size_t arrsize, bool hasnan, bool reversed) { - std::sort( - arr, arr + arrsize, xss::utils::get_cmp_func(hasnan, order)); + std::sort(arr, + arr + arrsize, + xss::utils::get_cmp_func(hasnan, reversed)); } template - void - qselect(T *arr, size_t k, size_t arrsize, bool hasnan, sort_order order) + void qselect(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) { std::nth_element(arr, arr + k, arr + arrsize, - xss::utils::get_cmp_func(hasnan, order)); + xss::utils::get_cmp_func(hasnan, reversed)); } template - void partial_qsort( - T *arr, size_t k, size_t arrsize, bool hasnan, sort_order order) + void + partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan, bool reversed) { std::partial_sort(arr, arr + k, arr + arrsize, - xss::utils::get_cmp_func(hasnan, order)); + xss::utils::get_cmp_func(hasnan, reversed)); } template std::vector argsort(T *arr, size_t arrsize, bool hasnan) diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 2b92c6c6..4a1c2a9f 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -5,31 +5,23 @@ #include "avx512-64bit-qsort.hpp" #include "x86simdsort-internal.h" -using x86simdsort::sort_order; - #define DEFINE_ALL_METHODS(type) \ template <> \ - void qsort(type *arr, size_t arrsize, bool hasnan, sort_order order) \ + void qsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - avx512_qsort(arr, arrsize, hasnan, order); \ + avx512_qsort(arr, arrsize, hasnan, descending); \ } \ template <> \ - void qselect(type *arr, \ - size_t k, \ - size_t arrsize, \ - bool hasnan, \ - sort_order order) \ + void qselect( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx512_qselect(arr, k, arrsize, hasnan, order); \ + avx512_qselect(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - void partial_qsort(type *arr, \ - size_t k, \ - size_t arrsize, \ - bool hasnan, \ - sort_order order) \ + void partial_qsort( \ + type *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - avx512_partial_qsort(arr, k, arrsize, hasnan, order); \ + avx512_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ template <> \ std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ diff --git a/lib/x86simdsort-spr.cpp b/lib/x86simdsort-spr.cpp index e3b0946e..d19de1ed 100644 --- a/lib/x86simdsort-spr.cpp +++ b/lib/x86simdsort-spr.cpp @@ -2,32 +2,30 @@ #include "avx512fp16-16bit-qsort.hpp" #include "x86simdsort-internal.h" -using x86simdsort::sort_order; - namespace xss { namespace avx512 { template <> - void qsort(_Float16 *arr, size_t size, bool hasnan, sort_order order) + void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending) { - avx512_qsort(arr, size, hasnan, order); + avx512_qsort(arr, size, hasnan, descending); } template <> void qselect(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - sort_order order) + bool descending) { - avx512_qselect(arr, k, arrsize, hasnan, order); + avx512_qselect(arr, k, arrsize, hasnan, descending); } template <> void partial_qsort(_Float16 *arr, size_t k, size_t arrsize, bool hasnan, - sort_order order) + bool descending) { - avx512_partial_qsort(arr, k, arrsize, hasnan, order); + avx512_partial_qsort(arr, k, arrsize, hasnan, descending); } } // namespace avx512 } // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index dd39b561..21c8b34f 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -5,8 +5,6 @@ #include #include -using x86simdsort::sort_order; - static int check_cpu_feature_support(std::string_view cpufeature) { const char *disable_avx512 = std::getenv("XSS_DISABLE_AVX512"); @@ -59,40 +57,32 @@ namespace x86simdsort { #define CAT(a, b) CAT_(a, b) #define DECLARE_INTERNAL_qsort(TYPE) \ - static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, sort_order) \ - = NULL; \ + static void (*internal_qsort##TYPE)(TYPE *, size_t, bool, bool) = NULL; \ template <> \ - void qsort(TYPE *arr, size_t arrsize, bool hasnan, sort_order order) \ + void qsort(TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - (*internal_qsort##TYPE)(arr, arrsize, hasnan, order); \ + (*internal_qsort##TYPE)(arr, arrsize, hasnan, descending); \ } #define DECLARE_INTERNAL_qselect(TYPE) \ - static void (*internal_qselect##TYPE)( \ - TYPE *, size_t, size_t, bool, sort_order) \ + static void (*internal_qselect##TYPE)(TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void qselect(TYPE *arr, \ - size_t k, \ - size_t arrsize, \ - bool hasnan, \ - sort_order order) \ + void qselect( \ + TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, order); \ + (*internal_qselect##TYPE)(arr, k, arrsize, hasnan, descending); \ } #define DECLARE_INTERNAL_partial_qsort(TYPE) \ static void (*internal_partial_qsort##TYPE)( \ - TYPE *, size_t, size_t, bool, sort_order) \ + TYPE *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void partial_qsort(TYPE *arr, \ - size_t k, \ - size_t arrsize, \ - bool hasnan, \ - sort_order order) \ + void partial_qsort( \ + TYPE *arr, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, order); \ + (*internal_partial_qsort##TYPE)(arr, k, arrsize, hasnan, descending); \ } #define DECLARE_INTERNAL_argsort(TYPE) \ diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index 369d4153..42d5247f 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -5,7 +5,6 @@ #include #include #include -#include "x86simdsort-orders.h" #define XSS_EXPORT_SYMBOL __attribute__((visibility("default"))) #define XSS_HIDE_SYMBOL __attribute__((visibility("hidden"))) @@ -15,11 +14,8 @@ namespace x86simdsort { // quicksort template -XSS_EXPORT_SYMBOL void qsort(T *arr, - size_t arrsize, - bool hasnan = false, - x86simdsort::sort_order order - = x86simdsort::sort_order::sort_ascending); +XSS_EXPORT_SYMBOL void +qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // quickselect template @@ -27,8 +23,7 @@ XSS_EXPORT_SYMBOL void qselect(T *arr, size_t k, size_t arrsize, bool hasnan = false, - x86simdsort::sort_order order - = x86simdsort::sort_order::sort_ascending); + bool descending = false); // partial sort template @@ -36,8 +31,7 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr, size_t k, size_t arrsize, bool hasnan = false, - x86simdsort::sort_order order - = x86simdsort::sort_order::sort_ascending); + bool descending = false); // argsort template diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 7c9da83c..53ad856e 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -203,7 +203,7 @@ X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan(_Float16 *arr, /* Specialized template function for _Float16 qsort_*/ template <> X86_SIMD_SORT_INLINE_ONLY void -avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, sort_order order) +avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, bool descending) { using vtype = zmm_vector<_Float16>; @@ -212,7 +212,7 @@ avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, sort_order order) if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf(arr, arrsize); } - if (order == sort_order::sort_descending) { + if (descending) { qsort_, _Float16>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } @@ -220,8 +220,7 @@ avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, sort_order order) qsort_, _Float16>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } - replace_inf_with_nan( - arr, arrsize, nan_count, order == sort_order::sort_descending); + replace_inf_with_nan(arr, arrsize, nan_count, descending); } } @@ -230,11 +229,11 @@ X86_SIMD_SORT_INLINE_ONLY void avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, - sort_order order) + bool descending) { using vtype = zmm_vector<_Float16>; - if (order == sort_order::sort_descending) { + if (descending) { arrsize_t index_first_elem = 0; if (UNLIKELY(hasnan)) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); @@ -272,9 +271,9 @@ X86_SIMD_SORT_INLINE_ONLY void avx512_partial_qsort(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, - sort_order order) + bool descending) { - avx512_qselect(arr, k - 1, arrsize, hasnan, order); - avx512_qsort(arr, k - 1, hasnan, order); + avx512_qselect(arr, k - 1, arrsize, hasnan, descending); + avx512_qsort(arr, k - 1, hasnan, descending); } #endif // AVX512FP16_QSORT_16BIT diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index ce315c27..435abd56 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -38,9 +38,6 @@ #include "xss-pivot-selection.hpp" #include "xss-network-qsort.hpp" #include "xss-common-comparators.hpp" -#include "../lib/x86simdsort-orders.h" - -using x86simdsort::sort_order; template bool is_a_nan(T elem) @@ -611,7 +608,7 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, // Quicksort routines: template X86_SIMD_SORT_INLINE void -xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, sort_order order) +xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool descending) { if (arrsize > 1) { if constexpr (std::is_floating_point_v) { @@ -619,7 +616,7 @@ xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, sort_order order) if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf(arr, arrsize); } - if (order == sort_order::sort_descending) { + if (descending) { qsort_, T>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } @@ -627,14 +624,11 @@ xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, sort_order order) qsort_, T>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } - replace_inf_with_nan(arr, - arrsize, - nan_count, - order == sort_order::sort_descending); + replace_inf_with_nan(arr, arrsize, nan_count, descending); } else { UNUSED(hasnan); - if (order == sort_order::sort_descending) { + if (descending) { qsort_, T>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } @@ -649,9 +643,9 @@ xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, sort_order order) // Quick select methods template X86_SIMD_SORT_INLINE void xss_qselect( - T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, sort_order order) + T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) { - if (order == sort_order::sort_descending) { + if (descending) { arrsize_t index_first_elem = 0; if constexpr (std::is_floating_point_v) { if (UNLIKELY(hasnan)) { @@ -693,10 +687,10 @@ X86_SIMD_SORT_INLINE void xss_qselect( // Partial sort methods: template X86_SIMD_SORT_INLINE void xss_partial_qsort( - T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, sort_order order) + T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) { - xss_qselect(arr, k - 1, arrsize, hasnan, order); - xss_qsort(arr, k - 1, hasnan, order); + xss_qselect(arr, k - 1, arrsize, hasnan, descending); + xss_qsort(arr, k - 1, hasnan, descending); } #define DEFINE_METHODS(ISA, VTYPE) \ @@ -704,30 +698,27 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort( X86_SIMD_SORT_INLINE void ISA##_qsort(T *arr, \ arrsize_t size, \ bool hasnan = false, \ - sort_order order \ - = sort_order::sort_ascending) \ + bool descending = false) \ { \ - xss_qsort(arr, size, hasnan, order); \ + xss_qsort(arr, size, hasnan, descending); \ } \ template \ X86_SIMD_SORT_INLINE void ISA##_qselect(T *arr, \ arrsize_t k, \ arrsize_t size, \ bool hasnan = false, \ - sort_order order \ - = sort_order::sort_ascending) \ + bool descending = false) \ { \ - xss_qselect(arr, k, size, hasnan, order); \ + xss_qselect(arr, k, size, hasnan, descending); \ } \ template \ - X86_SIMD_SORT_INLINE void ISA##_partial_qsort( \ - T *arr, \ - arrsize_t k, \ - arrsize_t size, \ - bool hasnan = false, \ - sort_order order = sort_order::sort_ascending) \ + X86_SIMD_SORT_INLINE void ISA##_partial_qsort(T *arr, \ + arrsize_t k, \ + arrsize_t size, \ + bool hasnan = false, \ + bool descending = false) \ { \ - xss_partial_qsort(arr, k, size, hasnan, order); \ + xss_partial_qsort(arr, k, size, hasnan, descending); \ } DEFINE_METHODS(avx512, zmm_vector) diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index a92da1e3..72e173a4 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -55,17 +55,14 @@ TYPED_TEST_P(simdsort, test_qsort_descending) bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { std::vector basearr = get_array(type, size); - + // Descending order std::vector arr = basearr; std::vector sortedarr = arr; std::sort(sortedarr.begin(), sortedarr.end(), compare>()); - x86simdsort::qsort(arr.data(), - arr.size(), - hasnan, - x86simdsort::sort_order::sort_descending); + x86simdsort::qsort(arr.data(), arr.size(), hasnan, true); IS_SORTED(sortedarr, arr, type); arr.clear(); @@ -131,11 +128,7 @@ TYPED_TEST_P(simdsort, test_qselect_descending) sortedarr.begin() + k, sortedarr.end(), compare>()); - x86simdsort::qselect(arr.data(), - k, - arr.size(), - hasnan, - x86simdsort::sort_order::sort_descending); + x86simdsort::qselect(arr.data(), k, arr.size(), hasnan, true); IS_ARR_PARTITIONED(arr, k, sortedarr[k], type, true); arr.clear(); @@ -203,12 +196,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort_descending) std::sort(sortedarr.begin(), sortedarr.end(), compare>()); - x86simdsort::partial_qsort( - arr.data(), - k, - arr.size(), - hasnan, - x86simdsort::sort_order::sort_descending); + x86simdsort::partial_qsort(arr.data(), k, arr.size(), hasnan, true); IS_ARR_PARTIALSORTED(arr, k, sortedarr, type); arr.clear(); From 6cb9046a91afe4e5dc2aab5bec68d95324eac33e Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 26 Mar 2024 11:51:02 -0700 Subject: [PATCH 15/18] Simplified logic for ascending/descending comparators --- lib/x86simdsort-spr.cpp | 15 +++- src/avx512fp16-16bit-qsort.hpp | 87 ++++++++++------------- src/xss-common-qsort.h | 125 +++++++++++++++------------------ 3 files changed, 106 insertions(+), 121 deletions(-) diff --git a/lib/x86simdsort-spr.cpp b/lib/x86simdsort-spr.cpp index d19de1ed..b09a8393 100644 --- a/lib/x86simdsort-spr.cpp +++ b/lib/x86simdsort-spr.cpp @@ -7,7 +7,10 @@ namespace avx512 { template <> void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending) { - avx512_qsort(arr, size, hasnan, descending); + if (descending) { avx512_qsort(arr, size, hasnan); } + else { + avx512_qsort(arr, size, hasnan); + } } template <> void qselect(_Float16 *arr, @@ -16,7 +19,10 @@ namespace avx512 { bool hasnan, bool descending) { - avx512_qselect(arr, k, arrsize, hasnan, descending); + if (descending) { avx512_qselect(arr, k, arrsize, hasnan); } + else { + avx512_qselect(arr, k, arrsize, hasnan); + } } template <> void partial_qsort(_Float16 *arr, @@ -25,7 +31,10 @@ namespace avx512 { bool hasnan, bool descending) { - avx512_partial_qsort(arr, k, arrsize, hasnan, descending); + if (descending) { avx512_partial_qsort(arr, k, arrsize, hasnan); } + else { + avx512_partial_qsort(arr, k, arrsize, hasnan); + } } } // namespace avx512 } // namespace xss diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 53ad856e..0100cfea 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -201,79 +201,64 @@ X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan(_Float16 *arr, } } /* Specialized template function for _Float16 qsort_*/ -template <> +template X86_SIMD_SORT_INLINE_ONLY void -avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan, bool descending) +avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan) { using vtype = zmm_vector<_Float16>; + using comparator = + typename std::conditional, + AscendingComparator>::type; if (arrsize > 1) { arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { - nan_count = replace_nan_with_inf(arr, arrsize); - } - if (descending) { - qsort_, _Float16>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } - else { - qsort_, _Float16>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + nan_count = replace_nan_with_inf(arr, arrsize); } + + qsort_( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count, descending); } } -template <> -X86_SIMD_SORT_INLINE_ONLY void avx512_qselect(_Float16 *arr, - arrsize_t k, - arrsize_t arrsize, - bool hasnan, - bool descending) +template +X86_SIMD_SORT_INLINE_ONLY void +avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { using vtype = zmm_vector<_Float16>; + using comparator = + typename std::conditional, + AscendingComparator>::type; - if (descending) { - arrsize_t index_first_elem = 0; - if (UNLIKELY(hasnan)) { + arrsize_t index_first_elem = 0; + arrsize_t index_last_elem = arrsize - 1; + + if (UNLIKELY(hasnan)) { + if constexpr (descending) { index_first_elem = move_nans_to_start_of_array(arr, arrsize); } - - arrsize_t size_without_nans = arrsize - index_first_elem; - - if (index_first_elem <= k) { - qselect_, _Float16>( - arr, - k, - index_first_elem, - arrsize - 1, - 2 * (arrsize_t)log2(size_without_nans)); + else { + index_last_elem = move_nans_to_end_of_array(arr, arrsize); } } - else { - arrsize_t indx_last_elem = arrsize - 1; - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); - } - if (indx_last_elem >= k) { - qselect_, _Float16>( - arr, - k, - 0, - indx_last_elem, - 2 * (arrsize_t)log2(indx_last_elem)); - } + if (index_first_elem <= k && index_last_elem >= k) { + qselect_(arr, + k, + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize)); } } -template <> -X86_SIMD_SORT_INLINE_ONLY void avx512_partial_qsort(_Float16 *arr, - arrsize_t k, - arrsize_t arrsize, - bool hasnan, - bool descending) +template +X86_SIMD_SORT_INLINE_ONLY void +avx512_partial_qsort(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { - avx512_qselect(arr, k - 1, arrsize, hasnan, descending); - avx512_qsort(arr, k - 1, hasnan, descending); + avx512_qselect(arr, k - 1, arrsize, hasnan); + avx512_qsort(arr, k - 1, hasnan); } #endif // AVX512FP16_QSORT_16BIT diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 435abd56..4aa81260 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -606,91 +606,71 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, } // Quicksort routines: -template -X86_SIMD_SORT_INLINE void -xss_qsort(T *arr, arrsize_t arrsize, bool hasnan, bool descending) +template +X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan) { + using comparator = + typename std::conditional, + AscendingComparator>::type; + if (arrsize > 1) { + arrsize_t nan_count = 0; if constexpr (std::is_floating_point_v) { - arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf(arr, arrsize); } - if (descending) { - qsort_, T>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } - else { - qsort_, T>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } - replace_inf_with_nan(arr, arrsize, nan_count, descending); - } - else { - UNUSED(hasnan); - if (descending) { - qsort_, T>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } - else { - qsort_, T>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } } + + UNUSED(hasnan); + qsort_( + arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + + replace_inf_with_nan(arr, arrsize, nan_count, descending); } } // Quick select methods -template -X86_SIMD_SORT_INLINE void xss_qselect( - T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) +template +X86_SIMD_SORT_INLINE void +xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { - if (descending) { - arrsize_t index_first_elem = 0; - if constexpr (std::is_floating_point_v) { - if (UNLIKELY(hasnan)) { - index_first_elem = move_nans_to_start_of_array(arr, arrsize); - } - } + using comparator = + typename std::conditional, + AscendingComparator>::type; - arrsize_t size_without_nans = arrsize - index_first_elem; + arrsize_t index_first_elem = 0; + arrsize_t index_last_elem = arrsize - 1; - UNUSED(hasnan); - if (index_first_elem <= k) { - qselect_, T>( - arr, - k, - index_first_elem, - arrsize - 1, - 2 * (arrsize_t)log2(size_without_nans)); - } - } - else { - arrsize_t indx_last_elem = arrsize - 1; - if constexpr (std::is_floating_point_v) { - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + if constexpr (std::is_floating_point_v) { + if (UNLIKELY(hasnan)) { + if constexpr (descending) { + index_first_elem = move_nans_to_start_of_array(arr, arrsize); + } + else { + index_last_elem = move_nans_to_end_of_array(arr, arrsize); } - } - UNUSED(hasnan); - if (indx_last_elem >= k) { - qselect_, T>( - arr, - k, - 0, - indx_last_elem, - 2 * (arrsize_t)log2(indx_last_elem)); } } + + UNUSED(hasnan); + if (index_first_elem <= k && index_last_elem >= k) { + qselect_(arr, + k, + index_first_elem, + index_last_elem, + 2 * (arrsize_t)log2(arrsize)); + } } // Partial sort methods: -template -X86_SIMD_SORT_INLINE void xss_partial_qsort( - T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) +template +X86_SIMD_SORT_INLINE void +xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { - xss_qselect(arr, k - 1, arrsize, hasnan, descending); - xss_qsort(arr, k - 1, hasnan, descending); + xss_qselect(arr, k - 1, arrsize, hasnan); + xss_qsort(arr, k - 1, hasnan); } #define DEFINE_METHODS(ISA, VTYPE) \ @@ -700,7 +680,10 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort( bool hasnan = false, \ bool descending = false) \ { \ - xss_qsort(arr, size, hasnan, descending); \ + if (descending) { xss_qsort(arr, size, hasnan); } \ + else { \ + xss_qsort(arr, size, hasnan); \ + } \ } \ template \ X86_SIMD_SORT_INLINE void ISA##_qselect(T *arr, \ @@ -709,7 +692,10 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort( bool hasnan = false, \ bool descending = false) \ { \ - xss_qselect(arr, k, size, hasnan, descending); \ + if (descending) { xss_qselect(arr, k, size, hasnan); } \ + else { \ + xss_qselect(arr, k, size, hasnan); \ + } \ } \ template \ X86_SIMD_SORT_INLINE void ISA##_partial_qsort(T *arr, \ @@ -718,7 +704,12 @@ X86_SIMD_SORT_INLINE void xss_partial_qsort( bool hasnan = false, \ bool descending = false) \ { \ - xss_partial_qsort(arr, k, size, hasnan, descending); \ + if (descending) { \ + xss_partial_qsort(arr, k, size, hasnan); \ + } \ + else { \ + xss_partial_qsort(arr, k, size, hasnan); \ + } \ } DEFINE_METHODS(avx512, zmm_vector) From f19214a74109b2e2ae56db6c7e06d798ab37ea37 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 26 Mar 2024 12:09:54 -0700 Subject: [PATCH 16/18] formatting after rebase --- src/xss-common-qsort.h | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 4aa81260..d98a32b6 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -364,12 +364,14 @@ X86_SIMD_SORT_INLINE arrsize_t partition_unrolled(type_t *arr, type_t *biggest) { if constexpr (num_unroll == 0) { - return partition(arr, left, right, pivot, smallest, biggest); + return partition( + arr, left, right, pivot, smallest, biggest); } /* Use regular partition for smaller arrays */ if (right - left < 3 * num_unroll * vtype::numlanes) { - return partition(arr, left, right, pivot, smallest, biggest); + return partition( + arr, left, right, pivot, smallest, biggest); } /* make array length divisible by vtype::numlanes, shortening the array */ @@ -547,9 +549,10 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - arrsize_t pivot_index - = partition_unrolled( - arr, left, right + 1, pivot, &smallest, &biggest); + arrsize_t pivot_index = partition_unrolled( + arr, left, right + 1, pivot, &smallest, &biggest); if (pivot_result.result == pivot_result_t::Only2Values) { return; } @@ -590,9 +593,10 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - arrsize_t pivot_index - = partition_unrolled( - arr, left, right + 1, pivot, &smallest, &biggest); + arrsize_t pivot_index = partition_unrolled( + arr, left, right + 1, pivot, &smallest, &biggest); type_t leftmostValue = comparator::leftmost(smallest, biggest); type_t rightmostValue = comparator::rightmost(smallest, biggest); From b629ba5ef5dcbe5f6e69baa44579347f3a31ebe5 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 26 Mar 2024 13:05:10 -0700 Subject: [PATCH 17/18] clang-format --- tests/test-qsort-common.h | 8 ++++---- tests/test-qsort.cpp | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test-qsort-common.h b/tests/test-qsort-common.h index 998537be..4fdb87fc 100644 --- a/tests/test-qsort-common.h +++ b/tests/test-qsort-common.h @@ -69,16 +69,16 @@ void IS_ARR_PARTITIONED(std::vector arr, } // ( 2) Elements to the left of k should be atmost arr[k] if (k >= 1) { - T max_left = *std::max_element( - arr.begin(), arr.begin() + k - 1, cmp_less); + T max_left + = *std::max_element(arr.begin(), arr.begin() + k - 1, cmp_less); if (!cmp_geq(arr[k], max_left)) { REPORT_FAIL("incorrect left partition", arr.size(), type, k); } } // 3) Elements to the right of k should be atleast arr[k] if (k != (size_t)(arr.size() - 1)) { - T min_right = *std::min_element( - arr.begin() + k + 1, arr.end(), cmp_less); + T min_right + = *std::min_element(arr.begin() + k + 1, arr.end(), cmp_less); if (!cmp_leq(arr[k], min_right)) { REPORT_FAIL("incorrect right partition", arr.size(), type, k); } diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 72e173a4..5d4ba587 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -55,7 +55,7 @@ TYPED_TEST_P(simdsort, test_qsort_descending) bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { std::vector basearr = get_array(type, size); - + // Descending order std::vector arr = basearr; std::vector sortedarr = arr; From b0d09297f6bfae5e48aa54c340c6c6b572a63103 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Wed, 27 Mar 2024 15:04:54 -0700 Subject: [PATCH 18/18] Condense AscendingComparator and DescendingComparator into one class --- src/avx512-16bit-qsort.hpp | 8 +-- src/avx512fp16-16bit-qsort.hpp | 8 +-- src/xss-common-comparators.hpp | 116 ++++++++++++--------------------- src/xss-common-qsort.h | 8 +-- src/xss-pivot-selection.hpp | 2 +- 5 files changed, 56 insertions(+), 86 deletions(-) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 3ed6d447..15c7c91e 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -574,11 +574,11 @@ X86_SIMD_SORT_INLINE void avx512_qsort_fp16(uint16_t *arr, arr, arrsize); } if (descending) { - qsort_, uint16_t>( + qsort_, uint16_t>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } else { - qsort_, uint16_t>( + qsort_, uint16_t>( arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } replace_inf_with_nan(arr, arrsize, nan_count, descending); @@ -599,7 +599,7 @@ X86_SIMD_SORT_INLINE void avx512_qselect_fp16(uint16_t *arr, } if (indx_last_elem >= k) { if (descending) { - qselect_, uint16_t>( + qselect_, uint16_t>( arr, k, 0, @@ -607,7 +607,7 @@ X86_SIMD_SORT_INLINE void avx512_qselect_fp16(uint16_t *arr, 2 * (arrsize_t)log2(indx_last_elem)); } else { - qselect_, uint16_t>( + qselect_, uint16_t>( arr, k, 0, diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 0100cfea..130e28a8 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -208,8 +208,8 @@ avx512_qsort(_Float16 *arr, arrsize_t arrsize, bool hasnan) using vtype = zmm_vector<_Float16>; using comparator = typename std::conditional, - AscendingComparator>::type; + Comparator, + Comparator>::type; if (arrsize > 1) { arrsize_t nan_count = 0; @@ -231,8 +231,8 @@ avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) using vtype = zmm_vector<_Float16>; using comparator = typename std::conditional, - AscendingComparator>::type; + Comparator, + Comparator>::type; arrsize_t index_first_elem = 0; arrsize_t index_last_elem = arrsize - 1; diff --git a/src/xss-common-comparators.hpp b/src/xss-common-comparators.hpp index 48a1d020..bd742cd4 100644 --- a/src/xss-common-comparators.hpp +++ b/src/xss-common-comparators.hpp @@ -34,8 +34,8 @@ type_t next_value(type_t value) template X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); -template -struct AscendingComparator { +template +struct Comparator { using reg_t = typename vtype::reg_t; using opmask_t = typename vtype::opmask_t; using type_t = typename vtype::type_t; @@ -43,115 +43,85 @@ struct AscendingComparator { X86_SIMD_SORT_FINLINE bool STDSortComparator(const type_t &a, const type_t &b) { - return comparison_func(a, b); + if constexpr (descend) { return comparison_func(b, a); } + else { + return comparison_func(a, b); + } } X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b) { - return vtype::ge(a, b); + if constexpr (descend) { return vtype::ge(b, a); } + else { + return vtype::ge(a, b); + } } X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b) { - ::COEX(a, b); + if constexpr (descend) { ::COEX(b, a); } + else { + ::COEX(a, b); + } } // Returns a vector of values that would be sorted as far right as possible // For ascending order, this is the maximum possible value X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec() { - return vtype::zmm_max(); + if constexpr (descend) { return vtype::zmm_min(); } + else { + return vtype::zmm_max(); + } } // Returns the value that would be leftmost of the two when sorted // For ascending order, that is the smaller value X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t larger) { - UNUSED(larger); - return smaller; + if constexpr (descend) { + UNUSED(smaller); + return larger; + } + else { + UNUSED(larger); + return smaller; + } } // Returns the value that would be rightmost of the two when sorted // For ascending order, that is the larger value X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t larger) { - UNUSED(smaller); - return larger; + if constexpr (descend) { + UNUSED(larger); + return smaller; + } + else { + UNUSED(smaller); + return larger; + } } // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample // Try just doing the next largest value greater than this seemingly very common value to seperate them out X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median) { - return next_value(median); + if constexpr (descend) { return median; } + else { + return next_value(median); + } } // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median) { - return median; - } -}; - -template -struct DescendingComparator { - using reg_t = typename vtype::reg_t; - using opmask_t = typename vtype::opmask_t; - using type_t = typename vtype::type_t; - - X86_SIMD_SORT_FINLINE bool STDSortComparator(const type_t &a, - const type_t &b) - { - return comparison_func(b, a); - } - - X86_SIMD_SORT_FINLINE opmask_t PartitionComparator(reg_t a, reg_t b) - { - return vtype::ge(b, a); - } - - X86_SIMD_SORT_FINLINE void COEX(reg_t &a, reg_t &b) - { - ::COEX(b, a); - } - - // Returns a vector of values that would be sorted as far right as possible - // For descending order, this is the minimum possible value - X86_SIMD_SORT_FINLINE reg_t rightmostPossibleVec() - { - return vtype::zmm_min(); - } - - // Returns the value that would be leftmost of the two when sorted - // For descending order, that is the larger value - X86_SIMD_SORT_FINLINE type_t leftmost(type_t smaller, type_t bigger) - { - UNUSED(smaller); - return bigger; - } - - // Returns the value that would be rightmost of the two when sorted - // For descending order, that is the smaller value - X86_SIMD_SORT_FINLINE type_t rightmost(type_t smaller, type_t bigger) - { - UNUSED(bigger); - return smaller; - } - - // If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample - // Thus, median probably is a fine pivot, since it will move all of this common value into its own partition - X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsSmallest(type_t median) - { - return median; - } - - // If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample - // Try just doing the next smallest value less than this seemingly very common value to seperate them out - X86_SIMD_SORT_FINLINE type_t choosePivotMedianIsLargest(type_t median) - { - return prev_value(median); + if constexpr (descend) { return prev_value(median); } + else { + return median; + } } }; -#endif // XSS_COMMON_COMPARATORS \ No newline at end of file +#endif // XSS_COMMON_COMPARATORS diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index d98a32b6..02522b50 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -615,8 +615,8 @@ X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan) { using comparator = typename std::conditional, - AscendingComparator>::type; + Comparator, + Comparator>::type; if (arrsize > 1) { arrsize_t nan_count = 0; @@ -641,8 +641,8 @@ xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { using comparator = typename std::conditional, - AscendingComparator>::type; + Comparator, + Comparator>::type; arrsize_t index_first_elem = 0; arrsize_t index_last_elem = arrsize - 1; diff --git a/src/xss-pivot-selection.hpp b/src/xss-pivot-selection.hpp index 5d955616..6ce0b887 100644 --- a/src/xss-pivot-selection.hpp +++ b/src/xss-pivot-selection.hpp @@ -115,7 +115,7 @@ get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right) // Sort the samples // Note that this intentionally uses the AscendingComparator // instead of the provided comparator - sort_vectors, numVecs>(vecs); + sort_vectors, numVecs>(vecs); type_t samples[N]; for (int i = 0; i < numVecs; i++) {