Skip to content

Commit cb4358f

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #105 from r-devulap/key-value
Add key-value sort to runtime dispatch
2 parents 15c3379 + aba8371 commit cb4358f

File tree

12 files changed

+244
-19
lines changed

12 files changed

+244
-19
lines changed

benchmarks/bench-all.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@
4444
#include "bench-partial-qsort.hpp"
4545
#include "bench-qselect.hpp"
4646
#include "bench-qsort.hpp"
47+
#include "bench-keyvalue.hpp"

benchmarks/bench-keyvalue.hpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include "x86simdsort-scalar.h"
2+
3+
template <typename T, class... Args>
4+
static void scalarkvsort(benchmark::State &state, Args &&...args)
5+
{
6+
// Get args
7+
auto args_tuple = std::make_tuple(std::move(args)...);
8+
size_t arrsize = std::get<0>(args_tuple);
9+
std::string arrtype = std::get<1>(args_tuple);
10+
// set up array
11+
std::vector<T> key = get_array<T>(arrtype, arrsize);
12+
std::vector<T> val = get_array<T>("random", arrsize);
13+
std::vector<T> key_bkp = key;
14+
// benchmark
15+
for (auto _ : state) {
16+
xss::scalar::keyvalue_qsort(key.data(), val.data(), arrsize, false);
17+
state.PauseTiming();
18+
key = key_bkp;
19+
state.ResumeTiming();
20+
}
21+
}
22+
23+
template <typename T, class... Args>
24+
static void simdkvsort(benchmark::State &state, Args &&...args)
25+
{
26+
auto args_tuple = std::make_tuple(std::move(args)...);
27+
size_t arrsize = std::get<0>(args_tuple);
28+
std::string arrtype = std::get<1>(args_tuple);
29+
// set up array
30+
std::vector<T> key = get_array<T>(arrtype, arrsize);
31+
std::vector<T> val = get_array<T>("random", arrsize);
32+
std::vector<T> key_bkp = key;
33+
// benchmark
34+
for (auto _ : state) {
35+
x86simdsort::keyvalue_qsort(key.data(), val.data(), arrsize);
36+
state.PauseTiming();
37+
key = key_bkp;
38+
state.ResumeTiming();
39+
}
40+
}
41+
42+
#define BENCH_BOTH_KVSORT(type) \
43+
BENCH_SORT(simdkvsort, type) \
44+
BENCH_SORT(scalarkvsort, type)
45+
46+
BENCH_BOTH_KVSORT(uint64_t)
47+
BENCH_BOTH_KVSORT(int64_t)
48+
BENCH_BOTH_KVSORT(double)

lib/x86simdsort-internal.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ namespace avx512 {
99
// quicksort
1010
template <typename T>
1111
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
12+
// key-value quicksort
13+
template <typename T1, typename T2>
14+
XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2* val, size_t arrsize, bool hasnan = false);
1215
// quickselect
1316
template <typename T>
1417
XSS_HIDE_SYMBOL void
@@ -30,6 +33,9 @@ namespace avx2 {
3033
// quicksort
3134
template <typename T>
3235
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
36+
// key-value quicksort
37+
template <typename T1, typename T2>
38+
XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2* val, size_t arrsize, bool hasnan = false);
3339
// quickselect
3440
template <typename T>
3541
XSS_HIDE_SYMBOL void
@@ -51,6 +57,9 @@ namespace scalar {
5157
// quicksort
5258
template <typename T>
5359
XSS_HIDE_SYMBOL void qsort(T *arr, size_t arrsize, bool hasnan = false);
60+
// key-value quicksort
61+
template <typename T1, typename T2>
62+
XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2* val, size_t arrsize, bool hasnan = false);
5463
// quickselect
5564
template <typename T>
5665
XSS_HIDE_SYMBOL void

lib/x86simdsort-scalar.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,27 @@
33
#include <numeric>
44

55
namespace xss {
6+
namespace utils {
7+
/* O(1) permute array in place: stolen from
8+
* http://www.davidespataro.it/apply-a-permutation-to-a-vector */
9+
template<typename T>
10+
void apply_permutation_in_place(T* arr, std::vector<size_t> arg)
11+
{
12+
for(size_t i = 0 ; i < arg.size() ; i++) {
13+
size_t curr = i;
14+
size_t next = arg[curr];
15+
while(next != i)
16+
{
17+
std::swap(arr[curr], arr[next]);
18+
arg[curr] = curr;
19+
curr = next;
20+
next = arg[next];
21+
}
22+
arg[curr] = curr;
23+
}
24+
}
25+
} // utils
26+
627
namespace scalar {
728
template <typename T>
829
void qsort(T *arr, size_t arrsize, bool hasnan)
@@ -57,6 +78,13 @@ namespace scalar {
5778
compare_arg<T, std::less<T>>(arr));
5879
return arg;
5980
}
81+
template <typename T1, typename T2>
82+
void keyvalue_qsort(T1 *key, T2* val, size_t arrsize, bool hasnan)
83+
{
84+
std::vector<size_t> arg = argsort(key, arrsize, hasnan);
85+
utils::apply_permutation_in_place(key, arg);
86+
utils::apply_permutation_in_place(val, arg);
87+
}
6088

6189
} // namespace scalar
6290
} // namespace xss

