66enum  class  pivot_result_t  : int  { Normal, Sorted, Only2Values };
77
88template  <typename  type_t >
9- struct  pivot_results {
10-      
9+ struct  pivot_results   {
10+ 
1111    pivot_result_t  result = pivot_result_t ::Normal;
1212    type_t  pivot = 0 ;
13-     
14-     pivot_results (type_t  _pivot, pivot_result_t  _result = pivot_result_t ::Normal){
13+ 
14+     pivot_results (type_t  _pivot,
15+                   pivot_result_t  _result = pivot_result_t ::Normal)
16+     {
1517        pivot = _pivot;
1618        result = _result;
1719    }
1820};
1921
2022template  <typename  type_t >
21- type_t  next_value (type_t  value){
23+ type_t  next_value (type_t  value)
24+ {
2225    //  TODO this probably handles non-native float16 wrong
23-     if  constexpr  (std::is_floating_point<type_t >::value){
26+     if  constexpr  (std::is_floating_point<type_t >::value)  {
2427        return  std::nextafter (value, std::numeric_limits<type_t >::infinity ());
25-     }else { 
26-          if  (value < std::numeric_limits< type_t >:: max ()) {
27-              return  value + 1 ;
28-         } else {
28+     }
29+     else   {
30+         if  (value < std::numeric_limits< type_t >:: max ()) {  return  value + 1 ; } 
31+         else   {
2932            return  value;
3033        }
3134    }
@@ -96,23 +99,23 @@ X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr,
9699}
97100
98101template  <typename  vtype, typename  type_t >
99- X86_SIMD_SORT_INLINE pivot_results<type_t > get_pivot_near_constant (type_t  *arr,
100-                                              type_t  commonValue,
101-                                              const  arrsize_t  left,
102-                                              const  arrsize_t  right);
102+ X86_SIMD_SORT_INLINE pivot_results<type_t >
103+ get_pivot_near_constant (type_t  *arr,
104+                         type_t  commonValue,
105+                         const  arrsize_t  left,
106+                         const  arrsize_t  right);
103107
104108template  <typename  vtype, typename  type_t >
105- X86_SIMD_SORT_INLINE pivot_results<type_t > get_pivot_smart (type_t  *arr,
106-                                              const  arrsize_t  left,
107-                                              const  arrsize_t  right)
109+ X86_SIMD_SORT_INLINE pivot_results<type_t >
110+ get_pivot_smart (type_t  *arr, const  arrsize_t  left, const  arrsize_t  right)
108111{
109112    using  reg_t  = typename  vtype::reg_t ;
110113    constexpr  int  numVecs = 4 ;
111-      
112-     if  (right - left + 1  <= 4  * numVecs * vtype::numlanes){
113-         return  pivot_results<type_t >(get_pivot<vtype>(arr, left, right));  
114+ 
115+     if  (right - left + 1  <= 4  * numVecs * vtype::numlanes)  {
116+         return  pivot_results<type_t >(get_pivot<vtype>(arr, left, right));
114117    }
115-      
118+ 
116119    constexpr  int  N = numVecs * vtype::numlanes;
117120
118121    arrsize_t  width = (right - vtype::numlanes) - left;
@@ -122,111 +125,123 @@ X86_SIMD_SORT_INLINE pivot_results<type_t> get_pivot_smart(type_t *arr,
122125    for  (int  i = 0 ; i < numVecs; i++) {
123126        vecs[i] = vtype::loadu (arr + left + delta * i);
124127    }
125-      
128+ 
126129    //  Sort the samples
127130    sort_vectors<vtype, numVecs>(vecs);
128-      
131+ 
129132    type_t  samples[N];
130-     for  (int  i = 0 ; i < numVecs; i++){
133+     for  (int  i = 0 ; i < numVecs; i++)  {
131134        vtype::storeu (samples + vtype::numlanes * i, vecs[i]);
132135    }
133-      
136+ 
134137    type_t  smallest = samples[0 ];
135138    type_t  largest = samples[N - 1 ];
136139    type_t  median = samples[N / 2 ];
137-      
138-     if  (smallest == largest){
140+ 
141+     if  (smallest == largest)  {
139142        //  We have a very unlucky sample, or the array is constant / near constant
140143        //  Run a special function meant to deal with this situation
141144        return  get_pivot_near_constant<vtype, type_t >(arr, median, left, right);
142-     }else  if  (median != smallest && median != largest){
145+     }
146+     else  if  (median != smallest && median != largest) {
143147        //  We have a normal sample; use it's median
144148        return  pivot_results<type_t >(median);
145-     }else  if  (median == smallest){
149+     }
150+     else  if  (median == smallest) {
146151        //  If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample
147152        //  Try just doing the next largest value greater than this seemingly very common value to seperate them out
148153        return  pivot_results<type_t >(next_value<type_t >(median));
149-     }else  if  (median == largest){
154+     }
155+     else  if  (median == largest) {
150156        //  If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample
151157        //  Thus, median probably is a fine pivot, since it will move all of this common value into its own partition
152158        return  pivot_results<type_t >(median);
153-     }else {
159+     }
160+     else  {
154161        //  Should be unreachable
155162        return  pivot_results<type_t >(median);
156163    }
157-      
164+ 
158165    //  Should be unreachable
159166    return  pivot_results<type_t >(median);
160167}
161168
162169//  Handles the case where we seem to have a near-constant array, since our sample of the array was constant
163170template  <typename  vtype, typename  type_t >
164- X86_SIMD_SORT_INLINE pivot_results<type_t > get_pivot_near_constant (type_t  *arr,
165-                                              type_t  commonValue,
166-                                              const  arrsize_t  left,
167-                                              const  arrsize_t  right)
171+ X86_SIMD_SORT_INLINE pivot_results<type_t >
172+ get_pivot_near_constant (type_t  *arr,
173+                         type_t  commonValue,
174+                         const  arrsize_t  left,
175+                         const  arrsize_t  right)
168176{
169177    using  reg_t  = typename  vtype::reg_t ;
170-      
178+ 
171179    arrsize_t  index = left;
172-      
180+ 
173181    type_t  value1 = 0 ;
174182    type_t  value2 = 0 ;
175-      
183+ 
176184    //  First, search for any value not equal to the common value
177185    //  First vectorized
178186    reg_t  commonVec = vtype::set1 (commonValue);
179-     for  (; index <= right - vtype::numlanes; index += vtype::numlanes){
187+     for  (; index <= right - vtype::numlanes; index += vtype::numlanes)  {
180188        reg_t  data = vtype::loadu (arr + index);
181-         if  (!vtype::all_false (vtype::knot_opmask (vtype::eq (data, commonVec)))){
189+         if  (!vtype::all_false (vtype::knot_opmask (vtype::eq (data, commonVec))))  {
182190            break ;
183191        }
184192    }
185-      
193+ 
186194    //  Than scalar at the end
187-     for  (; index <= right; index++){
188-         if  (arr[index] != commonValue){
195+     for  (; index <= right; index++)  {
196+         if  (arr[index] != commonValue)  {
189197            value1 = arr[index];
190198            break ;
191-         }  
199+         }
192200    }
193-      
194-     if  (index == right + 1 ){
201+ 
202+     if  (index == right + 1 )  {
195203        //  The array is completely constant
196204        //  Setting the second flag to true skips partitioning, as the array is constant and thus sorted
197205        return  pivot_results<type_t >(commonValue, pivot_result_t ::Sorted);
198206    }
199-      
207+ 
200208    //  Secondly, search for a second value not equal to either of the previous two
201209    //  First vectorized
202210    reg_t  value1Vec = vtype::set1 (value1);
203-     for  (; index <= right - vtype::numlanes; index += vtype::numlanes){
211+     for  (; index <= right - vtype::numlanes; index += vtype::numlanes)  {
204212        reg_t  data = vtype::loadu (arr + index);
205-         if  (!vtype::all_false (vtype::knot_opmask (vtype::eq (data, commonVec))) && !vtype::all_false (vtype::knot_opmask (vtype::eq (data, value1Vec)))){
213+         if  (!vtype::all_false (vtype::knot_opmask (vtype::eq (data, commonVec)))
214+             && !vtype::all_false (
215+                     vtype::knot_opmask (vtype::eq (data, value1Vec)))) {
206216            break ;
207217        }
208218    }
209-      
219+ 
210220    //  Then scalar
211-     for  (; index <= right; index++){
212-         if  (arr[index] != commonValue && arr[index] != value1){
221+     for  (; index <= right; index++)  {
222+         if  (arr[index] != commonValue && arr[index] != value1)  {
213223            value2 = arr[index];
214224            break ;
215-         }  
225+         }
216226    }
217-      
218-     if  (index == right + 1 ){
227+ 
228+     if  (index == right + 1 )  {
219229        //  The array contains only 2 values
220230        //  We must pick the larger one, else the right partition is empty
221231        //  We can also skip recursing, as it is guaranteed both partitions are constant after partitioning with the larger value
222232        //  TODO this logic now assumes we use greater than or equal to specifically when partitioning, might be worth noting that somewhere
223233        type_t  pivot = std::max (value1, commonValue, comparison_func<vtype>);
224234        return  pivot_results<type_t >(pivot, pivot_result_t ::Only2Values);
225235    }
226-      
236+ 
227237    //  The array has at least 3 distinct values. Use the middle one as the pivot
228-     type_t  median = std::max (std::min (value1,value2, comparison_func<vtype>), std::min (std::max (value1,value2, comparison_func<vtype>),commonValue, comparison_func<vtype>), comparison_func<vtype>);
238+     type_t  median = std::max (
239+             std::min (value1, value2, comparison_func<vtype>),
240+             std::min (std::max (value1, value2, comparison_func<vtype>),
241+                      commonValue,
242+                      comparison_func<vtype>),
243+             comparison_func<vtype>);
229244    return  pivot_results<type_t >(median);
230245}
231246
232- #endif 
247+ #endif 
0 commit comments