88#define AVX512_16BIT_COMMON
99
1010#include " avx512-common-qsort.h"
11+ #include " xss-network-qsort.hpp"
1112
1213/*
1314 * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic
@@ -33,8 +34,8 @@ static const uint16_t network[6][32]
3334 * Assumes zmm is random and performs a full sorting network defined in
3435 * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
3536 */
36- template <typename vtype, typename zmm_t = typename vtype::zmm_t >
37- X86_SIMD_SORT_INLINE zmm_t sort_zmm_16bit (zmm_t zmm)
37+ template <typename vtype, typename reg_t = typename vtype::reg_t >
38+ X86_SIMD_SORT_INLINE reg_t sort_zmm_16bit (reg_t zmm)
3839{
3940 // Level 1
4041 zmm = cmp_merge<vtype>(
@@ -93,8 +94,8 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_16bit(zmm_t zmm)
9394}
9495
9596// Assumes zmm is bitonic and performs a recursive half cleaner
96- template <typename vtype, typename zmm_t = typename vtype::zmm_t >
97- X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit (zmm_t zmm)
97+ template <typename vtype, typename reg_t = typename vtype::reg_t >
98+ X86_SIMD_SORT_INLINE reg_t bitonic_merge_zmm_16bit (reg_t zmm)
9899{
99100 // 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc ..
100101 zmm = cmp_merge<vtype>(
@@ -118,208 +119,4 @@ X86_SIMD_SORT_INLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
118119 return zmm;
119120}
120121
121- // Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner
122- template <typename vtype, typename zmm_t = typename vtype::zmm_t >
123- X86_SIMD_SORT_INLINE void bitonic_merge_two_zmm_16bit (zmm_t &zmm1, zmm_t &zmm2)
124- {
125- // 1) First step of a merging network: coex of zmm1 and zmm2 reversed
126- zmm2 = vtype::permutexvar (vtype::get_network (4 ), zmm2);
127- zmm_t zmm3 = vtype::min (zmm1, zmm2);
128- zmm_t zmm4 = vtype::max (zmm1, zmm2);
129- // 2) Recursive half cleaner for each
130- zmm1 = bitonic_merge_zmm_16bit<vtype>(zmm3);
131- zmm2 = bitonic_merge_zmm_16bit<vtype>(zmm4);
132- }
133-
134- // Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive
135- // half cleaner
136- template <typename vtype, typename zmm_t = typename vtype::zmm_t >
137- X86_SIMD_SORT_INLINE void bitonic_merge_four_zmm_16bit (zmm_t *zmm)
138- {
139- zmm_t zmm2r = vtype::permutexvar (vtype::get_network (4 ), zmm[2 ]);
140- zmm_t zmm3r = vtype::permutexvar (vtype::get_network (4 ), zmm[3 ]);
141- zmm_t zmm_t1 = vtype::min (zmm[0 ], zmm3r);
142- zmm_t zmm_t2 = vtype::min (zmm[1 ], zmm2r);
143- zmm_t zmm_t3 = vtype::permutexvar (vtype::get_network (4 ),
144- vtype::max (zmm[1 ], zmm2r));
145- zmm_t zmm_t4 = vtype::permutexvar (vtype::get_network (4 ),
146- vtype::max (zmm[0 ], zmm3r));
147- zmm_t zmm0 = vtype::min (zmm_t1, zmm_t2);
148- zmm_t zmm1 = vtype::max (zmm_t1, zmm_t2);
149- zmm_t zmm2 = vtype::min (zmm_t3, zmm_t4);
150- zmm_t zmm3 = vtype::max (zmm_t3, zmm_t4);
151- zmm[0 ] = bitonic_merge_zmm_16bit<vtype>(zmm0);
152- zmm[1 ] = bitonic_merge_zmm_16bit<vtype>(zmm1);
153- zmm[2 ] = bitonic_merge_zmm_16bit<vtype>(zmm2);
154- zmm[3 ] = bitonic_merge_zmm_16bit<vtype>(zmm3);
155- }
156-
157- template <typename vtype, typename type_t >
158- X86_SIMD_SORT_INLINE void sort_32_16bit (type_t *arr, int32_t N)
159- {
160- typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull ) & 0xFFFFFFFF ;
161- typename vtype::zmm_t zmm
162- = vtype::mask_loadu (vtype::zmm_max (), load_mask, arr);
163- vtype::mask_storeu (arr, load_mask, sort_zmm_16bit<vtype>(zmm));
164- }
165-
166- template <typename vtype, typename type_t >
167- X86_SIMD_SORT_INLINE void sort_64_16bit (type_t *arr, int32_t N)
168- {
169- if (N <= 32 ) {
170- sort_32_16bit<vtype>(arr, N);
171- return ;
172- }
173- using zmm_t = typename vtype::zmm_t ;
174- typename vtype::opmask_t load_mask
175- = ((0x1ull << (N - 32 )) - 0x1ull ) & 0xFFFFFFFF ;
176- zmm_t zmm1 = vtype::loadu (arr);
177- zmm_t zmm2 = vtype::mask_loadu (vtype::zmm_max (), load_mask, arr + 32 );
178- zmm1 = sort_zmm_16bit<vtype>(zmm1);
179- zmm2 = sort_zmm_16bit<vtype>(zmm2);
180- bitonic_merge_two_zmm_16bit<vtype>(zmm1, zmm2);
181- vtype::storeu (arr, zmm1);
182- vtype::mask_storeu (arr + 32 , load_mask, zmm2);
183- }
184-
185- template <typename vtype, typename type_t >
186- X86_SIMD_SORT_INLINE void sort_128_16bit (type_t *arr, int32_t N)
187- {
188- if (N <= 64 ) {
189- sort_64_16bit<vtype>(arr, N);
190- return ;
191- }
192- using zmm_t = typename vtype::zmm_t ;
193- using opmask_t = typename vtype::opmask_t ;
194- zmm_t zmm[4 ];
195- zmm[0 ] = vtype::loadu (arr);
196- zmm[1 ] = vtype::loadu (arr + 32 );
197- opmask_t load_mask1 = 0xFFFFFFFF , load_mask2 = 0xFFFFFFFF ;
198- if (N != 128 ) {
199- uint64_t combined_mask = (0x1ull << (N - 64 )) - 0x1ull ;
200- load_mask1 = combined_mask & 0xFFFFFFFF ;
201- load_mask2 = (combined_mask >> 32 ) & 0xFFFFFFFF ;
202- }
203- zmm[2 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask1, arr + 64 );
204- zmm[3 ] = vtype::mask_loadu (vtype::zmm_max (), load_mask2, arr + 96 );
205- zmm[0 ] = sort_zmm_16bit<vtype>(zmm[0 ]);
206- zmm[1 ] = sort_zmm_16bit<vtype>(zmm[1 ]);
207- zmm[2 ] = sort_zmm_16bit<vtype>(zmm[2 ]);
208- zmm[3 ] = sort_zmm_16bit<vtype>(zmm[3 ]);
209- bitonic_merge_two_zmm_16bit<vtype>(zmm[0 ], zmm[1 ]);
210- bitonic_merge_two_zmm_16bit<vtype>(zmm[2 ], zmm[3 ]);
211- bitonic_merge_four_zmm_16bit<vtype>(zmm);
212- vtype::storeu (arr, zmm[0 ]);
213- vtype::storeu (arr + 32 , zmm[1 ]);
214- vtype::mask_storeu (arr + 64 , load_mask1, zmm[2 ]);
215- vtype::mask_storeu (arr + 96 , load_mask2, zmm[3 ]);
216- }
217-
218- template <typename vtype, typename type_t >
219- X86_SIMD_SORT_INLINE type_t get_pivot_16bit (type_t *arr,
220- const int64_t left,
221- const int64_t right)
222- {
223- // median of 32
224- int64_t size = (right - left) / 32 ;
225- type_t vec_arr[32 ] = {arr[left],
226- arr[left + size],
227- arr[left + 2 * size],
228- arr[left + 3 * size],
229- arr[left + 4 * size],
230- arr[left + 5 * size],
231- arr[left + 6 * size],
232- arr[left + 7 * size],
233- arr[left + 8 * size],
234- arr[left + 9 * size],
235- arr[left + 10 * size],
236- arr[left + 11 * size],
237- arr[left + 12 * size],
238- arr[left + 13 * size],
239- arr[left + 14 * size],
240- arr[left + 15 * size],
241- arr[left + 16 * size],
242- arr[left + 17 * size],
243- arr[left + 18 * size],
244- arr[left + 19 * size],
245- arr[left + 20 * size],
246- arr[left + 21 * size],
247- arr[left + 22 * size],
248- arr[left + 23 * size],
249- arr[left + 24 * size],
250- arr[left + 25 * size],
251- arr[left + 26 * size],
252- arr[left + 27 * size],
253- arr[left + 28 * size],
254- arr[left + 29 * size],
255- arr[left + 30 * size],
256- arr[left + 31 * size]};
257- typename vtype::zmm_t rand_vec = vtype::loadu (vec_arr);
258- typename vtype::zmm_t sort = sort_zmm_16bit<vtype>(rand_vec);
259- return ((type_t *)&sort)[16 ];
260- }
261-
262- template <typename vtype, typename type_t >
263- static void
264- qsort_16bit_ (type_t *arr, int64_t left, int64_t right, int64_t max_iters)
265- {
266- /*
267- * Resort to std::sort if quicksort isnt making any progress
268- */
269- if (max_iters <= 0 ) {
270- std::sort (arr + left, arr + right + 1 , comparison_func<vtype>);
271- return ;
272- }
273- /*
274- * Base case: use bitonic networks to sort arrays <= 128
275- */
276- if (right + 1 - left <= 128 ) {
277- sort_128_16bit<vtype>(arr + left, (int32_t )(right + 1 - left));
278- return ;
279- }
280-
281- type_t pivot = get_pivot_16bit<vtype>(arr, left, right);
282- type_t smallest = vtype::type_max ();
283- type_t biggest = vtype::type_min ();
284- int64_t pivot_index = partition_avx512<vtype>(
285- arr, left, right + 1 , pivot, &smallest, &biggest);
286- if (pivot != smallest)
287- qsort_16bit_<vtype>(arr, left, pivot_index - 1 , max_iters - 1 );
288- if (pivot != biggest)
289- qsort_16bit_<vtype>(arr, pivot_index, right, max_iters - 1 );
290- }
291-
292- template <typename vtype, typename type_t >
293- static void qselect_16bit_ (type_t *arr,
294- int64_t pos,
295- int64_t left,
296- int64_t right,
297- int64_t max_iters)
298- {
299- /*
300- * Resort to std::sort if quicksort isnt making any progress
301- */
302- if (max_iters <= 0 ) {
303- std::sort (arr + left, arr + right + 1 , comparison_func<vtype>);
304- return ;
305- }
306- /*
307- * Base case: use bitonic networks to sort arrays <= 128
308- */
309- if (right + 1 - left <= 128 ) {
310- sort_128_16bit<vtype>(arr + left, (int32_t )(right + 1 - left));
311- return ;
312- }
313-
314- type_t pivot = get_pivot_16bit<vtype>(arr, left, right);
315- type_t smallest = vtype::type_max ();
316- type_t biggest = vtype::type_min ();
317- int64_t pivot_index = partition_avx512<vtype>(
318- arr, left, right + 1 , pivot, &smallest, &biggest);
319- if ((pivot != smallest) && (pos < pivot_index))
320- qselect_16bit_<vtype>(arr, pos, left, pivot_index - 1 , max_iters - 1 );
321- else if ((pivot != biggest) && (pos >= pivot_index))
322- qselect_16bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1 );
323- }
324-
325122#endif // AVX512_16BIT_COMMON
0 commit comments