lib/x86simdsort-skx.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// SKX specific routines:
22
#include "avx512-32bit-qsort.hpp"
3+
#include "avx512-64bit-keyvaluesort.hpp"
34
#include "avx512-64bit-argsort.hpp"
45
#include "avx512-64bit-qsort.hpp"
56
#include "x86simdsort-internal.h"
@@ -32,6 +33,14 @@
3233
return avx512_argselect(arr, k, arrsize, hasnan); \
3334
}
3435

36+
#define DEFINE_KEYVALUE_METHODS(type1, type2) \
37+
template <> \
38+
void keyvalue_qsort(type1 *key, type2* val, size_t arrsize, bool hasnan) \
39+
{ \
40+
avx512_qsort_kv(key, val, arrsize, hasnan); \
41+
} \
42+
43+
3544
namespace xss {
3645
namespace avx512 {
3746
DEFINE_ALL_METHODS(uint32_t)
@@ -40,5 +49,14 @@ namespace avx512 {
4049
DEFINE_ALL_METHODS(uint64_t)
4150
DEFINE_ALL_METHODS(int64_t)
4251
DEFINE_ALL_METHODS(double)
52+
DEFINE_KEYVALUE_METHODS(double, uint64_t)
53+
DEFINE_KEYVALUE_METHODS(double, int64_t)
54+
DEFINE_KEYVALUE_METHODS(double, double)
55+
DEFINE_KEYVALUE_METHODS(uint64_t, uint64_t)
56+
DEFINE_KEYVALUE_METHODS(uint64_t, int64_t)
57+
DEFINE_KEYVALUE_METHODS(uint64_t, double)
58+
DEFINE_KEYVALUE_METHODS(int64_t, uint64_t)
59+
DEFINE_KEYVALUE_METHODS(int64_t, int64_t)
60+
DEFINE_KEYVALUE_METHODS(int64_t, double)
4361
} // namespace avx512
4462
} // namespace xss

lib/x86simdsort.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ dispatch_requested(std::string_view cpurequested,
5151
return false;
5252
}
5353

54+
namespace x86simdsort {
55+
5456
#define CAT_(a, b) a##b
5557
#define CAT(a, b) CAT_(a, b)
5658

@@ -120,6 +122,33 @@ dispatch_requested(std::string_view cpurequested,
120122
return; \
121123
} \
122124
} \
125+
} \
126+
127+
#define DISPATCH_KEYVALUE_SORT(TYPE1, TYPE2, ISA) \
128+
static void (CAT(CAT(*internal_kv_qsort_, TYPE1), TYPE2))(TYPE1*, TYPE2*, size_t, bool) = NULL; \
129+
template <> \
130+
void keyvalue_qsort(TYPE1 *key, TYPE2* val, size_t arrsize, bool hasnan) \
131+
{ \
132+
(CAT(CAT(*internal_kv_qsort_, TYPE1), TYPE2))(key, val, arrsize, hasnan); \
133+
} \
134+
static __attribute__((constructor)) void \
135+
CAT(CAT(resolve_keyvalue_qsort_, TYPE1), TYPE2)(void) \
136+
{ \
137+
CAT(CAT(internal_kv_qsort_, TYPE1), TYPE2) = &xss::scalar::keyvalue_qsort<TYPE1, TYPE2>; \
138+
__builtin_cpu_init(); \
139+
std::string_view preferred_cpu = find_preferred_cpu(ISA); \
140+
if constexpr (dispatch_requested("avx512", ISA)) { \
141+
if (preferred_cpu.find("avx512") != std::string_view::npos) { \
142+
CAT(CAT(internal_kv_qsort_, TYPE1), TYPE2) = &xss::avx512::keyvalue_qsort<TYPE1, TYPE2>; \
143+
return; \
144+
} \
145+
} \
146+
if constexpr (dispatch_requested("avx2", ISA)) { \
147+
if (preferred_cpu.find("avx2") != std::string_view::npos) { \
148+
CAT(CAT(internal_kv_qsort_, TYPE1), TYPE2) = &xss::avx2::keyvalue_qsort<TYPE1, TYPE2>; \
149+
return; \
150+
} \
151+
} \
123152
}
124153

