Skip to content

Commit da8ce10

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #108 from r-devulap/kvsort-32bit
Support key-value sort for 32-bit dtypes
2 parents 8187e9a + cb46165 commit da8ce10

File tree

8 files changed

+132
-57
lines changed

8 files changed

+132
-57
lines changed

README.md

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,33 @@ AVX2 specific implementations, please see
88
[README](https://github.com/intel/x86-simd-sort/blob/main/src/README.md) file under
99
`src/` directory. The following routines are currently supported:
1010

11+
12+
### Sort routines on arrays
1113
```cpp
1214
x86simdsort::qsort(T* arr, size_t size, bool hasnan);
1315
x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan);
1416
x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan);
17+
```
18+
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,
19+
int32_t, double, uint64_t, int64_t]`
20+
21+
### Key-value sort routines on pairs of arrays
22+
```cpp
23+
x86simdsort::keyvalue_qsort(T1* key, T2* val, size_t size, bool hasnan);
24+
```
25+
Supported datatypes: `T1`, `T2` $\in$ `[float, uint32_t, int32_t, double,
26+
uint64_t, int64_t]` Note that keyvalue sort is not yet supported for 16-bit
27+
data types.
28+
29+
### Arg sort routines on arrays
30+
```cpp
1531
std::vector<size_t> arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan);
1632
std::vector<size_t> arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan);
1733
```
34+
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,
35+
int32_t, double, uint64_t, int64_t]`
1836

19-
### Build/Install
37+
## Build/Install
2038

