diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index 6e0743d6..0c819946 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -290,10 +290,11 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) } template -static void -qselect_16bit_(type_t *arr, int64_t pos, - int64_t left, int64_t right, - int64_t max_iters) +static void qselect_16bit_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index e9e97aa1..c4061ddf 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -648,7 +648,7 @@ qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) type_t pivot = get_pivot_32bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) qsort_32bit_(arr, left, pivot_index - 1, max_iters - 1); @@ -657,10 +657,11 @@ qsort_32bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) } template -static void -qselect_32bit_(type_t *arr, int64_t pos, - int64_t left, int64_t right, - int64_t max_iters) +static void qselect_32bit_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress @@ -680,7 +681,7 @@ qselect_32bit_(type_t *arr, int64_t pos, type_t pivot = get_pivot_32bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if ((pivot != smallest) && (pos < pivot_index)) qselect_32bit_(arr, pos, left, pivot_index - 1, max_iters - 1); diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index dfb5376f..1cbcd388 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -172,6 +172,161 @@ X86_SIMD_SORT_INLINE void bitonic_merge_sixteen_zmm_64bit(zmm_t *zmm) zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); } +template +X86_SIMD_SORT_INLINE void bitonic_merge_32_zmm_64bit(zmm_t *zmm) +{ + const __m512i rev_index = _mm512_set_epi64(NETWORK_64BIT_2); + zmm_t zmm16r = vtype::permutexvar(rev_index, zmm[16]); + zmm_t zmm17r = vtype::permutexvar(rev_index, zmm[17]); + zmm_t zmm18r = vtype::permutexvar(rev_index, zmm[18]); + zmm_t zmm19r = vtype::permutexvar(rev_index, zmm[19]); + zmm_t zmm20r = vtype::permutexvar(rev_index, zmm[20]); + zmm_t zmm21r = vtype::permutexvar(rev_index, zmm[21]); + zmm_t zmm22r = vtype::permutexvar(rev_index, zmm[22]); + zmm_t zmm23r = vtype::permutexvar(rev_index, zmm[23]); + zmm_t zmm24r = vtype::permutexvar(rev_index, zmm[24]); + zmm_t zmm25r = vtype::permutexvar(rev_index, zmm[25]); + zmm_t zmm26r = vtype::permutexvar(rev_index, zmm[26]); + zmm_t zmm27r = vtype::permutexvar(rev_index, zmm[27]); + zmm_t zmm28r = vtype::permutexvar(rev_index, zmm[28]); + zmm_t zmm29r = vtype::permutexvar(rev_index, zmm[29]); + zmm_t zmm30r = vtype::permutexvar(rev_index, zmm[30]); + zmm_t zmm31r = vtype::permutexvar(rev_index, zmm[31]); + zmm_t zmm_t1 = vtype::min(zmm[0], zmm31r); + zmm_t zmm_t2 = vtype::min(zmm[1], zmm30r); + zmm_t zmm_t3 = vtype::min(zmm[2], zmm29r); + zmm_t zmm_t4 = vtype::min(zmm[3], zmm28r); + zmm_t zmm_t5 = vtype::min(zmm[4], zmm27r); + zmm_t zmm_t6 = vtype::min(zmm[5], zmm26r); + zmm_t zmm_t7 = vtype::min(zmm[6], zmm25r); + zmm_t zmm_t8 = vtype::min(zmm[7], zmm24r); + zmm_t zmm_t9 = vtype::min(zmm[8], zmm23r); + zmm_t zmm_t10 = vtype::min(zmm[9], zmm22r); + zmm_t zmm_t11 = vtype::min(zmm[10], zmm21r); + zmm_t zmm_t12 = vtype::min(zmm[11], zmm20r); + zmm_t zmm_t13 = vtype::min(zmm[12], zmm19r); + zmm_t zmm_t14 = vtype::min(zmm[13], zmm18r); + zmm_t zmm_t15 = vtype::min(zmm[14], zmm17r); + zmm_t zmm_t16 = vtype::min(zmm[15], zmm16r); + zmm_t zmm_t17 = vtype::permutexvar(rev_index, vtype::max(zmm[15], zmm16r)); + zmm_t zmm_t18 = vtype::permutexvar(rev_index, vtype::max(zmm[14], zmm17r)); + zmm_t zmm_t19 = vtype::permutexvar(rev_index, vtype::max(zmm[13], zmm18r)); + zmm_t zmm_t20 = vtype::permutexvar(rev_index, vtype::max(zmm[12], zmm19r)); + zmm_t zmm_t21 = vtype::permutexvar(rev_index, vtype::max(zmm[11], zmm20r)); + zmm_t zmm_t22 = vtype::permutexvar(rev_index, vtype::max(zmm[10], zmm21r)); + zmm_t zmm_t23 = vtype::permutexvar(rev_index, vtype::max(zmm[9], zmm22r)); + zmm_t zmm_t24 = vtype::permutexvar(rev_index, vtype::max(zmm[8], zmm23r)); + zmm_t zmm_t25 = vtype::permutexvar(rev_index, vtype::max(zmm[7], zmm24r)); + zmm_t zmm_t26 = vtype::permutexvar(rev_index, vtype::max(zmm[6], zmm25r)); + zmm_t zmm_t27 = vtype::permutexvar(rev_index, vtype::max(zmm[5], zmm26r)); + zmm_t zmm_t28 = vtype::permutexvar(rev_index, vtype::max(zmm[4], zmm27r)); + zmm_t zmm_t29 = vtype::permutexvar(rev_index, vtype::max(zmm[3], zmm28r)); + zmm_t zmm_t30 = vtype::permutexvar(rev_index, vtype::max(zmm[2], zmm29r)); + zmm_t zmm_t31 = vtype::permutexvar(rev_index, vtype::max(zmm[1], zmm30r)); + zmm_t zmm_t32 = vtype::permutexvar(rev_index, vtype::max(zmm[0], zmm31r)); + // Recusive half clear 16 zmm regs + COEX(zmm_t1, zmm_t9); + COEX(zmm_t2, zmm_t10); + COEX(zmm_t3, zmm_t11); + COEX(zmm_t4, zmm_t12); + COEX(zmm_t5, zmm_t13); + COEX(zmm_t6, zmm_t14); + COEX(zmm_t7, zmm_t15); + COEX(zmm_t8, zmm_t16); + COEX(zmm_t17, zmm_t25); + COEX(zmm_t18, zmm_t26); + COEX(zmm_t19, zmm_t27); + COEX(zmm_t20, zmm_t28); + COEX(zmm_t21, zmm_t29); + COEX(zmm_t22, zmm_t30); + COEX(zmm_t23, zmm_t31); + COEX(zmm_t24, zmm_t32); + // + COEX(zmm_t1, zmm_t5); + COEX(zmm_t2, zmm_t6); + COEX(zmm_t3, zmm_t7); + COEX(zmm_t4, zmm_t8); + COEX(zmm_t9, zmm_t13); + COEX(zmm_t10, zmm_t14); + COEX(zmm_t11, zmm_t15); + COEX(zmm_t12, zmm_t16); + COEX(zmm_t17, zmm_t21); + COEX(zmm_t18, zmm_t22); + COEX(zmm_t19, zmm_t23); + COEX(zmm_t20, zmm_t24); + COEX(zmm_t25, zmm_t29); + COEX(zmm_t26, zmm_t30); + COEX(zmm_t27, zmm_t31); + COEX(zmm_t28, zmm_t32); + // + COEX(zmm_t1, zmm_t3); + COEX(zmm_t2, zmm_t4); + COEX(zmm_t5, zmm_t7); + COEX(zmm_t6, zmm_t8); + COEX(zmm_t9, zmm_t11); + COEX(zmm_t10, zmm_t12); + COEX(zmm_t13, zmm_t15); + COEX(zmm_t14, zmm_t16); + COEX(zmm_t17, zmm_t19); + COEX(zmm_t18, zmm_t20); + COEX(zmm_t21, zmm_t23); + COEX(zmm_t22, zmm_t24); + COEX(zmm_t25, zmm_t27); + COEX(zmm_t26, zmm_t28); + COEX(zmm_t29, zmm_t31); + COEX(zmm_t30, zmm_t32); + // + COEX(zmm_t1, zmm_t2); + COEX(zmm_t3, zmm_t4); + COEX(zmm_t5, zmm_t6); + COEX(zmm_t7, zmm_t8); + COEX(zmm_t9, zmm_t10); + COEX(zmm_t11, zmm_t12); + COEX(zmm_t13, zmm_t14); + COEX(zmm_t15, zmm_t16); + COEX(zmm_t17, zmm_t18); + COEX(zmm_t19, zmm_t20); + COEX(zmm_t21, zmm_t22); + COEX(zmm_t23, zmm_t24); + COEX(zmm_t25, zmm_t26); + COEX(zmm_t27, zmm_t28); + COEX(zmm_t29, zmm_t30); + COEX(zmm_t31, zmm_t32); + // + zmm[0] = bitonic_merge_zmm_64bit(zmm_t1); + zmm[1] = bitonic_merge_zmm_64bit(zmm_t2); + zmm[2] = bitonic_merge_zmm_64bit(zmm_t3); + zmm[3] = bitonic_merge_zmm_64bit(zmm_t4); + zmm[4] = bitonic_merge_zmm_64bit(zmm_t5); + zmm[5] = bitonic_merge_zmm_64bit(zmm_t6); + zmm[6] = bitonic_merge_zmm_64bit(zmm_t7); + zmm[7] = bitonic_merge_zmm_64bit(zmm_t8); + zmm[8] = bitonic_merge_zmm_64bit(zmm_t9); + zmm[9] = bitonic_merge_zmm_64bit(zmm_t10); + zmm[10] = bitonic_merge_zmm_64bit(zmm_t11); + zmm[11] = bitonic_merge_zmm_64bit(zmm_t12); + zmm[12] = bitonic_merge_zmm_64bit(zmm_t13); + zmm[13] = bitonic_merge_zmm_64bit(zmm_t14); + zmm[14] = bitonic_merge_zmm_64bit(zmm_t15); + zmm[15] = bitonic_merge_zmm_64bit(zmm_t16); + zmm[16] = bitonic_merge_zmm_64bit(zmm_t17); + zmm[17] = bitonic_merge_zmm_64bit(zmm_t18); + zmm[18] = bitonic_merge_zmm_64bit(zmm_t19); + zmm[19] = bitonic_merge_zmm_64bit(zmm_t20); + zmm[20] = bitonic_merge_zmm_64bit(zmm_t21); + zmm[21] = bitonic_merge_zmm_64bit(zmm_t22); + zmm[22] = bitonic_merge_zmm_64bit(zmm_t23); + zmm[23] = bitonic_merge_zmm_64bit(zmm_t24); + zmm[24] = bitonic_merge_zmm_64bit(zmm_t25); + zmm[25] = bitonic_merge_zmm_64bit(zmm_t26); + zmm[26] = bitonic_merge_zmm_64bit(zmm_t27); + zmm[27] = bitonic_merge_zmm_64bit(zmm_t28); + zmm[28] = bitonic_merge_zmm_64bit(zmm_t29); + zmm[29] = bitonic_merge_zmm_64bit(zmm_t30); + zmm[30] = bitonic_merge_zmm_64bit(zmm_t31); + zmm[31] = bitonic_merge_zmm_64bit(zmm_t32); +} + template X86_SIMD_SORT_INLINE void sort_8_64bit(type_t *arr, int32_t N) { @@ -371,6 +526,200 @@ X86_SIMD_SORT_INLINE void sort_128_64bit(type_t *arr, int32_t N) vtype::mask_storeu(arr + 120, load_mask8, zmm[15]); } +template +X86_SIMD_SORT_INLINE void sort_256_64bit(type_t *arr, int32_t N) +{ + if (N <= 128) { + sort_128_64bit(arr, N); + return; + } + using zmm_t = typename vtype::zmm_t; + using opmask_t = typename vtype::opmask_t; + zmm_t zmm[32]; + zmm[0] = vtype::loadu(arr); + zmm[1] = vtype::loadu(arr + 8); + zmm[2] = vtype::loadu(arr + 16); + zmm[3] = vtype::loadu(arr + 24); + zmm[4] = vtype::loadu(arr + 32); + zmm[5] = vtype::loadu(arr + 40); + zmm[6] = vtype::loadu(arr + 48); + zmm[7] = vtype::loadu(arr + 56); + zmm[8] = vtype::loadu(arr + 64); + zmm[9] = vtype::loadu(arr + 72); + zmm[10] = vtype::loadu(arr + 80); + zmm[11] = vtype::loadu(arr + 88); + zmm[12] = vtype::loadu(arr + 96); + zmm[13] = vtype::loadu(arr + 104); + zmm[14] = vtype::loadu(arr + 112); + zmm[15] = vtype::loadu(arr + 120); + zmm[0] = sort_zmm_64bit(zmm[0]); + zmm[1] = sort_zmm_64bit(zmm[1]); + zmm[2] = sort_zmm_64bit(zmm[2]); + zmm[3] = sort_zmm_64bit(zmm[3]); + zmm[4] = sort_zmm_64bit(zmm[4]); + zmm[5] = sort_zmm_64bit(zmm[5]); + zmm[6] = sort_zmm_64bit(zmm[6]); + zmm[7] = sort_zmm_64bit(zmm[7]); + zmm[8] = sort_zmm_64bit(zmm[8]); + zmm[9] = sort_zmm_64bit(zmm[9]); + zmm[10] = sort_zmm_64bit(zmm[10]); + zmm[11] = sort_zmm_64bit(zmm[11]); + zmm[12] = sort_zmm_64bit(zmm[12]); + zmm[13] = sort_zmm_64bit(zmm[13]); + zmm[14] = sort_zmm_64bit(zmm[14]); + zmm[15] = sort_zmm_64bit(zmm[15]); + opmask_t load_mask1 = 0xFF, load_mask2 = 0xFF; + opmask_t load_mask3 = 0xFF, load_mask4 = 0xFF; + opmask_t load_mask5 = 0xFF, load_mask6 = 0xFF; + opmask_t load_mask7 = 0xFF, load_mask8 = 0xFF; + opmask_t load_mask9 = 0xFF, load_mask10 = 0xFF; + opmask_t load_mask11 = 0xFF, load_mask12 = 0xFF; + opmask_t load_mask13 = 0xFF, load_mask14 = 0xFF; + opmask_t load_mask15 = 0xFF, load_mask16 = 0xFF; + if (N != 256) { + uint64_t combined_mask; + if (N < 192) { + combined_mask = (0x1ull << (N - 128)) - 0x1ull; + load_mask1 = (combined_mask)&0xFF; + load_mask2 = (combined_mask >> 8) & 0xFF; + load_mask3 = (combined_mask >> 16) & 0xFF; + load_mask4 = (combined_mask >> 24) & 0xFF; + load_mask5 = (combined_mask >> 32) & 0xFF; + load_mask6 = (combined_mask >> 40) & 0xFF; + load_mask7 = (combined_mask >> 48) & 0xFF; + load_mask8 = (combined_mask >> 56) & 0xFF; + load_mask9 = 0x00; + load_mask10 = 0x0; + load_mask11 = 0x00; + load_mask12 = 0x00; + load_mask13 = 0x00; + load_mask14 = 0x00; + load_mask15 = 0x00; + load_mask16 = 0x00; + } + else { + combined_mask = (0x1ull << (N - 192)) - 0x1ull; + load_mask9 = (combined_mask)&0xFF; + load_mask10 = (combined_mask >> 8) & 0xFF; + load_mask11 = (combined_mask >> 16) & 0xFF; + load_mask12 = (combined_mask >> 24) & 0xFF; + load_mask13 = (combined_mask >> 32) & 0xFF; + load_mask14 = (combined_mask >> 40) & 0xFF; + load_mask15 = (combined_mask >> 48) & 0xFF; + load_mask16 = (combined_mask >> 56) & 0xFF; + } + } + zmm[16] = vtype::mask_loadu(vtype::zmm_max(), load_mask1, arr + 128); + zmm[17] = vtype::mask_loadu(vtype::zmm_max(), load_mask2, arr + 136); + zmm[18] = vtype::mask_loadu(vtype::zmm_max(), load_mask3, arr + 144); + zmm[19] = vtype::mask_loadu(vtype::zmm_max(), load_mask4, arr + 152); + zmm[20] = vtype::mask_loadu(vtype::zmm_max(), load_mask5, arr + 160); + zmm[21] = vtype::mask_loadu(vtype::zmm_max(), load_mask6, arr + 168); + zmm[22] = vtype::mask_loadu(vtype::zmm_max(), load_mask7, arr + 176); + zmm[23] = vtype::mask_loadu(vtype::zmm_max(), load_mask8, arr + 184); + if (N < 192) { + zmm[24] = vtype::zmm_max(); + zmm[25] = vtype::zmm_max(); + zmm[26] = vtype::zmm_max(); + zmm[27] = vtype::zmm_max(); + zmm[28] = vtype::zmm_max(); + zmm[29] = vtype::zmm_max(); + zmm[30] = vtype::zmm_max(); + zmm[31] = vtype::zmm_max(); + } + else { + zmm[24] = vtype::mask_loadu(vtype::zmm_max(), load_mask9, arr + 192); + zmm[25] = vtype::mask_loadu(vtype::zmm_max(), load_mask10, arr + 200); + zmm[26] = vtype::mask_loadu(vtype::zmm_max(), load_mask11, arr + 208); + zmm[27] = vtype::mask_loadu(vtype::zmm_max(), load_mask12, arr + 216); + zmm[28] = vtype::mask_loadu(vtype::zmm_max(), load_mask13, arr + 224); + zmm[29] = vtype::mask_loadu(vtype::zmm_max(), load_mask14, arr + 232); + zmm[30] = vtype::mask_loadu(vtype::zmm_max(), load_mask15, arr + 240); + zmm[31] = vtype::mask_loadu(vtype::zmm_max(), load_mask16, arr + 248); + } + zmm[16] = sort_zmm_64bit(zmm[16]); + zmm[17] = sort_zmm_64bit(zmm[17]); + zmm[18] = sort_zmm_64bit(zmm[18]); + zmm[19] = sort_zmm_64bit(zmm[19]); + zmm[20] = sort_zmm_64bit(zmm[20]); + zmm[21] = sort_zmm_64bit(zmm[21]); + zmm[22] = sort_zmm_64bit(zmm[22]); + zmm[23] = sort_zmm_64bit(zmm[23]); + zmm[24] = sort_zmm_64bit(zmm[24]); + zmm[25] = sort_zmm_64bit(zmm[25]); + zmm[26] = sort_zmm_64bit(zmm[26]); + zmm[27] = sort_zmm_64bit(zmm[27]); + zmm[28] = sort_zmm_64bit(zmm[28]); + zmm[29] = sort_zmm_64bit(zmm[29]); + zmm[30] = sort_zmm_64bit(zmm[30]); + zmm[31] = sort_zmm_64bit(zmm[31]); + bitonic_merge_two_zmm_64bit(zmm[0], zmm[1]); + bitonic_merge_two_zmm_64bit(zmm[2], zmm[3]); + bitonic_merge_two_zmm_64bit(zmm[4], zmm[5]); + bitonic_merge_two_zmm_64bit(zmm[6], zmm[7]); + bitonic_merge_two_zmm_64bit(zmm[8], zmm[9]); + bitonic_merge_two_zmm_64bit(zmm[10], zmm[11]); + bitonic_merge_two_zmm_64bit(zmm[12], zmm[13]); + bitonic_merge_two_zmm_64bit(zmm[14], zmm[15]); + bitonic_merge_two_zmm_64bit(zmm[16], zmm[17]); + bitonic_merge_two_zmm_64bit(zmm[18], zmm[19]); + bitonic_merge_two_zmm_64bit(zmm[20], zmm[21]); + bitonic_merge_two_zmm_64bit(zmm[22], zmm[23]); + bitonic_merge_two_zmm_64bit(zmm[24], zmm[25]); + bitonic_merge_two_zmm_64bit(zmm[26], zmm[27]); + bitonic_merge_two_zmm_64bit(zmm[28], zmm[29]); + bitonic_merge_two_zmm_64bit(zmm[30], zmm[31]); + bitonic_merge_four_zmm_64bit(zmm); + bitonic_merge_four_zmm_64bit(zmm + 4); + bitonic_merge_four_zmm_64bit(zmm + 8); + bitonic_merge_four_zmm_64bit(zmm + 12); + bitonic_merge_four_zmm_64bit(zmm + 16); + bitonic_merge_four_zmm_64bit(zmm + 20); + bitonic_merge_four_zmm_64bit(zmm + 24); + bitonic_merge_four_zmm_64bit(zmm + 28); + bitonic_merge_eight_zmm_64bit(zmm); + bitonic_merge_eight_zmm_64bit(zmm + 8); + bitonic_merge_eight_zmm_64bit(zmm + 16); + bitonic_merge_eight_zmm_64bit(zmm + 24); + bitonic_merge_sixteen_zmm_64bit(zmm); + bitonic_merge_sixteen_zmm_64bit(zmm + 16); + bitonic_merge_32_zmm_64bit(zmm); + vtype::storeu(arr, zmm[0]); + vtype::storeu(arr + 8, zmm[1]); + vtype::storeu(arr + 16, zmm[2]); + vtype::storeu(arr + 24, zmm[3]); + vtype::storeu(arr + 32, zmm[4]); + vtype::storeu(arr + 40, zmm[5]); + vtype::storeu(arr + 48, zmm[6]); + vtype::storeu(arr + 56, zmm[7]); + vtype::storeu(arr + 64, zmm[8]); + vtype::storeu(arr + 72, zmm[9]); + vtype::storeu(arr + 80, zmm[10]); + vtype::storeu(arr + 88, zmm[11]); + vtype::storeu(arr + 96, zmm[12]); + vtype::storeu(arr + 104, zmm[13]); + vtype::storeu(arr + 112, zmm[14]); + vtype::storeu(arr + 120, zmm[15]); + vtype::mask_storeu(arr + 128, load_mask1, zmm[16]); + vtype::mask_storeu(arr + 136, load_mask2, zmm[17]); + vtype::mask_storeu(arr + 144, load_mask3, zmm[18]); + vtype::mask_storeu(arr + 152, load_mask4, zmm[19]); + vtype::mask_storeu(arr + 160, load_mask5, zmm[20]); + vtype::mask_storeu(arr + 168, load_mask6, zmm[21]); + vtype::mask_storeu(arr + 176, load_mask7, zmm[22]); + vtype::mask_storeu(arr + 184, load_mask8, zmm[23]); + if (N > 192) { + vtype::mask_storeu(arr + 192, load_mask9, zmm[24]); + vtype::mask_storeu(arr + 200, load_mask10, zmm[25]); + vtype::mask_storeu(arr + 208, load_mask11, zmm[26]); + vtype::mask_storeu(arr + 216, load_mask12, zmm[27]); + vtype::mask_storeu(arr + 224, load_mask13, zmm[28]); + vtype::mask_storeu(arr + 232, load_mask14, zmm[29]); + vtype::mask_storeu(arr + 240, load_mask15, zmm[30]); + vtype::mask_storeu(arr + 248, load_mask16, zmm[31]); + } +} + template static void qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) @@ -385,15 +734,15 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) /* * Base case: use bitonic networks to sort arrays <= 128 */ - if (right + 1 - left <= 128) { - sort_128_64bit(arr + left, (int32_t)(right + 1 - left)); + if (right + 1 - left <= 256) { + sort_256_64bit(arr + left, (int32_t)(right + 1 - left)); return; } type_t pivot = get_pivot_64bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if (pivot != smallest) qsort_64bit_(arr, left, pivot_index - 1, max_iters - 1); @@ -402,10 +751,11 @@ qsort_64bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters) } template -static void -qselect_64bit_(type_t *arr, int64_t pos, - int64_t left, int64_t right, - int64_t max_iters) +static void qselect_64bit_(type_t *arr, + int64_t pos, + int64_t left, + int64_t right, + int64_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress @@ -425,7 +775,7 @@ qselect_64bit_(type_t *arr, int64_t pos, type_t pivot = get_pivot_64bit(arr, left, right); type_t smallest = vtype::type_max(); type_t biggest = vtype::type_min(); - int64_t pivot_index = partition_avx512( + int64_t pivot_index = partition_avx512_unrolled( arr, left, right + 1, pivot, &smallest, &biggest); if ((pivot != smallest) && (pos < pivot_index)) qselect_64bit_(arr, pos, left, pivot_index - 1, max_iters - 1); diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index 5b6591f0..0e0ad818 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -95,7 +95,8 @@ void avx512_qselect(T *arr, int64_t k, int64_t arrsize); void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize); template -inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize) { +inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize) +{ avx512_qselect(arr, k - 1, arrsize); avx512_qsort(arr, k - 1); } @@ -259,4 +260,123 @@ static inline int64_t partition_avx512(type_t *arr, *biggest = vtype::reducemax(max_vec); return l_store; } + +template +static inline int64_t partition_avx512_unrolled(type_t *arr, + int64_t left, + int64_t right, + type_t pivot, + type_t *smallest, + type_t *biggest) +{ + if (right - left <= 2 * num_unroll * vtype::numlanes) { + return partition_avx512( + arr, left, right, pivot, smallest, biggest); + } + /* make array length divisible by 8*vtype::numlanes , shortening the array */ + for (int32_t i = ((right - left) % (num_unroll * vtype::numlanes)); i > 0; + --i) { + *smallest = std::min(*smallest, arr[left], comparison_func); + *biggest = std::max(*biggest, arr[left], comparison_func); + if (!comparison_func(arr[left], pivot)) { + std::swap(arr[left], arr[--right]); + } + else { + ++left; + } + } + + if (left == right) + return left; /* less than vtype::numlanes elements in the array */ + + using zmm_t = typename vtype::zmm_t; + zmm_t pivot_vec = vtype::set1(pivot); + zmm_t min_vec = vtype::set1(*smallest); + zmm_t max_vec = vtype::set1(*biggest); + + // We will now have atleast 16 registers worth of data to process: + // left and right vtype::numlanes values are partitioned at the end + zmm_t vec_left[num_unroll], vec_right[num_unroll]; +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + vec_left[ii] = vtype::loadu(arr + left + vtype::numlanes * ii); + vec_right[ii] = vtype::loadu( + arr + (right - vtype::numlanes * (num_unroll - ii))); + } + // store points of the vectors + int64_t r_store = right - vtype::numlanes; + int64_t l_store = left; + // indices for loading the elements + left += num_unroll * vtype::numlanes; + right -= num_unroll * vtype::numlanes; + while (right - left != 0) { + zmm_t curr_vec[num_unroll]; + /* + * if fewer elements are stored on the right side of the array, + * then next elements are loaded from the right side, + * otherwise from the left side + */ + if ((r_store + vtype::numlanes) - right < left - l_store) { + right -= num_unroll * vtype::numlanes; +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + curr_vec[ii] = vtype::loadu(arr + right + ii * vtype::numlanes); + } + } + else { +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + curr_vec[ii] = vtype::loadu(arr + left + ii * vtype::numlanes); + } + left += num_unroll * vtype::numlanes; + } +// partition the current vector and save it on both sides of the array +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_ge_pivot + = partition_vec(arr, + l_store, + r_store + vtype::numlanes, + curr_vec[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_ge_pivot); + r_store -= amount_ge_pivot; + } + } + +/* partition and save vec_left[8] and vec_right[8] */ +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_ge_pivot + = partition_vec(arr, + l_store, + r_store + vtype::numlanes, + vec_left[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_ge_pivot); + r_store -= amount_ge_pivot; + } +#pragma GCC unroll 8 + for (int ii = 0; ii < num_unroll; ++ii) { + int32_t amount_ge_pivot + = partition_vec(arr, + l_store, + r_store + vtype::numlanes, + vec_right[ii], + pivot_vec, + &min_vec, + &max_vec); + l_store += (vtype::numlanes - amount_ge_pivot); + r_store -= amount_ge_pivot; + } + *smallest = vtype::reducemin(min_vec); + *biggest = vtype::reducemax(max_vec); + return l_store; +} #endif // AVX512_QSORT_COMMON diff --git a/tests/test_keyvalue.cpp b/tests/test_keyvalue.cpp index 0cb1ca6f..b9cca554 100644 --- a/tests/test_keyvalue.cpp +++ b/tests/test_keyvalue.cpp @@ -4,8 +4,8 @@ * *******************************************/ #include "avx512-64bit-keyvaluesort.hpp" -#include "rand_array.h" #include "cpuinfo.h" +#include "rand_array.h" #include #include diff --git a/tests/test_partial_qsort.hpp b/tests/test_partial_qsort.hpp index 5c08064e..4ba5caa8 100644 --- a/tests/test_partial_qsort.hpp +++ b/tests/test_partial_qsort.hpp @@ -30,7 +30,8 @@ TYPED_TEST_P(avx512_partial_sort, test_ranges) int k = get_uniform_rand_array(1, arrsize, 1).front(); /* Sort the range and verify all the required elements match the presorted set */ - avx512_partial_qsort(psortedarr.data(), k, psortedarr.size()); + avx512_partial_qsort( + psortedarr.data(), k, psortedarr.size()); for (size_t jj = 0; jj < k; jj++) { ASSERT_EQ(sortedarr[jj], psortedarr[jj]); } diff --git a/tests/test_qselect.hpp b/tests/test_qselect.hpp index cad017bb..f0c0c242 100644 --- a/tests/test_qselect.hpp +++ b/tests/test_qselect.hpp @@ -5,7 +5,7 @@ class avx512_select : public ::testing::Test { }; TYPED_TEST_SUITE_P(avx512_select); -TYPED_TEST_P(avx512_select, test_arrsizes) +TYPED_TEST_P(avx512_select, test_random) { if (cpu_has_avx512bw()) { if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { @@ -26,7 +26,8 @@ TYPED_TEST_P(avx512_select, test_arrsizes) std::sort(sortedarr.begin(), sortedarr.end()); for (size_t k = 0; k < arr.size(); ++k) { psortedarr = arr; - avx512_qselect(psortedarr.data(), k, psortedarr.size()); + avx512_qselect( + psortedarr.data(), k, psortedarr.size()); /* index k is correct */ ASSERT_EQ(sortedarr[k], psortedarr[k]); /* Check left partition */ @@ -34,7 +35,7 @@ TYPED_TEST_P(avx512_select, test_arrsizes) ASSERT_LE(psortedarr[jj], psortedarr[k]); } /* Check right partition */ - for (size_t jj = k+1; jj < arr.size(); jj++) { + for (size_t jj = k + 1; jj < arr.size(); jj++) { ASSERT_GE(psortedarr[jj], psortedarr[k]); } psortedarr.clear(); @@ -48,4 +49,48 @@ TYPED_TEST_P(avx512_select, test_arrsizes) } } -REGISTER_TYPED_TEST_SUITE_P(avx512_select, test_arrsizes); +TYPED_TEST_P(avx512_select, test_small_range) +{ + if (cpu_has_avx512bw()) { + if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { + GTEST_SKIP() << "Skipping this test, it requires avx512_vbmi2"; + } + std::vector arrsizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + arrsizes.push_back(ii); + } + std::vector arr; + std::vector sortedarr; + std::vector psortedarr; + for (size_t ii = 0; ii < arrsizes.size(); ++ii) { + /* Random array */ + arr = get_uniform_rand_array(arrsizes[ii], 20, 1); + sortedarr = arr; + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), sortedarr.end()); + for (size_t k = 0; k < arr.size(); ++k) { + psortedarr = arr; + avx512_qselect( + psortedarr.data(), k, psortedarr.size()); + /* index k is correct */ + ASSERT_EQ(sortedarr[k], psortedarr[k]); + /* Check left partition */ + for (size_t jj = 0; jj < k; jj++) { + ASSERT_LE(psortedarr[jj], psortedarr[k]); + } + /* Check right partition */ + for (size_t jj = k + 1; jj < arr.size(); jj++) { + ASSERT_GE(psortedarr[jj], psortedarr[k]); + } + psortedarr.clear(); + } + arr.clear(); + sortedarr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw"; + } +} + +REGISTER_TYPED_TEST_SUITE_P(avx512_select, test_random, test_small_range); diff --git a/tests/test_qsort.hpp b/tests/test_qsort.hpp index 65a8eaf6..4dc8a773 100644 --- a/tests/test_qsort.hpp +++ b/tests/test_qsort.hpp @@ -10,7 +10,7 @@ class avx512_sort : public ::testing::Test { }; TYPED_TEST_SUITE_P(avx512_sort); -TYPED_TEST_P(avx512_sort, test_arrsizes) +TYPED_TEST_P(avx512_sort, test_random) { if (cpu_has_avx512bw()) { if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { @@ -29,7 +29,7 @@ TYPED_TEST_P(avx512_sort, test_arrsizes) /* Sort with std::sort for comparison */ std::sort(sortedarr.begin(), sortedarr.end()); avx512_qsort(arr.data(), arr.size()); - ASSERT_EQ(sortedarr, arr); + ASSERT_EQ(sortedarr, arr) << "Array size = " << arrsizes[ii]; arr.clear(); sortedarr.clear(); } @@ -39,4 +39,97 @@ TYPED_TEST_P(avx512_sort, test_arrsizes) } } -REGISTER_TYPED_TEST_SUITE_P(avx512_sort, test_arrsizes); +TYPED_TEST_P(avx512_sort, test_reverse) +{ + if (cpu_has_avx512bw()) { + if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { + GTEST_SKIP() << "Skipping this test, it requires avx512_vbmi2"; + } + std::vector arrsizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + arrsizes.push_back((TypeParam)(ii + 1)); + } + std::vector arr; + std::vector sortedarr; + for (size_t ii = 0; ii < arrsizes.size(); ++ii) { + /* reverse array */ + for (int jj = 0; jj < arrsizes[ii]; ++jj) { + arr.push_back((TypeParam)(arrsizes[ii] - jj)); + } + sortedarr = arr; + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), sortedarr.end()); + avx512_qsort(arr.data(), arr.size()); + ASSERT_EQ(sortedarr, arr) << "Array size = " << arrsizes[ii]; + arr.clear(); + sortedarr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw"; + } +} + +TYPED_TEST_P(avx512_sort, test_constant) +{ + if (cpu_has_avx512bw()) { + if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { + GTEST_SKIP() << "Skipping this test, it requires avx512_vbmi2"; + } + std::vector arrsizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + arrsizes.push_back((TypeParam)(ii + 1)); + } + std::vector arr; + std::vector sortedarr; + for (size_t ii = 0; ii < arrsizes.size(); ++ii) { + /* constant array */ + for (int jj = 0; jj < arrsizes[ii]; ++jj) { + arr.push_back(ii); + } + sortedarr = arr; + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), sortedarr.end()); + avx512_qsort(arr.data(), arr.size()); + ASSERT_EQ(sortedarr, arr) << "Array size = " << arrsizes[ii]; + arr.clear(); + sortedarr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw"; + } +} + +TYPED_TEST_P(avx512_sort, test_small_range) +{ + if (cpu_has_avx512bw()) { + if ((sizeof(TypeParam) == 2) && (!cpu_has_avx512_vbmi2())) { + GTEST_SKIP() << "Skipping this test, it requires avx512_vbmi2"; + } + std::vector arrsizes; + for (int64_t ii = 0; ii < 1024; ++ii) { + arrsizes.push_back((TypeParam)(ii + 1)); + } + std::vector arr; + std::vector sortedarr; + for (size_t ii = 0; ii < arrsizes.size(); ++ii) { + arr = get_uniform_rand_array(arrsizes[ii], 20, 1); + sortedarr = arr; + /* Sort with std::sort for comparison */ + std::sort(sortedarr.begin(), sortedarr.end()); + avx512_qsort(arr.data(), arr.size()); + ASSERT_EQ(sortedarr, arr) << "Array size = " << arrsizes[ii]; + arr.clear(); + sortedarr.clear(); + } + } + else { + GTEST_SKIP() << "Skipping this test, it requires avx512bw"; + } +} +REGISTER_TYPED_TEST_SUITE_P(avx512_sort, + test_random, + test_reverse, + test_constant, + test_small_range); diff --git a/tests/test_qsortfp16.cpp b/tests/test_qsortfp16.cpp index f86d77df..d6a45f7b 100644 --- a/tests/test_qsortfp16.cpp +++ b/tests/test_qsortfp16.cpp @@ -95,7 +95,8 @@ TEST(avx512_qselect_float16, test_arrsizes) std::sort(sortedarr.begin(), sortedarr.end()); for (size_t k = 0; k < arr.size(); ++k) { psortedarr = arr; - avx512_qselect<_Float16>(psortedarr.data(), k, psortedarr.size()); + avx512_qselect<_Float16>( + psortedarr.data(), k, psortedarr.size()); /* index k is correct */ ASSERT_EQ(sortedarr[k], psortedarr[k]); /* Check left partition */ @@ -103,7 +104,7 @@ TEST(avx512_qselect_float16, test_arrsizes) ASSERT_LE(psortedarr[jj], psortedarr[k]); } /* Check right partition */ - for (size_t jj = k+1; jj < arr.size(); jj++) { + for (size_t jj = k + 1; jj < arr.size(); jj++) { ASSERT_GE(psortedarr[jj], psortedarr[k]); } psortedarr.clear(); @@ -142,7 +143,8 @@ TEST(avx512_partial_qsort_float16, test_ranges) int k = get_uniform_rand_array(1, arrsize, 1).front(); /* Sort the range and verify all the required elements match the presorted set */ - avx512_partial_qsort<_Float16>(psortedarr.data(), k, psortedarr.size()); + avx512_partial_qsort<_Float16>( + psortedarr.data(), k, psortedarr.size()); for (size_t jj = 0; jj < k; jj++) { ASSERT_EQ(sortedarr[jj], psortedarr[jj]); } diff --git a/tests/test_sort.cpp b/tests/test_sort.cpp index 85a6bd8d..92ffbc35 100644 --- a/tests/test_sort.cpp +++ b/tests/test_sort.cpp @@ -1,6 +1,6 @@ -#include "test_qsort.hpp" -#include "test_qselect.hpp" #include "test_partial_qsort.hpp" +#include "test_qselect.hpp" +#include "test_qsort.hpp" using QuickSortTestTypes = testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, avx512_sort, QuickSortTestTypes); INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, avx512_select, QuickSortTestTypes); -INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, avx512_partial_sort, QuickSortTestTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(TestPrefix, + avx512_partial_sort, + QuickSortTestTypes);