8585#define X86_SIMD_SORT_FINLINE static
8686#endif
8787
88+ #define LIKELY (x ) __builtin_expect((x),1 )
89+ #define UNLIKELY (x ) __builtin_expect((x),0 )
90+
8891template <typename type>
8992struct zmm_vector ;
9093
@@ -97,25 +100,54 @@ void avx512_qsort(T *arr, int64_t arrsize);
97100void avx512_qsort_fp16 (uint16_t *arr, int64_t arrsize);
98101
99102template <typename T>
100- void avx512_qselect (T *arr, int64_t k, int64_t arrsize);
101- void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize);
103+ void avx512_qselect (T *arr, int64_t k, int64_t arrsize, bool hasnan = false );
104+ void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false );
102105
103106template <typename T>
104- inline void avx512_partial_qsort (T *arr, int64_t k, int64_t arrsize)
107+ inline void avx512_partial_qsort (T *arr, int64_t k, int64_t arrsize, bool hasnan = false )
105108{
106- avx512_qselect<T>(arr, k - 1 , arrsize);
109+ avx512_qselect<T>(arr, k - 1 , arrsize, hasnan );
107110 avx512_qsort<T>(arr, k - 1 );
108111}
109- inline void avx512_partial_qsort_fp16 (uint16_t *arr, int64_t k, int64_t arrsize)
112+ inline void avx512_partial_qsort_fp16 (uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false )
110113{
111- avx512_qselect_fp16 (arr, k - 1 , arrsize);
114+ avx512_qselect_fp16 (arr, k - 1 , arrsize, hasnan );
112115 avx512_qsort_fp16 (arr, k - 1 );
113116}
114117
115118// key-value sort routines
116119template <typename T>
117120void avx512_qsort_kv (T *keys, uint64_t *indexes, int64_t arrsize);
118121
122+ template <typename T>
123+ bool is_a_nan (T elem)
124+ {
125+ return std::isnan (elem);
126+ }
127+
128+ /*
129+ * Sort all the NAN's to end of the array and return the index of the last elem
130+ * in the array which is not a nan
131+ */
132+ template <typename T>
133+ int64_t move_nans_to_end_of_array (T* arr, int64_t arrsize)
134+ {
135+ int64_t jj = arrsize - 1 ;
136+ int64_t ii = 0 ;
137+ int64_t count = 0 ;
138+ while (ii <= jj) {
139+ if (is_a_nan (arr[ii])) {
140+ std::swap (arr[ii], arr[jj]);
141+ jj -= 1 ;
142+ count++;
143+ }
144+ else {
145+ ii += 1 ;
146+ }
147+ }
148+ return arrsize-count-1 ;
149+ }
150+
119151template <typename vtype, typename T = typename vtype::type_t >
120152bool comparison_func (const T &a, const T &b)
121153{
0 commit comments