1
1
/* !
2
2
* Copyright 2014 by Contributors
3
+ * Modifications Copyright 2020 by Secure XGBoost Contributors
3
4
* \file quantile.h
4
5
* \brief util to compute quantiles
5
6
* \author Tianqi Chen
9
10
10
11
#include < dmlc/base.h>
11
12
#include < xgboost/logging.h>
12
-
13
13
#include < algorithm>
14
14
#include < cmath>
15
15
#include < cstring>
16
16
#include < iostream>
17
17
#include < vector>
18
-
19
18
#include " obl_primitives.h"
20
19
21
20
namespace xgboost {
@@ -231,9 +230,9 @@ template <typename DType, typename RType>
231
230
void CheckEqualSummary (const WQSummary<DType, RType> &lhs,
232
231
const WQSummary<DType, RType> &rhs) {
233
232
auto trace = [&]() {
234
- LOG (CONSOLE ) << " ---------- lhs: " ;
233
+ LOG (INFO ) << " ---------- lhs: " ;
235
234
lhs.Print ();
236
- LOG (CONSOLE ) << " ---------- rhs: " ;
235
+ LOG (INFO ) << " ---------- rhs: " ;
237
236
rhs.Print ();
238
237
};
239
238
// DEBUG CHECK
@@ -319,7 +318,6 @@ struct WQSummary {
319
318
helper_entry.entry .weight = 0 ;
320
319
}
321
320
322
- size_t unique_count = 0 ;
323
321
for (size_t idx = 0 ; idx < qhelper.size (); ++idx) {
324
322
// sum weight for same value
325
323
qhelper[idx].entry .weight += queue[idx].weight ;
@@ -329,7 +327,6 @@ struct WQSummary {
329
327
: !ObliviousEqual (qhelper[idx + 1 ].entry .value ,
330
328
qhelper[idx].entry .value );
331
329
qhelper[idx].is_new = is_new;
332
- unique_count += is_new;
333
330
if (idx != qhelper.size () - 1 ) {
334
331
// Accumulate when next is same with me, otherwise reset to zero.
335
332
qhelper[idx + 1 ].entry .weight =
@@ -441,7 +438,7 @@ struct WQSummary {
441
438
}
442
439
/* !
443
440
* \brief debug function, validate whether the summary
444
- * run consistency check to check if it is a valid summary
441
+ * run consistency check to check if it is a valid summary
445
442
* \param eps the tolerate error level, used when RType is floating point and
446
443
* some inconsistency could occur due to rounding error
447
444
*/
@@ -461,8 +458,9 @@ struct WQSummary {
461
458
* \brief set current summary to be obliviously pruned summary of src
462
459
* assume data field is already allocated to be at least maxsize
463
460
* dummy item will rank last of return and will involved in following
464
- * computation \param src source summary \param maxsize size we can afford in
465
- * the pruned sketch
461
+ * computation
462
+ * \param src source summary \param maxsize size we can afford in
463
+ * the pruned sketch
466
464
*/
467
465
inline void ObliviousSetPrune (const WQSummary &src, size_t maxsize) {
468
466
if (src.size <= maxsize) {
@@ -501,15 +499,15 @@ struct WQSummary {
501
499
//
502
500
for (size_t i = 1 ; i < src.size ; ++i) {
503
501
Item obliviousItem = ObliviousChoose (
504
- i < max_index - 1 ,
502
+ ObliviousLess (i , max_index - 1 ) ,
505
503
Item{src.data [i], src.data [i].rmax + src.data [i].rmin , true },
506
504
Item{kDummyEntryWithMaxValue , std::numeric_limits<RType>::max (),
507
505
true });
508
506
items.push_back (obliviousItem);
509
507
}
510
508
for (size_t i = 1 ; i < src.size - 1 ; ++i) {
511
509
Item obliviousItem = ObliviousChoose (
512
- i < max_index - 1 ,
510
+ ObliviousLess (i , max_index - 1 ) ,
513
511
Item{src.data [i], src.data [i].RMinNext () + src.data [i + 1 ].RMaxPrev (),
514
512
true },
515
513
Item{kDummyEntryWithMaxValue , std::numeric_limits<RType>::max (),
@@ -529,11 +527,11 @@ struct WQSummary {
529
527
// CASE max_index<=maxsize:All unique item will be select
530
528
// CASE other : select unique after dx2 index
531
529
bool do_select = ObliviousChoose (
532
- max_index <= maxsize,
533
- items[i].entry .value != last_selected_entry_value & &
534
- items[i].entry .value != std::numeric_limits<RType >::max (),
535
- !items[i - 1 ].has_entry && items[i].has_entry & &
536
- items[i].entry .value != last_selected_entry_value);
530
+ ObliviousLess ( max_index , maxsize) ,
531
+ ! ObliviousEqual ( items[i].entry .value , last_selected_entry_value) &
532
+ ! ObliviousEqual ( items[i].entry .value , std::numeric_limits<DType >::max () ),
533
+ !items[i - 1 ].has_entry & items[i].has_entry &
534
+ ! ObliviousEqual ( items[i].entry .value , last_selected_entry_value) );
537
535
ObliviousAssign (do_select, items[i].entry .value ,
538
536
last_selected_entry_value, &last_selected_entry_value);
539
537
ObliviousAssign (do_select, std::numeric_limits<RType>::min (),
@@ -551,12 +549,12 @@ struct WQSummary {
551
549
552
550
// Assign actual last entry to lastEntry
553
551
for (size_t idx = 0 ; idx < src.size ; ++idx) {
554
- ObliviousAssign (idx == max_index - 1 , src.data [idx], lastEntry,
552
+ ObliviousAssign (ObliviousEqual ( idx , max_index - 1 ) , src.data [idx], lastEntry,
555
553
&lastEntry);
556
554
}
557
555
// Append actual last item to items vector
558
556
for (size_t i = 0 ; i < src.size ; i++) {
559
- ObliviousAssign (i == select_count, lastEntry, items[i].entry ,
557
+ ObliviousAssign (ObliviousEqual (i , select_count) , lastEntry, items[i].entry ,
560
558
&items[i].entry );
561
559
}
562
560
@@ -658,17 +656,16 @@ struct WQSummary {
658
656
std::transform (sa.data , sa.data + sa.size , merged_party_entrys.begin (),
659
657
[](const Entry &entry) {
660
658
bool is_dummy = ObliviousChoose (
661
- entry.value == std::numeric_limits<DType>::max (), true ,
659
+ ObliviousEqual ( entry.value , std::numeric_limits<DType>::max () ), true ,
662
660
false );
663
661
return EntryWithPartyInfo{entry, true , is_dummy};
664
662
});
665
- std::transform (
666
- sb.data , sb.data + sb.size , merged_party_entrys.begin () + sa.size ,
667
- [](const Entry &entry) {
668
- bool is_dummy = ObliviousChoose (
669
- entry.value == std::numeric_limits<DType>::max (), true , false );
670
- return EntryWithPartyInfo{entry, false , is_dummy};
671
- });
663
+ std::transform (sb.data , sb.data + sb.size , merged_party_entrys.begin () + sa.size ,
664
+ [](const Entry &entry) {
665
+ bool is_dummy = ObliviousChoose (
666
+ ObliviousEqual (entry.value , std::numeric_limits<DType>::max ()), true , false );
667
+ return EntryWithPartyInfo{entry, false , is_dummy};
668
+ });
672
669
// Build bitonic sequence.
673
670
std::reverse (merged_party_entrys.begin (),
674
671
merged_party_entrys.begin () + sa.size );
@@ -692,11 +689,11 @@ struct WQSummary {
692
689
693
690
// Save first.
694
691
RType next_aprev_rmin = ObliviousChoose (
695
- merged_party_entrys[idx].is_party_a &&
692
+ merged_party_entrys[idx].is_party_a &
696
693
!merged_party_entrys[idx].is_dummy ,
697
694
merged_party_entrys[idx].entry .RMinNext (), a_prev_rmin);
698
695
RType next_bprev_rmin = ObliviousChoose (
699
- !merged_party_entrys[idx].is_party_a &&
696
+ !merged_party_entrys[idx].is_party_a &
700
697
!merged_party_entrys[idx].is_dummy ,
701
698
merged_party_entrys[idx].entry .RMinNext (), b_prev_rmin);
702
699
@@ -721,11 +718,11 @@ struct WQSummary {
721
718
RType a_prev_rmax = sa.data [sa.size - 1 ].rmax ;
722
719
RType b_prev_rmax = sb.data [sb.size - 1 ].rmax ;
723
720
for (int idx = 0 ; idx < sa.size ; idx++) {
724
- a_prev_rmax = ObliviousChoose (sa.data [idx].rmax > a_prev_rmax,
721
+ a_prev_rmax = ObliviousChoose (ObliviousGreater ( sa.data [idx].rmax , a_prev_rmax) ,
725
722
sa.data [idx].rmax , a_prev_rmax);
726
723
}
727
724
for (int idx = 0 ; idx < sb.size ; idx++) {
728
- b_prev_rmax = ObliviousChoose (sb.data [idx].rmax > b_prev_rmax,
725
+ b_prev_rmax = ObliviousChoose (ObliviousGreater ( sb.data [idx].rmax , b_prev_rmax) ,
729
726
sb.data [idx].rmax , b_prev_rmax);
730
727
}
731
728
size_t duplicate_count = 0 ;
@@ -753,11 +750,11 @@ struct WQSummary {
753
750
754
751
// Need to save first since the rmax will be overwritten.
755
752
RType next_aprev_rmax = ObliviousChoose (
756
- merged_party_entrys[idx].is_party_a &&
753
+ merged_party_entrys[idx].is_party_a &
757
754
!merged_party_entrys[idx].is_dummy ,
758
755
merged_party_entrys[idx].entry .RMaxPrev (), a_prev_rmax);
759
756
RType next_bprev_rmax = ObliviousChoose (
760
- !merged_party_entrys[idx].is_party_a &&
757
+ !merged_party_entrys[idx].is_party_a &
761
758
!merged_party_entrys[idx].is_dummy ,
762
759
merged_party_entrys[idx].entry .RMaxPrev (), b_prev_rmax);
763
760
// Add peer RMaxPrev.
@@ -766,10 +763,10 @@ struct WQSummary {
766
763
// Handle equals.
767
764
// Handle dummys
768
765
RType rmin_to_add = ObliviousChoose (
769
- equal_prev && !dummy_item, prevEntry.rmin , static_cast <RType>(0 ));
766
+ equal_prev & !dummy_item, prevEntry.rmin , static_cast <RType>(0 ));
770
767
RType wmin_to_add = ObliviousChoose (
771
- equal_prev && !dummy_item, prevEntry.wmin , static_cast <RType>(0 ));
772
- rmax_to_add = ObliviousChoose (equal_prev && !dummy_item, prevEntry.rmax ,
768
+ equal_prev & !dummy_item, prevEntry.wmin , static_cast <RType>(0 ));
769
+ rmax_to_add = ObliviousChoose (equal_prev & !dummy_item, prevEntry.rmax ,
773
770
rmax_to_add);
774
771
// Update.
775
772
merged_party_entrys[idx].entry .rmax += rmax_to_add;
@@ -779,10 +776,10 @@ struct WQSummary {
779
776
// Copy rmin, rmax, wmin from previous if values are equal.
780
777
// Value is ok to be infinite now since this is two party merge, at most
781
778
// two items are the same given a specific value.
782
- ObliviousAssign (equal_next && !dummy_item, nextEntry,
779
+ ObliviousAssign (equal_next & !dummy_item, nextEntry,
783
780
merged_party_entrys[idx].entry ,
784
781
&merged_party_entrys[idx].entry );
785
- ObliviousAssign (equal_next && !dummy_item,
782
+ ObliviousAssign (equal_next & !dummy_item,
786
783
std::numeric_limits<DType>::max (),
787
784
merged_party_entrys[idx].entry .value ,
788
785
&merged_party_entrys[idx].entry .value );
@@ -799,7 +796,6 @@ struct WQSummary {
799
796
ObliviousSort (this ->data , this ->data + this ->size );
800
797
// std::sort(this->data, this->data + this->size);
801
798
LOG (DEBUG) << __func__ << " PASSED 3" << std::endl;
802
- // exit(1);
803
799
// Need to confirm shrink.
804
800
if (ObliviousDebugCheckEnabled ()) {
805
801
std::vector<Entry> oblivious_results (this ->data , this ->data + this ->size );
@@ -893,9 +889,9 @@ struct WQSummary {
893
889
// helper function to print the current content of sketch
894
890
inline void Print () const {
895
891
for (size_t i = 0 ; i < this ->size ; ++i) {
896
- LOG (CONSOLE ) << " [" << i << " ] rmin=" << data[i].rmin
897
- << " , rmax=" << data[i].rmax << " , wmin=" << data[i].wmin
898
- << " , v=" << data[i].value ;
892
+ LOG (INFO ) << " [" << i << " ] rmin=" << data[i].rmin
893
+ << " , rmax=" << data[i].rmax << " , wmin=" << data[i].wmin
894
+ << " , v=" << data[i].value ;
899
895
}
900
896
}
901
897
@@ -966,9 +962,9 @@ struct WXQSummary : public WQSummary<DType, RType> {
966
962
size_t max_index = 0 ;
967
963
// find actually max item
968
964
for (size_t idx = 0 ; idx < src.size ; idx++) {
969
- max_index = ObliviousChoose (
970
- src. data [idx]. value != std::numeric_limits<DType>::max (), idx,
971
- max_index);
965
+ bool is_valid = ! ObliviousEqual (src. data [idx]. value ,
966
+ std::numeric_limits<DType>::max ());
967
+ max_index = ObliviousChoose (is_valid, idx, max_index);
972
968
}
973
969
max_index += 1 ;
974
970
@@ -1137,8 +1133,8 @@ struct GKSummary {
1137
1133
/* ! \brief used for debug purpose, print the summary */
1138
1134
inline void Print () const {
1139
1135
for (size_t i = 0 ; i < size; ++i) {
1140
- LOG (CONSOLE ) << " x=" << data[i].value << " \t "
1141
- << " [" << data[i].rmin << " ," << data[i].rmax << " ]" ;
1136
+ LOG (INFO ) << " x=" << data[i].value << " \t "
1137
+ << " [" << data[i].rmin << " ," << data[i].rmax << " ]" ;
1142
1138
}
1143
1139
}
1144
1140
/* !
0 commit comments