125154
#define ISA_LIST(...) \
@@ -128,7 +157,6 @@ dispatch_requested(std::string_view cpurequested,
128157
__VA_ARGS__ \
129158
}
130159

131-
namespace x86simdsort {
132160
#ifdef __FLT16_MAX__
133161
DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr"))
134162
DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr"))
@@ -168,4 +196,14 @@ DISPATCH_ALL(argselect,
168196
(ISA_LIST("avx512_skx")),
169197
(ISA_LIST("avx512_skx")))
170198

199+
DISPATCH_KEYVALUE_SORT(uint64_t, int64_t, (ISA_LIST("avx512_skx")))
200+
DISPATCH_KEYVALUE_SORT(uint64_t, uint64_t, (ISA_LIST("avx512_skx")))
201+
DISPATCH_KEYVALUE_SORT(uint64_t, double, (ISA_LIST("avx512_skx")))
202+
DISPATCH_KEYVALUE_SORT(int64_t, int64_t, (ISA_LIST("avx512_skx")))
203+
DISPATCH_KEYVALUE_SORT(int64_t, uint64_t, (ISA_LIST("avx512_skx")))
204+
DISPATCH_KEYVALUE_SORT(int64_t, double, (ISA_LIST("avx512_skx")))
205+
DISPATCH_KEYVALUE_SORT(double, int64_t, (ISA_LIST("avx512_skx")))
206+
DISPATCH_KEYVALUE_SORT(double, double, (ISA_LIST("avx512_skx")))
207+
DISPATCH_KEYVALUE_SORT(double, uint64_t, (ISA_LIST("avx512_skx")))
208+
171209
} // namespace x86simdsort

lib/x86simdsort.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,10 @@ template <typename T>
3434
XSS_EXPORT_SYMBOL std::vector<size_t>
3535
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
3636

37+
// argselect
38+
template <typename T1, typename T2>
39+
XSS_EXPORT_SYMBOL void
40+
keyvalue_qsort(T1 *key, T2* val, size_t arrsize, bool hasnan = false);
41+
3742
} // namespace x86simdsort
3843
#endif

run-bench.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
elif "argsort" in args.benchcompare:
3232
baseline = "scalarargsort.*" + filterb
3333
contender = "simdargsort.*" + filterb
34+
elif "keyvalue" in args.benchcompare:
35+
baseline = "scalarkvsort.*" + filterb
36+
contender = "simdkvsort.*" + filterb
3437
else:
3538
parser.print_help(sys.stderr)
3639
parser.error("ERROR: Unknown argument '%s'" % args.benchcompare)

src/avx512-64bit-keyvaluesort.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,8 @@ heapify(type1_t *keys, type2_t *indexes, arrsize_t idx, arrsize_t size)
542542
arrsize_t i = idx;
543543
while (true) {
544544
arrsize_t j = 2 * i + 1;
545-
if (j >= size || j < 0) { break; }
546-
int k = j + 1;
545+
if (j >= size) { break; }
546+
arrsize_t k = j + 1;
547547
if (k < size && keys[j] < keys[k]) { j = k; }
548548
if (keys[j] < keys[i]) { break; }
549549
std::swap(keys[i], keys[j]);
@@ -558,8 +558,9 @@ template <typename vtype1,
558558
X86_SIMD_SORT_INLINE void
559559
heap_sort(type1_t *keys, type2_t *indexes, arrsize_t size)
560560
{
561-
for (arrsize_t i = size / 2 - 1; i >= 0; i--) {
561+
for (arrsize_t i = size / 2 - 1; ; i--) {
562562
heapify<vtype1, vtype2>(keys, indexes, i, size);
563+
if (i == 0) { break; }
563564
}
564565
for (arrsize_t i = size - 1; i > 0; i--) {
565566
std::swap(keys[0], keys[i]);
@@ -614,8 +615,9 @@ X86_SIMD_SORT_INLINE void qsort_64bit_(type1_t *keys,
614615

615616
template <typename T1, typename T2>
616617
X86_SIMD_SORT_INLINE void
617-
avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize)
618+
avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
618619
{
620+
UNUSED(hasnan);
619621
if (arrsize > 1) {
620622
if constexpr (std::is_floating_point_v<T1>) {
621623
arrsize_t nan_count

tests/meson.build

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ libtests += static_library('tests_qsort',
66
include_directories : [lib, utils],
77
)
88

9+
libtests += static_library('tests_kvsort',
10+
files('test-keyvalue.cpp', ),
11+
dependencies: gtest_dep,
12+
include_directories : [lib, utils],
13+
)
14+
915
#if cancompilefp16
1016
# libtests += static_library('tests_qsortfp16',
1117
# files('test-qsortfp16.cpp', ),

0 commit comments

Comments
 (0)