Skip to content

Commit 01a8865

Browse files
committed
resolve cr comments
1 parent a65d6bf commit 01a8865

File tree

1 file changed

+42
-46
lines changed

1 file changed

+42
-46
lines changed

include/xgboost/common/quantile.h

Lines changed: 42 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*!
22
* Copyright 2014 by Contributors
3+
* Modifications Copyright 2020 by Secure XGBoost Contributors
34
* \file quantile.h
45
* \brief util to compute quantiles
56
* \author Tianqi Chen
@@ -9,13 +10,11 @@
910

1011
#include <dmlc/base.h>
1112
#include <xgboost/logging.h>
12-
1313
#include <algorithm>
1414
#include <cmath>
1515
#include <cstring>
1616
#include <iostream>
1717
#include <vector>
18-
1918
#include "obl_primitives.h"
2019

2120
namespace xgboost {
@@ -231,9 +230,9 @@ template <typename DType, typename RType>
231230
void CheckEqualSummary(const WQSummary<DType, RType> &lhs,
232231
const WQSummary<DType, RType> &rhs) {
233232
auto trace = [&]() {
234-
LOG(CONSOLE) << "---------- lhs: ";
233+
LOG(INFO) << "---------- lhs: ";
235234
lhs.Print();
236-
LOG(CONSOLE) << "---------- rhs: ";
235+
LOG(INFO) << "---------- rhs: ";
237236
rhs.Print();
238237
};
239238
// DEBUG CHECK
@@ -319,7 +318,6 @@ struct WQSummary {
319318
helper_entry.entry.weight = 0;
320319
}
321320

322-
size_t unique_count = 0;
323321
for (size_t idx = 0; idx < qhelper.size(); ++idx) {
324322
// sum weight for same value
325323
qhelper[idx].entry.weight += queue[idx].weight;
@@ -329,7 +327,6 @@ struct WQSummary {
329327
: !ObliviousEqual(qhelper[idx + 1].entry.value,
330328
qhelper[idx].entry.value);
331329
qhelper[idx].is_new = is_new;
332-
unique_count += is_new;
333330
if (idx != qhelper.size() - 1) {
334331
// Accumulate when next is same with me, otherwise reset to zero.
335332
qhelper[idx + 1].entry.weight =
@@ -441,7 +438,7 @@ struct WQSummary {
441438
}
442439
/*!
443440
* \brief debug function, validate whether the summary
444-
* run consistency check to check if it is a valid summary
441+
* run consistency check to check if it is a valid summary
445442
* \param eps the tolerate error level, used when RType is floating point and
446443
* some inconsistency could occur due to rounding error
447444
*/
@@ -461,8 +458,9 @@ struct WQSummary {
461458
* \brief set current summary to be obliviously pruned summary of src
462459
* assume data field is already allocated to be at least maxsize
463460
* dummy item will rank last of return and will involved in following
464-
* computation \param src source summary \param maxsize size we can afford in
465-
* the pruned sketch
461+
* computation
462+
* \param src source summary \param maxsize size we can afford in
463+
* the pruned sketch
466464
*/
467465
inline void ObliviousSetPrune(const WQSummary &src, size_t maxsize) {
468466
if (src.size <= maxsize) {
@@ -501,15 +499,15 @@ struct WQSummary {
501499
//
502500
for (size_t i = 1; i < src.size; ++i) {
503501
Item obliviousItem = ObliviousChoose(
504-
i < max_index - 1,
502+
ObliviousLess(i , max_index - 1),
505503
Item{src.data[i], src.data[i].rmax + src.data[i].rmin, true},
506504
Item{kDummyEntryWithMaxValue, std::numeric_limits<RType>::max(),
507505
true});
508506
items.push_back(obliviousItem);
509507
}
510508
for (size_t i = 1; i < src.size - 1; ++i) {
511509
Item obliviousItem = ObliviousChoose(
512-
i < max_index - 1,
510+
ObliviousLess(i , max_index - 1),
513511
Item{src.data[i], src.data[i].RMinNext() + src.data[i + 1].RMaxPrev(),
514512
true},
515513
Item{kDummyEntryWithMaxValue, std::numeric_limits<RType>::max(),
@@ -529,11 +527,11 @@ struct WQSummary {
529527
// CASE max_index<=maxsize:All unique item will be select
530528
// CASE other : select unique after dx2 index
531529
bool do_select = ObliviousChoose(
532-
max_index <= maxsize,
533-
items[i].entry.value != last_selected_entry_value &&
534-
items[i].entry.value != std::numeric_limits<RType>::max(),
535-
!items[i - 1].has_entry && items[i].has_entry &&
536-
items[i].entry.value != last_selected_entry_value);
530+
ObliviousLess(max_index , maxsize),
531+
!ObliviousEqual(items[i].entry.value , last_selected_entry_value) &
532+
!ObliviousEqual(items[i].entry.value , std::numeric_limits<RType>::max()),
533+
!items[i - 1].has_entry & items[i].has_entry &
534+
!ObliviousEqual(items[i].entry.value , last_selected_entry_value));
537535
ObliviousAssign(do_select, items[i].entry.value,
538536
last_selected_entry_value, &last_selected_entry_value);
539537
ObliviousAssign(do_select, std::numeric_limits<RType>::min(),
@@ -551,12 +549,12 @@ struct WQSummary {
551549

552550
// Assign actual last entry to lastEntry
553551
for (size_t idx = 0; idx < src.size; ++idx) {
554-
ObliviousAssign(idx == max_index - 1, src.data[idx], lastEntry,
552+
ObliviousAssign(ObliviousEqual(idx , max_index - 1), src.data[idx], lastEntry,
555553
&lastEntry);
556554
}
557555
// Append actual last item to items vector
558556
for (size_t i = 0; i < src.size; i++) {
559-
ObliviousAssign(i == select_count, lastEntry, items[i].entry,
557+
ObliviousAssign(ObliviousEqual(i , select_count), lastEntry, items[i].entry,
560558
&items[i].entry);
561559
}
562560

@@ -658,17 +656,16 @@ struct WQSummary {
658656
std::transform(sa.data, sa.data + sa.size, merged_party_entrys.begin(),
659657
[](const Entry &entry) {
660658
bool is_dummy = ObliviousChoose(
661-
entry.value == std::numeric_limits<DType>::max(), true,
659+
ObliviousEqual(entry.value , std::numeric_limits<DType>::max()), true,
662660
false);
663661
return EntryWithPartyInfo{entry, true, is_dummy};
664662
});
665-
std::transform(
666-
sb.data, sb.data + sb.size, merged_party_entrys.begin() + sa.size,
667-
[](const Entry &entry) {
668-
bool is_dummy = ObliviousChoose(
669-
entry.value == std::numeric_limits<DType>::max(), true, false);
670-
return EntryWithPartyInfo{entry, false, is_dummy};
671-
});
663+
std::transform(sb.data, sb.data + sb.size, merged_party_entrys.begin() + sa.size,
664+
[](const Entry &entry) {
665+
bool is_dummy = ObliviousChoose(
666+
ObliviousEqual(entry.value , std::numeric_limits<DType>::max()), true, false);
667+
return EntryWithPartyInfo{entry, false, is_dummy};
668+
});
672669
// Build bitonic sequence.
673670
std::reverse(merged_party_entrys.begin(),
674671
merged_party_entrys.begin() + sa.size);
@@ -692,11 +689,11 @@ struct WQSummary {
692689

693690
// Save first.
694691
RType next_aprev_rmin = ObliviousChoose(
695-
merged_party_entrys[idx].is_party_a &&
692+
merged_party_entrys[idx].is_party_a &
696693
!merged_party_entrys[idx].is_dummy,
697694
merged_party_entrys[idx].entry.RMinNext(), a_prev_rmin);
698695
RType next_bprev_rmin = ObliviousChoose(
699-
!merged_party_entrys[idx].is_party_a &&
696+
!merged_party_entrys[idx].is_party_a &
700697
!merged_party_entrys[idx].is_dummy,
701698
merged_party_entrys[idx].entry.RMinNext(), b_prev_rmin);
702699

@@ -721,11 +718,11 @@ struct WQSummary {
721718
RType a_prev_rmax = sa.data[sa.size - 1].rmax;
722719
RType b_prev_rmax = sb.data[sb.size - 1].rmax;
723720
for (int idx = 0; idx < sa.size; idx++) {
724-
a_prev_rmax = ObliviousChoose(sa.data[idx].rmax > a_prev_rmax,
721+
a_prev_rmax = ObliviousChoose(ObliviousGreater(sa.data[idx].rmax , a_prev_rmax),
725722
sa.data[idx].rmax, a_prev_rmax);
726723
}
727724
for (int idx = 0; idx < sb.size; idx++) {
728-
b_prev_rmax = ObliviousChoose(sb.data[idx].rmax > b_prev_rmax,
725+
b_prev_rmax = ObliviousChoose(ObliviousGreater(sb.data[idx].rmax , b_prev_rmax),
729726
sb.data[idx].rmax, b_prev_rmax);
730727
}
731728
size_t duplicate_count = 0;
@@ -753,11 +750,11 @@ struct WQSummary {
753750

754751
// Need to save first since the rmax will be overwritten.
755752
RType next_aprev_rmax = ObliviousChoose(
756-
merged_party_entrys[idx].is_party_a &&
753+
merged_party_entrys[idx].is_party_a &
757754
!merged_party_entrys[idx].is_dummy,
758755
merged_party_entrys[idx].entry.RMaxPrev(), a_prev_rmax);
759756
RType next_bprev_rmax = ObliviousChoose(
760-
!merged_party_entrys[idx].is_party_a &&
757+
!merged_party_entrys[idx].is_party_a &
761758
!merged_party_entrys[idx].is_dummy,
762759
merged_party_entrys[idx].entry.RMaxPrev(), b_prev_rmax);
763760
// Add peer RMaxPrev.
@@ -766,10 +763,10 @@ struct WQSummary {
766763
// Handle equals.
767764
// Handle dummys
768765
RType rmin_to_add = ObliviousChoose(
769-
equal_prev && !dummy_item, prevEntry.rmin, static_cast<RType>(0));
766+
equal_prev & !dummy_item, prevEntry.rmin, static_cast<RType>(0));
770767
RType wmin_to_add = ObliviousChoose(
771-
equal_prev && !dummy_item, prevEntry.wmin, static_cast<RType>(0));
772-
rmax_to_add = ObliviousChoose(equal_prev && !dummy_item, prevEntry.rmax,
768+
equal_prev & !dummy_item, prevEntry.wmin, static_cast<RType>(0));
769+
rmax_to_add = ObliviousChoose(equal_prev & !dummy_item, prevEntry.rmax,
773770
rmax_to_add);
774771
// Update.
775772
merged_party_entrys[idx].entry.rmax += rmax_to_add;
@@ -779,10 +776,10 @@ struct WQSummary {
779776
// Copy rmin, rmax, wmin from previous if values are equal.
780777
// Value is ok to be infinite now since this is two party merge, at most
781778
// two items are the same given a specific value.
782-
ObliviousAssign(equal_next && !dummy_item, nextEntry,
779+
ObliviousAssign(equal_next & !dummy_item, nextEntry,
783780
merged_party_entrys[idx].entry,
784781
&merged_party_entrys[idx].entry);
785-
ObliviousAssign(equal_next && !dummy_item,
782+
ObliviousAssign(equal_next & !dummy_item,
786783
std::numeric_limits<DType>::max(),
787784
merged_party_entrys[idx].entry.value,
788785
&merged_party_entrys[idx].entry.value);
@@ -799,7 +796,6 @@ struct WQSummary {
799796
ObliviousSort(this->data, this->data + this->size);
800797
// std::sort(this->data, this->data + this->size);
801798
LOG(DEBUG) << __func__ << " PASSED 3" << std::endl;
802-
// exit(1);
803799
// Need to confirm shrink.
804800
if (ObliviousDebugCheckEnabled()) {
805801
std::vector<Entry> oblivious_results(this->data, this->data + this->size);
@@ -893,9 +889,9 @@ struct WQSummary {
893889
// helper function to print the current content of sketch
894890
inline void Print() const {
895891
for (size_t i = 0; i < this->size; ++i) {
896-
LOG(CONSOLE) << "[" << i << "] rmin=" << data[i].rmin
897-
<< ", rmax=" << data[i].rmax << ", wmin=" << data[i].wmin
898-
<< ", v=" << data[i].value;
892+
LOG(INFO) << "[" << i << "] rmin=" << data[i].rmin
893+
<< ", rmax=" << data[i].rmax << ", wmin=" << data[i].wmin
894+
<< ", v=" << data[i].value;
899895
}
900896
}
901897

@@ -966,9 +962,9 @@ struct WXQSummary : public WQSummary<DType, RType> {
966962
size_t max_index = 0;
967963
// find actually max item
968964
for (size_t idx = 0; idx < src.size; idx++) {
969-
max_index = ObliviousChoose(
970-
src.data[idx].value != std::numeric_limits<DType>::max(), idx,
971-
max_index);
965+
bool is_valid = !ObliviousEqual(src.data[idx].value,
966+
std::numeric_limits<DType>::max());
967+
max_index = ObliviousChoose(is_valid, idx, max_index);
972968
}
973969
max_index += 1;
974970

@@ -1137,8 +1133,8 @@ struct GKSummary {
11371133
/*! \brief used for debug purpose, print the summary */
11381134
inline void Print() const {
11391135
for (size_t i = 0; i < size; ++i) {
1140-
LOG(CONSOLE) << "x=" << data[i].value << "\t"
1141-
<< "[" << data[i].rmin << "," << data[i].rmax << "]";
1136+
LOG(INFO) << "x=" << data[i].value << "\t"
1137+
<< "[" << data[i].rmin << "," << data[i].rmax << "]";
11421138
}
11431139
}
11441140
/*!

0 commit comments

Comments
 (0)