Skip to content

Commit dcf7370

Browse files
authored
[7.6][ML] Assorted runtime optimisations for classification and regression (#873)
Backport #863.
1 parent bc51132 commit dcf7370

16 files changed

+369
-81
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@ estimating maximum memory usage. (See {ml-pull}781[#781].)
3939
* Stratified fractional cross validation for regression. (See {ml-pull}784[#784].)
4040
* Added `geo_point` supported output for `lat_long` function records. (See {ml-pull}809[#809]
4141
and {pull}47050[#47050].)
42-
* Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].)
4342
* Use a random bag of the data to compute the loss function derivatives for each new
4443
tree which is trained for both regression and classification. (See {ml-pull}811[#811].)
45-
* Emit `prediction_probability` field alongside prediction field in ml results. (See {ml-pull}818[#818].)
44+
* Emit `prediction_probability` field alongside prediction field in ml results.
45+
(See {ml-pull}818[#818].)
46+
* Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].)
47+
* Reduce runtime of classification and regression. (See {ml-pull}863[#863].)
4648

4749
=== Bug Fixes
4850
* Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].)

include/core/CDataFrame.h

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define INCLUDED_ml_core_CDataFrame_h
99

1010
#include <core/CFloatStorage.h>
11+
#include <core/CPackedBitVector.h>
1112
#include <core/CVectorRange.h>
1213
#include <core/Concurrency.h>
1314
#include <core/ImportExport.h>
@@ -26,7 +27,6 @@ namespace ml {
2627
namespace core {
2728
class CDataFrameRowSlice;
2829
class CDataFrameRowSliceHandle;
29-
class CPackedBitVector;
3030
class CTemporaryDirectory;
3131

3232
namespace data_frame_detail {
@@ -35,7 +35,29 @@ using TFloatVec = std::vector<CFloatStorage>;
3535
using TFloatVecItr = TFloatVec::iterator;
3636
using TInt32Vec = std::vector<std::int32_t>;
3737
using TInt32VecCItr = TInt32Vec::const_iterator;
38-
using TPopMaskedRowFunc = std::function<std::size_t()>;
38+
39+
//! \brief A callback used to iterate over only the masked rows.
40+
class CORE_EXPORT CPopMaskedRow {
41+
public:
42+
CPopMaskedRow(std::size_t endSliceRows,
43+
CPackedBitVector::COneBitIndexConstIterator& maskedRow,
44+
const CPackedBitVector::COneBitIndexConstIterator& endMaskedRows)
45+
: m_EndSliceRows{endSliceRows}, m_MaskedRow{&maskedRow}, m_EndMaskedRows{&endMaskedRows} {
46+
}
47+
48+
std::size_t operator()() const {
49+
return ++(*m_MaskedRow) == *m_EndMaskedRows
50+
? m_EndSliceRows
51+
: std::min(**m_MaskedRow, m_EndSliceRows);
52+
}
53+
54+
private:
55+
std::size_t m_EndSliceRows;
56+
CPackedBitVector::COneBitIndexConstIterator* m_MaskedRow;
57+
const CPackedBitVector::COneBitIndexConstIterator* m_EndMaskedRows;
58+
};
59+
60+
using TOptionalPopMaskedRow = boost::optional<CPopMaskedRow>;
3961

4062
//! \brief A lightweight wrapper around a single row of the data frame.
4163
//!
@@ -123,7 +145,7 @@ class CORE_EXPORT CRowIterator final
123145
std::size_t index,
124146
TFloatVecItr rowItr,
125147
TInt32VecCItr docHashItr,
126-
TPopMaskedRowFunc popMaskedRow = nullptr);
148+
const TOptionalPopMaskedRow& popMaskedRow);
127149

128150
//! \name Forward Iterator Contract
129151
//@{
@@ -141,7 +163,7 @@ class CORE_EXPORT CRowIterator final
141163
std::size_t m_Index = 0;
142164
TFloatVecItr m_RowItr;
143165
TInt32VecCItr m_DocHashItr;
144-
TPopMaskedRowFunc m_PopMaskedRow;
166+
TOptionalPopMaskedRow m_PopMaskedRow;
145167
};
146168
}
147169

@@ -469,7 +491,7 @@ class CORE_EXPORT CDataFrame final {
469491
using TStrSizeUMapVec = std::vector<TStrSizeUMap>;
470492
using TSizeSizePr = std::pair<std::size_t, std::size_t>;
471493
using TSizeDataFrameRowSlicePtrVecPr = std::pair<std::size_t, TRowSlicePtrVec>;
472-
using TPopMaskedRowFunc = data_frame_detail::TPopMaskedRowFunc;
494+
using TOptionalPopMaskedRow = data_frame_detail::TOptionalPopMaskedRow;
473495

474496
//! \brief Writes rows to the data frame.
475497
class CDataFrameRowSliceWriter final {
@@ -504,19 +526,19 @@ class CORE_EXPORT CDataFrame final {
504526
TRowFuncVecBoolPr parallelApplyToAllRows(std::size_t numberThreads,
505527
std::size_t beginRows,
506528
std::size_t endRows,
507-
TRowFunc func,
529+
TRowFunc&& func,
508530
const CPackedBitVector* rowMask,
509531
bool commitResult) const;
510532
TRowFuncVecBoolPr sequentialApplyToAllRows(std::size_t beginRows,
511533
std::size_t endRows,
512-
TRowFunc func,
534+
TRowFunc& func,
513535
const CPackedBitVector* rowMask,
514536
bool commitResult) const;
515537

516538
void applyToRowsOfOneSlice(TRowFunc& func,
517539
std::size_t firstRowToRead,
518540
std::size_t endRowsToRead,
519-
TPopMaskedRowFunc popMaskedRow,
541+
const TOptionalPopMaskedRow& popMaskedRow,
520542
const CDataFrameRowSliceHandle& slice) const;
521543

522544
TRowSlicePtrVecCItr beginSlices(std::size_t beginRows) const;

include/core/CFloatStorage.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,46 @@ class CORE_EXPORT CFloatStorage {
8383
//! Implicit construction from a double.
8484
CFloatStorage(double value) : m_Value() { this->set(value); }
8585

86+
//! \name Operators
87+
//@{
88+
bool operator==(const CFloatStorage& rhs) const {
89+
return m_Value == rhs.m_Value;
90+
}
91+
bool operator==(const double& rhs) const {
92+
return static_cast<double>(m_Value) == rhs;
93+
}
94+
bool operator!=(const CFloatStorage& rhs) const {
95+
return m_Value != rhs.m_Value;
96+
}
97+
bool operator!=(const double& rhs) const {
98+
return static_cast<double>(m_Value) != rhs;
99+
}
100+
bool operator<(const CFloatStorage& rhs) const {
101+
return m_Value < rhs.m_Value;
102+
}
103+
bool operator<(const double& rhs) const {
104+
return static_cast<double>(m_Value) < rhs;
105+
}
106+
bool operator<=(const CFloatStorage& rhs) const {
107+
return m_Value <= rhs.m_Value;
108+
}
109+
bool operator<=(const double& rhs) const {
110+
return static_cast<double>(m_Value) <= rhs;
111+
}
112+
bool operator>(const CFloatStorage& rhs) const {
113+
return m_Value > rhs.m_Value;
114+
}
115+
bool operator>(const double& rhs) const {
116+
return static_cast<double>(m_Value) > rhs;
117+
}
118+
bool operator>=(const CFloatStorage& rhs) const {
119+
return m_Value >= rhs.m_Value;
120+
}
121+
bool operator>=(const double& rhs) const {
122+
return static_cast<double>(m_Value) >= rhs;
123+
}
124+
//@}
125+
86126
//! Set from a string.
87127
bool fromString(const std::string& string) {
88128
double value;

include/core/CImmutableRadixSet.h

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
7+
#ifndef INCLUDED_ml_core_CImmutableRadixSet_h
8+
#define INCLUDED_ml_core_CImmutableRadixSet_h
9+
10+
#include <core/CContainerPrinter.h>
11+
12+
#include <algorithm>
13+
#include <limits>
14+
#include <numeric>
15+
#include <vector>
16+
17+
namespace ml {
18+
namespace core {
19+
20+
//! \brief An immutable sorted set which provides very fast lookup.
21+
//!
22+
//! DESCRIPTION:\n
23+
//! This supports lower bound and look up by index as well as a subset of the non
24+
//! modifying interface of std::set. Its main purpose is to provide much faster
25+
//! lookup. To this end it subdivides the range of sorted values into buckets.
26+
//! In the case that the values are uniformly distributed lowerBound will be O(1)
27+
//! with low constant. Otherwise, it is worst case O(log(n)).
28+
template<typename T>
29+
class CImmutableRadixSet {
30+
public:
31+
using TVec = std::vector<T>;
32+
using TCItr = typename std::vector<T>::const_iterator;
33+
34+
public:
35+
// We only need to support floating point types at present (although it
36+
// could easily extended to support any numeric type).
37+
static_assert(std::is_floating_point<T>::value, "Only supports floating point types");
38+
39+
public:
40+
CImmutableRadixSet() = default;
41+
explicit CImmutableRadixSet(std::initializer_list<T> values)
42+
: m_Values{std::move(values)} {
43+
this->initialize();
44+
}
45+
explicit CImmutableRadixSet(TVec values) : m_Values{std::move(values)} {
46+
this->initialize();
47+
}
48+
49+
// This is movable only because we hold iterators to the underlying container.
50+
CImmutableRadixSet(const CImmutableRadixSet&) = delete;
51+
CImmutableRadixSet& operator=(const CImmutableRadixSet&) = delete;
52+
CImmutableRadixSet(CImmutableRadixSet&&) = default;
53+
CImmutableRadixSet& operator=(CImmutableRadixSet&&) = default;
54+
55+
//! \name Capacity
56+
//@{
57+
bool empty() const { return m_Values.size(); }
58+
std::size_t size() const { return m_Values.size(); }
59+
//@}
60+
61+
//! \name Iterators
62+
//@{
63+
TCItr begin() const { m_Values.begin(); }
64+
TCItr end() const { m_Values.end(); }
65+
//@}
66+
67+
//! \name Lookup
68+
//@{
69+
const T& operator[](std::size_t i) const { return m_Values[i]; }
70+
std::ptrdiff_t upperBound(const T& value) const {
71+
// This branch is predictable so essentially free.
72+
if (m_Values.size() < 2) {
73+
return std::distance(m_Values.begin(),
74+
std::upper_bound(m_Values.begin(), m_Values.end(), value));
75+
}
76+
77+
std::ptrdiff_t bucket{static_cast<std::ptrdiff_t>(m_Scale * (value - m_Min))};
78+
if (bucket < 0) {
79+
return 0;
80+
}
81+
if (bucket >= static_cast<std::ptrdiff_t>(m_Buckets.size())) {
82+
return static_cast<std::ptrdiff_t>(m_Values.size());
83+
}
84+
TCItr beginBucket;
85+
TCItr endBucket;
86+
std::tie(beginBucket, endBucket) = m_Buckets[bucket];
87+
return std::distance(m_Values.begin(),
88+
std::upper_bound(beginBucket, endBucket, value));
89+
}
90+
//@}
91+
92+
std::string print() const {
93+
return core::CContainerPrinter::print(m_Values);
94+
}
95+
96+
private:
97+
using TCItrCItrPr = std::pair<TCItr, TCItr>;
98+
using TCItrCItrPrVec = std::vector<TCItrCItrPr>;
99+
using TPtrdiffVec = std::vector<std::ptrdiff_t>;
100+
101+
private:
102+
void initialize() {
103+
std::sort(m_Values.begin(), m_Values.end());
104+
m_Values.erase(std::unique(m_Values.begin(), m_Values.end()), m_Values.end());
105+
if (m_Values.size() > 1) {
106+
std::size_t numberBuckets{m_Values.size()};
107+
m_Min = m_Values[0];
108+
m_Scale = static_cast<T>(numberBuckets) / (m_Values.back() - m_Min);
109+
m_Buckets.reserve(numberBuckets);
110+
T bucket{1};
111+
T bucketClose{m_Min + bucket / m_Scale};
112+
auto start = m_Values.begin();
113+
for (auto i = m_Values.begin(); i != m_Values.end(); ++i) {
114+
if (*i > bucketClose) {
115+
m_Buckets.emplace_back(start, i);
116+
bucket += T{1};
117+
bucketClose = m_Min + bucket / m_Scale;
118+
start = i;
119+
while (*i > bucketClose) {
120+
m_Buckets.emplace_back(start, i + 1);
121+
bucket += T{1};
122+
bucketClose = m_Min + bucket / m_Scale;
123+
}
124+
}
125+
}
126+
if (m_Buckets.size() < numberBuckets) {
127+
m_Buckets.emplace_back(start, m_Values.end());
128+
}
129+
}
130+
}
131+
132+
private:
133+
T m_Min = T{0};
134+
T m_Scale = T{0};
135+
TCItrCItrPrVec m_Buckets;
136+
TVec m_Values;
137+
};
138+
}
139+
}
140+
141+
#endif // INCLUDED_ml_core_CImmutableRadixSet_h

include/maths/CBoostedTreeHyperparameters.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class CBoostedTreeRegularization final {
108108
m_SoftTreeDepthLimit, inserter);
109109
core::CPersistUtils::persist(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE_TAG,
110110
m_SoftTreeDepthTolerance, inserter);
111-
};
111+
}
112112

113113
//! Populate the object from serialized data.
114114
bool acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
@@ -131,7 +131,7 @@ class CBoostedTreeRegularization final {
131131
m_SoftTreeDepthTolerance, traverser))
132132
} while (traverser.next());
133133
return true;
134-
};
134+
}
135135

136136
public:
137137
static const std::string REGULARIZATION_DEPTH_PENALTY_MULTIPLIER_TAG;

0 commit comments

Comments
 (0)