From 37c6cdf644f345223231be55d7ef242f1a39bce5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=B7=E8=A1=8C?= Date: Fri, 10 Jul 2020 18:02:23 +0800 Subject: [PATCH 1/5] Use Oblivious method to fix side channel privacy leak --- include/xgboost/common/quantile.h | 269 ++++++++++++++++++++---------- 1 file changed, 180 insertions(+), 89 deletions(-) diff --git a/include/xgboost/common/quantile.h b/include/xgboost/common/quantile.h index c4a4b501a..7bce9afea 100644 --- a/include/xgboost/common/quantile.h +++ b/include/xgboost/common/quantile.h @@ -1,6 +1,5 @@ /*! * Copyright 2014 by Contributors - * Modifications Copyright 2020 by Secure XGBoost Contributors * \file quantile.h * \brief util to compute quantiles * \author Tianqi Chen @@ -10,11 +9,13 @@ #include #include + #include #include #include #include #include + #include "obl_primitives.h" namespace xgboost { @@ -24,6 +25,7 @@ bool ObliviousSetCombineEnabled(); bool ObliviousSetPruneEnabled(); bool ObliviousDebugCheckEnabled(); bool ObliviousEnabled(); +void SetObliviousMode(bool); template struct WQSummaryEntry { @@ -154,6 +156,7 @@ struct EntryWithPartyInfo { using Entry = WQSummaryEntry; Entry entry; bool is_party_a; + bool is_dummy; inline bool operator<(const EntryWithPartyInfo &b) const { return entry < b.entry; @@ -228,9 +231,9 @@ template void CheckEqualSummary(const WQSummary &lhs, const WQSummary &rhs) { auto trace = [&]() { - LOG(INFO) << "---------- lhs: "; + LOG(CONSOLE) << "---------- lhs: "; lhs.Print(); - LOG(INFO) << "---------- rhs: "; + LOG(CONSOLE) << "---------- rhs: "; rhs.Print(); }; // DEBUG CHECK @@ -301,7 +304,11 @@ struct WQSummary { i = j; } } - + /* MakeSummaryOblivious protect the unique_count variable. + * in->size == qhelper.size + * out->size == qhelper.size + * out->data == || normal unique data | dummy data || + * */ inline void MakeSummaryOblivious(WQSummary *out) { ObliviousSort(queue.begin(), queue.begin() + qtail); @@ -330,34 +337,33 @@ struct WQSummary { } } - struct IsNewDescendingSorter { - bool operator()(const QEntryHelper &a, const QEntryHelper &b) { - return ObliviousGreater(a.is_new, b.is_new); - } - }; - struct ValueSorter { bool operator()(const QEntryHelper &a, const QEntryHelper &b) { return ObliviousLess(a.entry.value, b.entry.value); } }; - // Remove duplicates. - ObliviousSort(qhelper.begin(), qhelper.end(), IsNewDescendingSorter()); + for (size_t idx = 0; idx < qhelper.size(); ++idx) { + qhelper[idx].entry.value = + ObliviousChoose(qhelper[idx].is_new, qhelper[idx].entry.value, + std::numeric_limits::max()); + } // Resort by value. - ObliviousSort(qhelper.begin(), qhelper.begin() + unique_count, - ValueSorter()); + ObliviousSort(qhelper.begin(), qhelper.end(), ValueSorter()); out->size = 0; RType wsum = 0; - for (size_t idx = 0; idx < unique_count; ++idx) { + // is_new represent first sight + for (size_t idx = 0; idx < qhelper.size(); ++idx) { const RType w = qhelper[idx].entry.weight; - out->data[out->size++] = - Entry(wsum, wsum + w, w, qhelper[idx].entry.value); + bool is_new = qhelper[idx].is_new; + ObliviousAssign(is_new, + Entry(wsum, wsum + w, w, qhelper[idx].entry.value), + Entry(-1, -1, 0, std::numeric_limits::max()), + &out->data[out->size++]); wsum += w; } - if (ObliviousDebugCheckEnabled()) { std::vector oblivious_results(out->data, out->data + out->size); this->MakeSummaryRaw(out); @@ -417,6 +423,10 @@ struct WQSummary { size = src.size; std::memcpy(data, src.data, sizeof(Entry) * size); } + inline void CopyFromSize(const WQSummary &src, const size_t insize) { + size = insize; + std::memcpy(data, src.data, sizeof(Entry) * size); + } inline void MakeFromSorted(const Entry *entries, size_t n) { size = 0; for (size_t i = 0; i < n;) { @@ -448,8 +458,9 @@ struct WQSummary { } /*! - * \brief set current summary to be pruned summary of src + * \brief set current summary to be obliviously pruned summary of src * assume data field is already allocated to be at least maxsize + * dummy item will rank last of return and will involved in following computation * \param src source summary * \param maxsize size we can afford in the pruned sketch */ @@ -458,14 +469,23 @@ struct WQSummary { this->CopyFrom(src); return; } - - // Make sure dx2 items are last one when `d == (rmax + rmin) / 2`. - const Entry kDummyEntryWithMaxValue{0, 0, 1, + const Entry kDummyEntryWithMaxValue{-1, -1, 0, std::numeric_limits::max()}; + // Make sure dx2 items are last one when `d == (rmax + rmin) / 2`. const RType begin = src.data[0].rmax; - const RType range = src.data[src.size - 1].rmin - src.data[0].rmax; - const size_t n = maxsize - 1; + const RType n = maxsize - 1; + // max_index is equal to previous src.size + size_t max_index = 0; + RType range = 0; + // find actually max item + for (size_t idx = 0; idx < src.size; idx++) { + max_index = ObliviousChoose( + src.data[idx].value != std::numeric_limits::max(), idx, + max_index); + range = src.data[max_index].rmin - src.data[0].rmax; + } + max_index += 1; // Construct sort vector. using Item = PruneItem; @@ -475,16 +495,27 @@ struct WQSummary { RType dx2 = 2 * ((k * range) / n + begin); items.push_back(Item{kDummyEntryWithMaxValue, dx2, false}); } - std::transform(src.data + 1, src.data + src.size, std::back_inserter(items), - [](const Entry &entry) { - return Item{entry, entry.rmax + entry.rmin, true}; - }); + // ObliviousPrune contains Dummy item,So here we doing this on 2 cases + // CASE i < max_index: handle normal data + // CASE other: handle dummy data + // + for (size_t i = 1; i < src.size; ++i) { + Item obliviousItem = ObliviousChoose( + i < max_index - 1, + Item{src.data[i], src.data[i].rmax + src.data[i].rmin, true}, + Item{kDummyEntryWithMaxValue, std::numeric_limits::max(), + true}); + items.push_back(obliviousItem); + } for (size_t i = 1; i < src.size - 1; ++i) { - items.push_back(Item{src.data[i], - src.data[i].RMinNext() + src.data[i + 1].RMaxPrev(), - true}); + Item obliviousItem = ObliviousChoose( + i < max_index - 1, + Item{src.data[i], src.data[i].RMinNext() + src.data[i + 1].RMaxPrev(), + true}, + Item{kDummyEntryWithMaxValue, std::numeric_limits::max(), + true}); + items.push_back(obliviousItem); } - // Bitonic Sort. LOG(DEBUG) << __func__ << " BEGIN 1" << std::endl; ObliviousSort(items.begin(), items.end()); @@ -493,34 +524,42 @@ struct WQSummary { // Choose entrys. RType last_selected_entry_value = std::numeric_limits::min(); size_t select_count = 0; + Entry lastEntry = items[0].entry; for (size_t i = 1; i < items.size(); ++i) { - bool do_select = !items[i - 1].has_entry && items[i].has_entry && - items[i].entry.value != last_selected_entry_value; + // CASE max_index<=maxsize:All unique item will be select + // CASE other : select unique after dx2 index + bool do_select = ObliviousChoose( + max_index <= maxsize, + items[i].entry.value != last_selected_entry_value && + items[i].entry.value != std::numeric_limits::max(), + !items[i - 1].has_entry && items[i].has_entry && + items[i].entry.value != last_selected_entry_value); ObliviousAssign(do_select, items[i].entry.value, last_selected_entry_value, &last_selected_entry_value); ObliviousAssign(do_select, std::numeric_limits::min(), items[i].rank, &items[i].rank); + ObliviousAssign(i == max_index - 1, src.data[i], lastEntry, &lastEntry); + ObliviousAssign(do_select, items[i].entry, kDummyEntryWithMaxValue, + &items[i].entry); select_count += ObliviousChoose(do_select, 1, 0); } + // Bitonic Sort. LOG(DEBUG) << __func__ << " BEGIN 2" << std::endl; ObliviousSort(items.begin(), items.end()); LOG(DEBUG) << __func__ << " PASSED 2" << std::endl; + // Append actual last item to items vector + for (size_t i = 0; i < src.size; i++) { + ObliviousAssign(i == select_count, lastEntry, items[i].entry, + &items[i].entry); + } + this->data[0] = src.data[0]; - this->size = 1 + select_count; - std::transform(items.begin(), items.begin() + select_count, this->data + 1, - [](const Item &item) { - CHECK(item.has_entry && - item.rank == std::numeric_limits::min()); - return item.entry; - }); + this->size = maxsize; - // First and last ones are always kept in prune. - if (data[size - 1].value != src.data[src.size - 1].value) { - CHECK(size < maxsize); - data[size++] = src.data[src.size - 1]; - } + std::transform(items.begin(), items.begin() + maxsize - 1, this->data + 1, + [](const Item &item) { return item.entry; }); if (ObliviousDebugCheckEnabled()) { std::vector oblivious_results(data, data + size); @@ -605,28 +644,33 @@ struct WQSummary { this->CopyFrom(sa); return; } - using EntryWithPartyInfo = EntryWithPartyInfo; std::vector merged_party_entrys(this->size); // Fill party info and build bitonic sequence. + // std::transform(sa.data, sa.data + sa.size, merged_party_entrys.begin(), [](const Entry &entry) { - return EntryWithPartyInfo{entry, true}; - }); - std::transform(sb.data, sb.data + sb.size, - merged_party_entrys.begin() + sa.size, - [](const Entry &entry) { - return EntryWithPartyInfo{entry, false}; + bool is_dummy = ObliviousChoose( + entry.value == std::numeric_limits::max(), true, + false); + return EntryWithPartyInfo{entry, true, is_dummy}; }); + std::transform( + sb.data, sb.data + sb.size, merged_party_entrys.begin() + sa.size, + [](const Entry &entry) { + bool is_dummy = ObliviousChoose( + entry.value == std::numeric_limits::max(), true, false); + return EntryWithPartyInfo{entry, false, is_dummy}; + }); // Build bitonic sequence. std::reverse(merged_party_entrys.begin(), merged_party_entrys.begin() + sa.size); // Bitonic merge. // ObliviousSort(merged_party_entrys.begin(), merged_party_entrys.end()); ObliviousMerge(merged_party_entrys.begin(), merged_party_entrys.end()); - // Forward pass to compute rmin. + // Forward pass don`t need Oblivious RType a_prev_rmin = 0; RType b_prev_rmin = 0; for (size_t idx = 0; idx < merged_party_entrys.size(); ++idx) { @@ -642,10 +686,12 @@ struct WQSummary { // Save first. RType next_aprev_rmin = ObliviousChoose( - merged_party_entrys[idx].is_party_a, + merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMinNext(), a_prev_rmin); RType next_bprev_rmin = ObliviousChoose( - !merged_party_entrys[idx].is_party_a, + !merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMinNext(), b_prev_rmin); // This is a. Need to add previous b->RMinNext(). @@ -663,9 +709,22 @@ struct WQSummary { } // Backward pass to compute rmax. + // Backward Algo: + // 1、 find really data[last].rmax for sa and sb assign to prev_rmax + // 2、 use is_dummy to contral backward computation dataflow RType a_prev_rmax = sa.data[sa.size - 1].rmax; RType b_prev_rmax = sb.data[sb.size - 1].rmax; + + for (int idx = 0; idx < sa.size; idx++) { + a_prev_rmax = ObliviousChoose(sa.data[idx].rmax > a_prev_rmax, + sa.data[idx].rmax, a_prev_rmax); + } + for (int idx = 0; idx < sb.size; idx++) { + b_prev_rmax = ObliviousChoose(sb.data[idx].rmax > b_prev_rmax, + sb.data[idx].rmax, b_prev_rmax); + } size_t duplicate_count = 0; + size_t dummy_count = 0; for (ssize_t idx = merged_party_entrys.size() - 1; idx >= 0; --idx) { bool equal_prev = idx == 0 ? false @@ -676,28 +735,33 @@ struct WQSummary { ? false : ObliviousEqual(merged_party_entrys[idx].entry.value, merged_party_entrys[idx + 1].entry.value); + bool dummy_item = merged_party_entrys[idx].is_dummy; duplicate_count += ObliviousChoose(equal_next, 1, 0); + dummy_count += ObliviousChoose(merged_party_entrys[idx].is_dummy, 1, 0); // Need to save first since the rmax will be overwritten. RType next_aprev_rmax = ObliviousChoose( - merged_party_entrys[idx].is_party_a, + merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMaxPrev(), a_prev_rmax); RType next_bprev_rmax = ObliviousChoose( - !merged_party_entrys[idx].is_party_a, + !merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMaxPrev(), b_prev_rmax); - // Add peer RMaxPrev. RType rmax_to_add = ObliviousChoose(merged_party_entrys[idx].is_party_a, b_prev_rmax, a_prev_rmax); // Handle equals. - RType rmin_to_add = - ObliviousChoose(equal_prev, merged_party_entrys[idx - 1].entry.rmin, - static_cast(0)); - RType wmin_to_add = - ObliviousChoose(equal_prev, merged_party_entrys[idx - 1].entry.wmin, - static_cast(0)); - rmax_to_add = ObliviousChoose( - equal_prev, merged_party_entrys[idx - 1].entry.rmax, rmax_to_add); + // Handle dummys + RType rmin_to_add = ObliviousChoose( + equal_prev && !dummy_item, merged_party_entrys[idx - 1].entry.rmin, + static_cast(0)); + RType wmin_to_add = ObliviousChoose( + equal_prev && !dummy_item, merged_party_entrys[idx - 1].entry.wmin, + static_cast(0)); + rmax_to_add = + ObliviousChoose(equal_prev && !dummy_item, + merged_party_entrys[idx - 1].entry.rmax, rmax_to_add); // Update. merged_party_entrys[idx].entry.rmax += rmax_to_add; merged_party_entrys[idx].entry.rmin += rmin_to_add; @@ -706,17 +770,17 @@ struct WQSummary { // Copy rmin, rmax, wmin from previous if values are equal. // Value is ok to be infinite now since this is two party merge, at most // two items are the same given a specific value. - ObliviousAssign(equal_next, merged_party_entrys[idx + 1].entry, - merged_party_entrys[idx].entry, - &merged_party_entrys[idx].entry); - ObliviousAssign(equal_next, std::numeric_limits::max(), + ObliviousAssign( + equal_next && !dummy_item, merged_party_entrys[idx + 1].entry, + merged_party_entrys[idx].entry, &merged_party_entrys[idx].entry); + ObliviousAssign(equal_next && !dummy_item, + std::numeric_limits::max(), merged_party_entrys[idx].entry.value, &merged_party_entrys[idx].entry.value); a_prev_rmax = next_aprev_rmax; b_prev_rmax = next_bprev_rmax; } - // Bitonic sort to push duplicates to end of list. std::transform(merged_party_entrys.begin(), merged_party_entrys.end(), this->data, [](const EntryWithPartyInfo &party_entry) { @@ -726,10 +790,8 @@ struct WQSummary { ObliviousSort(this->data, this->data + this->size); // std::sort(this->data, this->data + this->size); LOG(DEBUG) << __func__ << " PASSED 3" << std::endl; - + // exit(1); // Need to confirm shrink. - this->size -= duplicate_count; - if (ObliviousDebugCheckEnabled()) { std::vector oblivious_results(this->data, this->data + this->size); RawSetCombine(sa, sb); @@ -822,7 +884,7 @@ struct WQSummary { // helper function to print the current content of sketch inline void Print() const { for (size_t i = 0; i < this->size; ++i) { - LOG(INFO) << "[" << i << "] rmin=" << data[i].rmin + LOG(CONSOLE) << "[" << i << "] rmin=" << data[i].rmin << ", rmax=" << data[i].rmax << ", wmin=" << data[i].wmin << ", v=" << data[i].value; } @@ -891,20 +953,31 @@ struct WXQSummary : public WQSummary { if (ObliviousSetPruneEnabled()) { return WQSummary::ObliviousSetPrune(src, maxsize); } - if (src.size <= maxsize) { - this->CopyFrom(src); + + size_t max_index = 0; + // find actually max item + for (size_t idx = 0; idx < src.size; idx++) { + max_index = ObliviousChoose( + src.data[idx].value != std::numeric_limits::max(), idx, + max_index); + } + max_index += 1; + + if (max_index <= maxsize) { + this->CopyFromSize(src, max_index); return; } RType begin = src.data[0].rmax; // n is number of points exclude the min/max points size_t n = maxsize - 2, nbig = 0; // these is the range of data exclude the min/max point - RType range = src.data[src.size - 1].rmin - begin; + RType range = src.data[max_index - 1].rmin - begin; + // RType range = src.data[src.size - 1].rmin - begin; // prune off zero weights if (range == 0.0f || maxsize <= 2) { // special case, contain only two effective data pts this->data[0] = src.data[0]; - this->data[1] = src.data[src.size - 1]; + this->data[1] = src.data[max_index - 1]; this->size = 2; return; } else { @@ -919,7 +992,7 @@ struct WXQSummary : public WQSummary { // first scan, grab all the big chunk // moving block index, exclude the two ends. size_t bid = 0; - for (size_t i = 1; i < src.size - 1; ++i) { + for (size_t i = 1; i < max_index - 1; ++i) { // detect big chunk data point in the middle // always save these data points. if (CheckLarge(src.data[i], chunk)) { @@ -931,8 +1004,8 @@ struct WXQSummary : public WQSummary { ++nbig; } } - if (bid != src.size - 2) { - mrange += src.data[src.size - 1].RMaxPrev() - src.data[bid].RMinNext(); + if (bid != max_index - 2) { + mrange += src.data[max_index - 1].RMaxPrev() - src.data[bid].RMinNext(); } } // assert: there cannot be more than n big data points @@ -951,8 +1024,8 @@ struct WXQSummary : public WQSummary { n = n - nbig; // find the rest of point size_t bid = 0, k = 1, lastidx = 0; - for (size_t end = 1; end < src.size; ++end) { - if (end == src.size - 1 || CheckLarge(src.data[end], chunk)) { + for (size_t end = 1; end < max_index; ++end) { + if (end == max_index - 1 || CheckLarge(src.data[end], chunk)) { if (bid != end - 1) { size_t i = bid; RType maxdx2 = src.data[end].RMaxPrev() * 2; @@ -1045,13 +1118,17 @@ struct GKSummary { size = src.size; std::memcpy(data, src.data, sizeof(Entry) * size); } + inline void CopyFromSize(const GKSummary &src, const size_t insize) { + size = insize; + std::memcpy(data, src.data, sizeof(Entry) * size); + } inline void CheckValid(RType eps) const { // assume always valid } /*! \brief used for debug purpose, print the summary */ inline void Print() const { for (size_t i = 0; i < size; ++i) { - LOG(INFO) << "x=" << data[i].value << "\t" + LOG(CONSOLE) << "x=" << data[i].value << "\t" << "[" << data[i].rmin << "," << data[i].rmax << "]"; } } @@ -1324,12 +1401,26 @@ class QuantileSketchTemplate { level[0].SetPrune(*out, limit_size); } } - out->CopyFrom(level[0]); + // filter out all the dummy item + size_t final_size = 0; + for (size_t idx = 0; idx < level[0].size; idx++) { + bool is_valid = !ObliviousEqual(out->data[idx].value, + std::numeric_limits::max()); + final_size += is_valid; + } + out->CopyFromSize(level[0], final_size); } else { if (out->size > limit_size) { temp.Reserve(limit_size); temp.SetPrune(*out, limit_size); - out->CopyFrom(temp); + // filter out all the dummy item + size_t final_size = 0; + for (size_t idx = 0; idx < out->size; idx++) { + bool is_valid = !ObliviousEqual(out->data[idx].value, + std::numeric_limits::max()); + final_size += is_valid; + } + out->CopyFromSize(temp, final_size); } } } @@ -1389,4 +1480,4 @@ class GKQuantileSketch : public QuantileSketchTemplate > {}; } // namespace common } // namespace xgboost -#endif // XGBOOST_COMMON_QUANTILE_H_ +#endif // XGBOOST_COMMON_QUANTILE_H_ \ No newline at end of file From a65d6bf23d8050f2fa0c7edd58054fdbe25c8cf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=B7=E8=A1=8C?= Date: Tue, 14 Jul 2020 00:28:35 +0800 Subject: [PATCH 2/5] occlum g++ env bug fix --- include/xgboost/common/quantile.h | 41 +++++++++++++++++++------------ 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/include/xgboost/common/quantile.h b/include/xgboost/common/quantile.h index 7bce9afea..f3970a377 100644 --- a/include/xgboost/common/quantile.h +++ b/include/xgboost/common/quantile.h @@ -460,9 +460,9 @@ struct WQSummary { /*! * \brief set current summary to be obliviously pruned summary of src * assume data field is already allocated to be at least maxsize - * dummy item will rank last of return and will involved in following computation - * \param src source summary - * \param maxsize size we can afford in the pruned sketch + * dummy item will rank last of return and will involved in following + * computation \param src source summary \param maxsize size we can afford in + * the pruned sketch */ inline void ObliviousSetPrune(const WQSummary &src, size_t maxsize) { if (src.size <= maxsize) { @@ -538,7 +538,7 @@ struct WQSummary { last_selected_entry_value, &last_selected_entry_value); ObliviousAssign(do_select, std::numeric_limits::min(), items[i].rank, &items[i].rank); - ObliviousAssign(i == max_index - 1, src.data[i], lastEntry, &lastEntry); + ObliviousAssign(do_select, items[i].entry, kDummyEntryWithMaxValue, &items[i].entry); select_count += ObliviousChoose(do_select, 1, 0); @@ -549,6 +549,11 @@ struct WQSummary { ObliviousSort(items.begin(), items.end()); LOG(DEBUG) << __func__ << " PASSED 2" << std::endl; + // Assign actual last entry to lastEntry + for (size_t idx = 0; idx < src.size; ++idx) { + ObliviousAssign(idx == max_index - 1, src.data[idx], lastEntry, + &lastEntry); + } // Append actual last item to items vector for (size_t i = 0; i < src.size; i++) { ObliviousAssign(i == select_count, lastEntry, items[i].entry, @@ -645,7 +650,8 @@ struct WQSummary { return; } using EntryWithPartyInfo = EntryWithPartyInfo; - + const Entry kDummyEntryWithMaxValue{-1, -1, 0, + std::numeric_limits::max()}; std::vector merged_party_entrys(this->size); // Fill party info and build bitonic sequence. // @@ -714,7 +720,6 @@ struct WQSummary { // 2、 use is_dummy to contral backward computation dataflow RType a_prev_rmax = sa.data[sa.size - 1].rmax; RType b_prev_rmax = sb.data[sb.size - 1].rmax; - for (int idx = 0; idx < sa.size; idx++) { a_prev_rmax = ObliviousChoose(sa.data[idx].rmax > a_prev_rmax, sa.data[idx].rmax, a_prev_rmax); @@ -725,6 +730,8 @@ struct WQSummary { } size_t duplicate_count = 0; size_t dummy_count = 0; + Entry prevEntry = merged_party_entrys[merged_party_entrys.size() - 1].entry; + Entry nextEntry = kDummyEntryWithMaxValue; for (ssize_t idx = merged_party_entrys.size() - 1; idx >= 0; --idx) { bool equal_prev = idx == 0 ? false @@ -736,6 +743,11 @@ struct WQSummary { : ObliviousEqual(merged_party_entrys[idx].entry.value, merged_party_entrys[idx + 1].entry.value); bool dummy_item = merged_party_entrys[idx].is_dummy; + prevEntry = idx == 0 ? kDummyEntryWithMaxValue + : merged_party_entrys[idx - 1].entry; + nextEntry = idx == merged_party_entrys.size() - 1 + ? kDummyEntryWithMaxValue + : merged_party_entrys[idx + 1].entry; duplicate_count += ObliviousChoose(equal_next, 1, 0); dummy_count += ObliviousChoose(merged_party_entrys[idx].is_dummy, 1, 0); @@ -754,14 +766,11 @@ struct WQSummary { // Handle equals. // Handle dummys RType rmin_to_add = ObliviousChoose( - equal_prev && !dummy_item, merged_party_entrys[idx - 1].entry.rmin, - static_cast(0)); + equal_prev && !dummy_item, prevEntry.rmin, static_cast(0)); RType wmin_to_add = ObliviousChoose( - equal_prev && !dummy_item, merged_party_entrys[idx - 1].entry.wmin, - static_cast(0)); - rmax_to_add = - ObliviousChoose(equal_prev && !dummy_item, - merged_party_entrys[idx - 1].entry.rmax, rmax_to_add); + equal_prev && !dummy_item, prevEntry.wmin, static_cast(0)); + rmax_to_add = ObliviousChoose(equal_prev && !dummy_item, prevEntry.rmax, + rmax_to_add); // Update. merged_party_entrys[idx].entry.rmax += rmax_to_add; merged_party_entrys[idx].entry.rmin += rmin_to_add; @@ -770,9 +779,9 @@ struct WQSummary { // Copy rmin, rmax, wmin from previous if values are equal. // Value is ok to be infinite now since this is two party merge, at most // two items are the same given a specific value. - ObliviousAssign( - equal_next && !dummy_item, merged_party_entrys[idx + 1].entry, - merged_party_entrys[idx].entry, &merged_party_entrys[idx].entry); + ObliviousAssign(equal_next && !dummy_item, nextEntry, + merged_party_entrys[idx].entry, + &merged_party_entrys[idx].entry); ObliviousAssign(equal_next && !dummy_item, std::numeric_limits::max(), merged_party_entrys[idx].entry.value, From 538a312eb559918804253297db56b65de97d9eb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=B7=E8=A1=8C?= Date: Tue, 14 Jul 2020 11:45:23 +0800 Subject: [PATCH 3/5] resolve cr comments --- include/xgboost/common/quantile.h | 88 +++++++++++++++---------------- 1 file changed, 42 insertions(+), 46 deletions(-) diff --git a/include/xgboost/common/quantile.h b/include/xgboost/common/quantile.h index f3970a377..90e9100fb 100644 --- a/include/xgboost/common/quantile.h +++ b/include/xgboost/common/quantile.h @@ -1,5 +1,6 @@ /*! * Copyright 2014 by Contributors + * Modifications Copyright 2020 by Secure XGBoost Contributors * \file quantile.h * \brief util to compute quantiles * \author Tianqi Chen @@ -9,13 +10,11 @@ #include #include - #include #include #include #include #include - #include "obl_primitives.h" namespace xgboost { @@ -231,9 +230,9 @@ template void CheckEqualSummary(const WQSummary &lhs, const WQSummary &rhs) { auto trace = [&]() { - LOG(CONSOLE) << "---------- lhs: "; + LOG(INFO) << "---------- lhs: "; lhs.Print(); - LOG(CONSOLE) << "---------- rhs: "; + LOG(INFO) << "---------- rhs: "; rhs.Print(); }; // DEBUG CHECK @@ -319,7 +318,6 @@ struct WQSummary { helper_entry.entry.weight = 0; } - size_t unique_count = 0; for (size_t idx = 0; idx < qhelper.size(); ++idx) { // sum weight for same value qhelper[idx].entry.weight += queue[idx].weight; @@ -329,7 +327,6 @@ struct WQSummary { : !ObliviousEqual(qhelper[idx + 1].entry.value, qhelper[idx].entry.value); qhelper[idx].is_new = is_new; - unique_count += is_new; if (idx != qhelper.size() - 1) { // Accumulate when next is same with me, otherwise reset to zero. qhelper[idx + 1].entry.weight = @@ -441,7 +438,7 @@ struct WQSummary { } /*! * \brief debug function, validate whether the summary - * run consistency check to check if it is a valid summary + * run consistency check to check if it is a valid summary * \param eps the tolerate error level, used when RType is floating point and * some inconsistency could occur due to rounding error */ @@ -461,8 +458,9 @@ struct WQSummary { * \brief set current summary to be obliviously pruned summary of src * assume data field is already allocated to be at least maxsize * dummy item will rank last of return and will involved in following - * computation \param src source summary \param maxsize size we can afford in - * the pruned sketch + * computation + * \param src source summary \param maxsize size we can afford in + * the pruned sketch */ inline void ObliviousSetPrune(const WQSummary &src, size_t maxsize) { if (src.size <= maxsize) { @@ -501,7 +499,7 @@ struct WQSummary { // for (size_t i = 1; i < src.size; ++i) { Item obliviousItem = ObliviousChoose( - i < max_index - 1, + ObliviousLess(i , max_index - 1), Item{src.data[i], src.data[i].rmax + src.data[i].rmin, true}, Item{kDummyEntryWithMaxValue, std::numeric_limits::max(), true}); @@ -509,7 +507,7 @@ struct WQSummary { } for (size_t i = 1; i < src.size - 1; ++i) { Item obliviousItem = ObliviousChoose( - i < max_index - 1, + ObliviousLess(i , max_index - 1), Item{src.data[i], src.data[i].RMinNext() + src.data[i + 1].RMaxPrev(), true}, Item{kDummyEntryWithMaxValue, std::numeric_limits::max(), @@ -529,11 +527,11 @@ struct WQSummary { // CASE max_index<=maxsize:All unique item will be select // CASE other : select unique after dx2 index bool do_select = ObliviousChoose( - max_index <= maxsize, - items[i].entry.value != last_selected_entry_value && - items[i].entry.value != std::numeric_limits::max(), - !items[i - 1].has_entry && items[i].has_entry && - items[i].entry.value != last_selected_entry_value); + ObliviousLess(max_index , maxsize), + !ObliviousEqual(items[i].entry.value , last_selected_entry_value) & + !ObliviousEqual(items[i].entry.value , std::numeric_limits::max()), + !items[i - 1].has_entry & items[i].has_entry & + !ObliviousEqual(items[i].entry.value , last_selected_entry_value)); ObliviousAssign(do_select, items[i].entry.value, last_selected_entry_value, &last_selected_entry_value); ObliviousAssign(do_select, std::numeric_limits::min(), @@ -551,12 +549,12 @@ struct WQSummary { // Assign actual last entry to lastEntry for (size_t idx = 0; idx < src.size; ++idx) { - ObliviousAssign(idx == max_index - 1, src.data[idx], lastEntry, + ObliviousAssign(ObliviousEqual(idx , max_index - 1), src.data[idx], lastEntry, &lastEntry); } // Append actual last item to items vector for (size_t i = 0; i < src.size; i++) { - ObliviousAssign(i == select_count, lastEntry, items[i].entry, + ObliviousAssign(ObliviousEqual(i , select_count), lastEntry, items[i].entry, &items[i].entry); } @@ -658,17 +656,16 @@ struct WQSummary { std::transform(sa.data, sa.data + sa.size, merged_party_entrys.begin(), [](const Entry &entry) { bool is_dummy = ObliviousChoose( - entry.value == std::numeric_limits::max(), true, + ObliviousEqual(entry.value , std::numeric_limits::max()), true, false); return EntryWithPartyInfo{entry, true, is_dummy}; }); - std::transform( - sb.data, sb.data + sb.size, merged_party_entrys.begin() + sa.size, - [](const Entry &entry) { - bool is_dummy = ObliviousChoose( - entry.value == std::numeric_limits::max(), true, false); - return EntryWithPartyInfo{entry, false, is_dummy}; - }); + std::transform(sb.data, sb.data + sb.size, merged_party_entrys.begin() + sa.size, + [](const Entry &entry) { + bool is_dummy = ObliviousChoose( + ObliviousEqual(entry.value , std::numeric_limits::max()), true, false); + return EntryWithPartyInfo{entry, false, is_dummy}; + }); // Build bitonic sequence. std::reverse(merged_party_entrys.begin(), merged_party_entrys.begin() + sa.size); @@ -692,11 +689,11 @@ struct WQSummary { // Save first. RType next_aprev_rmin = ObliviousChoose( - merged_party_entrys[idx].is_party_a && + merged_party_entrys[idx].is_party_a & !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMinNext(), a_prev_rmin); RType next_bprev_rmin = ObliviousChoose( - !merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_party_a & !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMinNext(), b_prev_rmin); @@ -721,11 +718,11 @@ struct WQSummary { RType a_prev_rmax = sa.data[sa.size - 1].rmax; RType b_prev_rmax = sb.data[sb.size - 1].rmax; for (int idx = 0; idx < sa.size; idx++) { - a_prev_rmax = ObliviousChoose(sa.data[idx].rmax > a_prev_rmax, + a_prev_rmax = ObliviousChoose(ObliviousGreater(sa.data[idx].rmax , a_prev_rmax), sa.data[idx].rmax, a_prev_rmax); } for (int idx = 0; idx < sb.size; idx++) { - b_prev_rmax = ObliviousChoose(sb.data[idx].rmax > b_prev_rmax, + b_prev_rmax = ObliviousChoose(ObliviousGreater(sb.data[idx].rmax , b_prev_rmax), sb.data[idx].rmax, b_prev_rmax); } size_t duplicate_count = 0; @@ -753,11 +750,11 @@ struct WQSummary { // Need to save first since the rmax will be overwritten. RType next_aprev_rmax = ObliviousChoose( - merged_party_entrys[idx].is_party_a && + merged_party_entrys[idx].is_party_a & !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMaxPrev(), a_prev_rmax); RType next_bprev_rmax = ObliviousChoose( - !merged_party_entrys[idx].is_party_a && + !merged_party_entrys[idx].is_party_a & !merged_party_entrys[idx].is_dummy, merged_party_entrys[idx].entry.RMaxPrev(), b_prev_rmax); // Add peer RMaxPrev. @@ -766,10 +763,10 @@ struct WQSummary { // Handle equals. // Handle dummys RType rmin_to_add = ObliviousChoose( - equal_prev && !dummy_item, prevEntry.rmin, static_cast(0)); + equal_prev & !dummy_item, prevEntry.rmin, static_cast(0)); RType wmin_to_add = ObliviousChoose( - equal_prev && !dummy_item, prevEntry.wmin, static_cast(0)); - rmax_to_add = ObliviousChoose(equal_prev && !dummy_item, prevEntry.rmax, + equal_prev & !dummy_item, prevEntry.wmin, static_cast(0)); + rmax_to_add = ObliviousChoose(equal_prev & !dummy_item, prevEntry.rmax, rmax_to_add); // Update. merged_party_entrys[idx].entry.rmax += rmax_to_add; @@ -779,10 +776,10 @@ struct WQSummary { // Copy rmin, rmax, wmin from previous if values are equal. // Value is ok to be infinite now since this is two party merge, at most // two items are the same given a specific value. - ObliviousAssign(equal_next && !dummy_item, nextEntry, + ObliviousAssign(equal_next & !dummy_item, nextEntry, merged_party_entrys[idx].entry, &merged_party_entrys[idx].entry); - ObliviousAssign(equal_next && !dummy_item, + ObliviousAssign(equal_next & !dummy_item, std::numeric_limits::max(), merged_party_entrys[idx].entry.value, &merged_party_entrys[idx].entry.value); @@ -799,7 +796,6 @@ struct WQSummary { ObliviousSort(this->data, this->data + this->size); // std::sort(this->data, this->data + this->size); LOG(DEBUG) << __func__ << " PASSED 3" << std::endl; - // exit(1); // Need to confirm shrink. if (ObliviousDebugCheckEnabled()) { std::vector oblivious_results(this->data, this->data + this->size); @@ -893,9 +889,9 @@ struct WQSummary { // helper function to print the current content of sketch inline void Print() const { for (size_t i = 0; i < this->size; ++i) { - LOG(CONSOLE) << "[" << i << "] rmin=" << data[i].rmin - << ", rmax=" << data[i].rmax << ", wmin=" << data[i].wmin - << ", v=" << data[i].value; + LOG(INFO) << "[" << i << "] rmin=" << data[i].rmin + << ", rmax=" << data[i].rmax << ", wmin=" << data[i].wmin + << ", v=" << data[i].value; } } @@ -966,9 +962,9 @@ struct WXQSummary : public WQSummary { size_t max_index = 0; // find actually max item for (size_t idx = 0; idx < src.size; idx++) { - max_index = ObliviousChoose( - src.data[idx].value != std::numeric_limits::max(), idx, - max_index); + bool is_valid = !ObliviousEqual(src.data[idx].value, + std::numeric_limits::max()); + max_index = ObliviousChoose(is_valid, idx, max_index); } max_index += 1; @@ -1137,8 +1133,8 @@ struct GKSummary { /*! \brief used for debug purpose, print the summary */ inline void Print() const { for (size_t i = 0; i < size; ++i) { - LOG(CONSOLE) << "x=" << data[i].value << "\t" - << "[" << data[i].rmin << "," << data[i].rmax << "]"; + LOG(INFO) << "x=" << data[i].value << "\t" + << "[" << data[i].rmin << "," << data[i].rmax << "]"; } } /*! From e73a3ee8ad0fe129a4ce0b4975ca1cbf4f88aa20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=B7=E8=A1=8C?= Date: Wed, 15 Jul 2020 16:09:21 +0800 Subject: [PATCH 4/5] resolve cr comments --- include/xgboost/common/quantile.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/xgboost/common/quantile.h b/include/xgboost/common/quantile.h index 90e9100fb..2e40cbfb6 100644 --- a/include/xgboost/common/quantile.h +++ b/include/xgboost/common/quantile.h @@ -459,7 +459,8 @@ struct WQSummary { * assume data field is already allocated to be at least maxsize * dummy item will rank last of return and will involved in following * computation - * \param src source summary \param maxsize size we can afford in + * \param src source summary + * \param maxsize size we can afford in * the pruned sketch */ inline void ObliviousSetPrune(const WQSummary &src, size_t maxsize) { @@ -479,7 +480,7 @@ struct WQSummary { // find actually max item for (size_t idx = 0; idx < src.size; idx++) { max_index = ObliviousChoose( - src.data[idx].value != std::numeric_limits::max(), idx, + !ObliviousEqual(src.data[idx].value , std::numeric_limits::max()), idx, max_index); range = src.data[max_index].rmin - src.data[0].rmax; } @@ -527,7 +528,7 @@ struct WQSummary { // CASE max_index<=maxsize:All unique item will be select // CASE other : select unique after dx2 index bool do_select = ObliviousChoose( - ObliviousLess(max_index , maxsize), + ObliviousLessOrEqual(max_index , maxsize), !ObliviousEqual(items[i].entry.value , last_selected_entry_value) & !ObliviousEqual(items[i].entry.value , std::numeric_limits::max()), !items[i - 1].has_entry & items[i].has_entry & From fff26359cec5410222560a34cc55907126cceeb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=B7=E8=A1=8C?= Date: Thu, 16 Jul 2020 10:35:40 +0800 Subject: [PATCH 5/5] resolve cr comments --- include/xgboost/common/quantile.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/xgboost/common/quantile.h b/include/xgboost/common/quantile.h index 2e40cbfb6..5f7c35dcb 100644 --- a/include/xgboost/common/quantile.h +++ b/include/xgboost/common/quantile.h @@ -24,7 +24,6 @@ bool ObliviousSetCombineEnabled(); bool ObliviousSetPruneEnabled(); bool ObliviousDebugCheckEnabled(); bool ObliviousEnabled(); -void SetObliviousMode(bool); template struct WQSummaryEntry {