2139
[meson](https://github.com/mesonbuild/meson) is the used build system. Command
2240
to build and install the library:
@@ -35,7 +53,7 @@ benchmark](https://github.com/google/benchmark) frameworks respectively. You
3553
can configure meson to build them both by using `-Dbuild_tests=true` and
3654
`-Dbuild_benchmarks=true`.
3755

38-
### Example usage
56+
## Example usage
3957

4058
```cpp
4159
#include "x86simdsort.h"
@@ -48,7 +66,7 @@ int main() {
4866
```
4967

5068

51-
### Details
69+
## Details
5270

5371
- `x86simdsort::qsort` is equivalent to `qsort` in
5472
[C](https://www.tutorialspoint.com/c_standard_library/c_function_qsort.htm)
@@ -77,7 +95,7 @@ argselect) will not use the SIMD based algorithms if they detect NAN's in the
7795
array. You can read details of all the implementations
7896
[here](https://github.com/intel/x86-simd-sort/src/README.md).
7997

80-
### Downstream projects using x86-simd-sort
98+
## Downstream projects using x86-simd-sort
8199

82100
- NumPy uses this as a [submodule](https://github.com/numpy/numpy/pull/22315) to accelerate `np.sort, np.argsort, np.partition and np.argpartition`.
83101
- A slightly modifed version this library has been integrated into [openJDK](https://github.com/openjdk/jdk/pull/14227).

benchmarks/bench-keyvalue.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ static void simdkvsort(benchmark::State &state, Args &&...args)
4646
BENCH_BOTH_KVSORT(uint64_t)
4747
BENCH_BOTH_KVSORT(int64_t)
4848
BENCH_BOTH_KVSORT(double)
49+
BENCH_BOTH_KVSORT(uint32_t)
50+
BENCH_BOTH_KVSORT(int32_t)
51+
BENCH_BOTH_KVSORT(float)

examples/avx512-kv.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ int main() {
55
int64_t arr1[size];
66
uint64_t arr2[size];
77
double arr3[size];
8+
float arr4[size];
89
avx512_qsort_kv(arr1, arr1, size);
910
avx512_qsort_kv(arr1, arr2, size);
1011
avx512_qsort_kv(arr1, arr3, size);
@@ -13,6 +14,9 @@ int main() {
1314
avx512_qsort_kv(arr2, arr3, size);
1415
avx512_qsort_kv(arr3, arr1, size);
1516
avx512_qsort_kv(arr3, arr2, size);
16-
avx512_qsort_kv(arr3, arr3, size);
17+
avx512_qsort_kv(arr1, arr4, size);
18+
avx512_qsort_kv(arr2, arr4, size);
19+
avx512_qsort_kv(arr3, arr4, size);
20+
return 0;
1721
return 0;
1822
}

lib/x86simdsort-skx.cpp

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,34 @@
3333
return avx512_argselect(arr, k, arrsize, hasnan); \
3434
}
3535

36-
#define DEFINE_KEYVALUE_METHODS(type1, type2) \
36+
#define DEFINE_KEYVALUE_METHODS(type) \
3737
template <> \
38-
void keyvalue_qsort(type1 *key, type2* val, size_t arrsize, bool hasnan) \
38+
void keyvalue_qsort(type *key, uint64_t* val, size_t arrsize, bool hasnan) \
39+
{ \
40+
avx512_qsort_kv(key, val, arrsize, hasnan); \
41+
} \
42+
template <> \
43+
void keyvalue_qsort(type *key, int64_t* val, size_t arrsize, bool hasnan) \
44+
{ \
45+
avx512_qsort_kv(key, val, arrsize, hasnan); \
46+
} \
47+
template <> \
48+
void keyvalue_qsort(type *key, double* val, size_t arrsize, bool hasnan) \
49+
{ \
50+
avx512_qsort_kv(key, val, arrsize, hasnan); \
51+
} \
52+
template <> \
53+
void keyvalue_qsort(type *key, uint32_t* val, size_t arrsize, bool hasnan) \
54+
{ \
55+
avx512_qsort_kv(key, val, arrsize, hasnan); \
56+
} \
57+
template <> \
58+
void keyvalue_qsort(type *key, int32_t* val, size_t arrsize, bool hasnan) \
59+
{ \
60+
avx512_qsort_kv(key, val, arrsize, hasnan); \
61+
} \
62+
template <> \
63+
void keyvalue_qsort(type *key, float* val, size_t arrsize, bool hasnan) \
3964
{ \
4065
avx512_qsort_kv(key, val, arrsize, hasnan); \
4166
} \
@@ -49,14 +74,11 @@ namespace avx512 {
4974
DEFINE_ALL_METHODS(uint64_t)
5075
DEFINE_ALL_METHODS(int64_t)
5176
DEFINE_ALL_METHODS(double)
52-
DEFINE_KEYVALUE_METHODS(double, uint64_t)
53-
DEFINE_KEYVALUE_METHODS(double, int64_t)
54-
DEFINE_KEYVALUE_METHODS(double, double)
55-
DEFINE_KEYVALUE_METHODS(uint64_t, uint64_t)
56-
DEFINE_KEYVALUE_METHODS(uint64_t, int64_t)
57-
DEFINE_KEYVALUE_METHODS(uint64_t, double)
58-
DEFINE_KEYVALUE_METHODS(int64_t, uint64_t)
59-
DEFINE_KEYVALUE_METHODS(int64_t, int64_t)
60-
DEFINE_KEYVALUE_METHODS(int64_t, double)
77+
DEFINE_KEYVALUE_METHODS(uint64_t)
78+
DEFINE_KEYVALUE_METHODS(int64_t)
79+
DEFINE_KEYVALUE_METHODS(double)
80+
DEFINE_KEYVALUE_METHODS(uint32_t)
81+
DEFINE_KEYVALUE_METHODS(int32_t)
82+
DEFINE_KEYVALUE_METHODS(float)
6183
} // namespace avx512
6284
} // namespace xss

lib/x86simdsort.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,19 @@ DISPATCH_ALL(argselect,
196196
(ISA_LIST("avx512_skx")),
197197
(ISA_LIST("avx512_skx")))
198198

199-
DISPATCH_KEYVALUE_SORT(uint64_t, int64_t, (ISA_LIST("avx512_skx")))
200-
DISPATCH_KEYVALUE_SORT(uint64_t, uint64_t, (ISA_LIST("avx512_skx")))
201-
DISPATCH_KEYVALUE_SORT(uint64_t, double, (ISA_LIST("avx512_skx")))
202-
DISPATCH_KEYVALUE_SORT(int64_t, int64_t, (ISA_LIST("avx512_skx")))
203-
DISPATCH_KEYVALUE_SORT(int64_t, uint64_t, (ISA_LIST("avx512_skx")))
204-
DISPATCH_KEYVALUE_SORT(int64_t, double, (ISA_LIST("avx512_skx")))
205-
DISPATCH_KEYVALUE_SORT(double, int64_t, (ISA_LIST("avx512_skx")))
206-
DISPATCH_KEYVALUE_SORT(double, double, (ISA_LIST("avx512_skx")))
207-
DISPATCH_KEYVALUE_SORT(double, uint64_t, (ISA_LIST("avx512_skx")))
199+
#define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \
200+
DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx")))\
201+
DISPATCH_KEYVALUE_SORT(type, int64_t, (ISA_LIST("avx512_skx")))\
202+
DISPATCH_KEYVALUE_SORT(type, double, (ISA_LIST("avx512_skx")))\
203+
DISPATCH_KEYVALUE_SORT(type, uint32_t, (ISA_LIST("avx512_skx")))\
204+
DISPATCH_KEYVALUE_SORT(type, int32_t, (ISA_LIST("avx512_skx")))\
205+
DISPATCH_KEYVALUE_SORT(type, float, (ISA_LIST("avx512_skx")))\
206+
207+
DISPATCH_KEYVALUE_SORT_FORTYPE(uint64_t)
208+
DISPATCH_KEYVALUE_SORT_FORTYPE(int64_t)
209+
DISPATCH_KEYVALUE_SORT_FORTYPE(double)
210+
DISPATCH_KEYVALUE_SORT_FORTYPE(uint32_t)
211+
DISPATCH_KEYVALUE_SORT_FORTYPE(int32_t)
212+
DISPATCH_KEYVALUE_SORT_FORTYPE(float)
208213

209214
} // namespace x86simdsort

src/avx512-64bit-common.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ struct ymm_vector<float> {
186186
// return _mm256_shuffle_ps(zmm, zmm, mask);
187187
//}
188188
}
189+
static reg_t sort_vec(reg_t x)
190+
{
191+
return sort_zmm_64bit<ymm_vector<type_t>>(x);
192+
}
189193
static void storeu(void *mem, reg_t x)
190194
{
191195
_mm256_storeu_ps((float *)mem, x);
@@ -342,6 +346,10 @@ struct ymm_vector<uint32_t> {
342346
* 32-bit and 64-bit */
343347
return _mm256_shuffle_epi32(zmm, 0b10110001);
344348
}
349+
static reg_t sort_vec(reg_t x)
350+
{
351+
return sort_zmm_64bit<ymm_vector<type_t>>(x);
352+
}
345353
static void storeu(void *mem, reg_t x)
346354
{
347355
_mm256_storeu_si256((__m256i *)mem, x);
@@ -498,6 +506,10 @@ struct ymm_vector<int32_t> {
498506
* 32-bit and 64-bit */
499507
return _mm256_shuffle_epi32(zmm, 0b10110001);
500508
}
509+
static reg_t sort_vec(reg_t x)
510+
{
511+
return sort_zmm_64bit<ymm_vector<type_t>>(x);
512+
}
501513
static void storeu(void *mem, reg_t x)
502514
{
503515
_mm256_storeu_si256((__m256i *)mem, x);

src/avx512-64bit-keyvaluesort.hpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ template <typename vtype1,
558558
X86_SIMD_SORT_INLINE void
559559
heap_sort(type1_t *keys, type2_t *indexes, arrsize_t size)
560560
{
561-
for (arrsize_t i = size / 2 - 1; ; i--) {
561+
for (arrsize_t i = size / 2 - 1;; i--) {
562562
heapify<vtype1, vtype2>(keys, indexes, i, size);
563563
if (i == 0) { break; }
564564
}
@@ -617,26 +617,33 @@ template <typename T1, typename T2>
617617
X86_SIMD_SORT_INLINE void
618618
avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false)
619619
{
620-
UNUSED(hasnan);
620+
using keytype = typename std::conditional<sizeof(T1) == sizeof(int32_t),
621+
ymm_vector<T1>,
622+
zmm_vector<T1>>::type;
623+
using valtype = typename std::conditional<sizeof(T2) == sizeof(int32_t),
624+
ymm_vector<T2>,
625+
zmm_vector<T2>>::type;
621626
if (arrsize > 1) {
622627
if constexpr (std::is_floating_point_v<T1>) {
623-
arrsize_t nan_count
624-
= replace_nan_with_inf<zmm_vector<double>>(keys, arrsize);
625-
qsort_64bit_<zmm_vector<T1>, zmm_vector<T2>>(
626-
keys,
627-
indexes,
628-
0,
629-
arrsize - 1,
630-
2 * (arrsize_t)log2(arrsize));
628+
arrsize_t nan_count = 0;
629+
if (UNLIKELY(hasnan)) {
630+
nan_count = replace_nan_with_inf<zmm_vector<double>>(keys,
631+
arrsize);
632+
}
633+
qsort_64bit_<keytype, valtype>(keys,
634+
indexes,
635+
0,
636+
arrsize - 1,
637+
2 * (arrsize_t)log2(arrsize));
631638
replace_inf_with_nan(keys, arrsize, nan_count);
632639
}
633640
else {
634-
qsort_64bit_<zmm_vector<T1>, zmm_vector<T2>>(
635-
keys,
636-
indexes,
637-
0,
638-
arrsize - 1,
639-
2 * (arrsize_t)log2(arrsize));
641+
UNUSED(hasnan);
642+
qsort_64bit_<keytype, valtype>(keys,
643+
indexes,
644+
0,
645+
arrsize - 1,
646+
2 * (arrsize_t)log2(arrsize));
640647
}
641648
}
642649
}

tests/test-keyvalue.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,28 +40,32 @@ TYPED_TEST_P(simdkvsort, test_kvsort)
4040
std::vector<T1> key_bckp = key;
4141
std::vector<T2> val_bckp = val;
4242
x86simdsort::keyvalue_qsort(key.data(), val.data(), size, hasnan);
43-
xss::scalar::keyvalue_qsort(key_bckp.data(), val_bckp.data(), size, hasnan);
43+
xss::scalar::keyvalue_qsort(
44+
key_bckp.data(), val_bckp.data(), size, hasnan);
4445
ASSERT_EQ(key, key_bckp);
45-
const bool hasDuplicates = std::adjacent_find(key.begin(), key.end()) != key.end();
46-
if (!hasDuplicates) {
47-
ASSERT_EQ(val, val_bckp);
48-
}
49-
key.clear(); val.clear();
50-
key_bckp.clear(); val_bckp.clear();
46+
const bool hasDuplicates
47+
= std::adjacent_find(key.begin(), key.end()) != key.end();
48+
if (!hasDuplicates) { ASSERT_EQ(val, val_bckp); }
49+
key.clear();
50+
val.clear();
51+
key_bckp.clear();
52+
val_bckp.clear();
5153
}
5254
}
5355
}
5456

5557
REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort);
5658

57-
using QKVSortTestTypes = testing::Types<std::tuple<double, double>,
58-
std::tuple<double, uint64_t>,
59-
std::tuple<double, int64_t>,
60-
std::tuple<uint64_t, double>,
61-
std::tuple<uint64_t, uint64_t>,
62-
std::tuple<uint64_t, int64_t>,
63-
std::tuple<int64_t, double>,
64-
std::tuple<int64_t, uint64_t>,
65-
std::tuple<int64_t, int64_t>>;
59+
#define CREATE_TUPLES(type) \
60+
std::tuple<double, type>, std::tuple<uint64_t, type>, \
61+
std::tuple<int64_t, type>, std::tuple<float, type>, \
62+
std::tuple<uint32_t, type>, std::tuple<int32_t, type>
63+
64+
using QKVSortTestTypes = testing::Types<CREATE_TUPLES(double),
65+
CREATE_TUPLES(uint64_t),
66+
CREATE_TUPLES(int64_t),
67+
CREATE_TUPLES(uint32_t),
68+
CREATE_TUPLES(int32_t),
69+
CREATE_TUPLES(float)>;
6670

6771
INSTANTIATE_TYPED_TEST_SUITE_P(xss, simdkvsort, QKVSortTestTypes);

0 commit comments

Comments
 (0)