Skip to content

Commit cbd5a3c

Browse files
committed
Move prediction cache into Learner.
* Clean-ups - Remove duplicated cache in Learner and GBM. - Remove ad-hoc fix of invalid cache. - Remove `PredictFromCache` in predictors. - Remove prediction cache for linear altogether, as it's only moving the prediction into training process but doesn't provide any actual overall speed gain. - The cache is now unique to Learner, which means the ownership is no longer shared by any other components. * Changes - Add version to prediction cache. - Use weak ptr to check expired DMatrix. - Pass shared pointer instead of raw pointer.
1 parent 29eeea7 commit cbd5a3c

25 files changed

+481
-391
lines changed

include/xgboost/gbm.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*!
2-
* Copyright by Contributors
2+
* Copyright 2014-2020 by Contributors
33
* \file gbm.h
44
* \brief Interface of gradient booster,
55
* that learns through gradient statistics.
@@ -18,6 +18,7 @@
1818
#include <utility>
1919
#include <string>
2020
#include <functional>
21+
#include <unordered_map>
2122
#include <memory>
2223

2324
namespace xgboost {
@@ -28,6 +29,8 @@ class ObjFunction;
2829

2930
struct GenericParameter;
3031
struct LearnerModelParam;
32+
struct PredictionCacheEntry;
33+
class PredictionContainer;
3134

3235
/*!
3336
* \brief interface of gradient boosting model.
@@ -38,7 +41,7 @@ class GradientBooster : public Model, public Configurable {
3841

3942
public:
4043
/*! \brief virtual destructor */
41-
virtual ~GradientBooster() = default;
44+
~GradientBooster() override = default;
4245
/*!
4346
* \brief Set the configuration of gradient boosting.
4447
* User must call configure once before InitModel and Training.
@@ -71,19 +74,22 @@ class GradientBooster : public Model, public Configurable {
7174
* \param obj The objective function, optional, can be nullptr when use customized version
7275
* the booster may change content of gpair
7376
*/
74-
virtual void DoBoost(DMatrix* p_fmat,
75-
HostDeviceVector<GradientPair>* in_gpair,
76-
ObjFunction* obj = nullptr) = 0;
77+
virtual void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
78+
PredictionCacheEntry *prediction) = 0;
7779

7880
/*!
7981
* \brief generate predictions for given feature matrix
8082
* \param dmat feature matrix
8183
* \param out_preds output vector to hold the predictions
82-
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
83-
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
84+
* \param training Whether the prediction value is used for training. For dart booster
85+
* drop out is performed during training.
86+
* \param ntree_limit limit the number of trees used in prediction,
87+
* when it equals 0, this means we do not limit
88+
* number of trees, this parameter is only valid
89+
* for gbtree, but not for gblinear
8490
*/
8591
virtual void PredictBatch(DMatrix* dmat,
86-
HostDeviceVector<bst_float>* out_preds,
92+
PredictionCacheEntry* out_preds,
8793
bool training,
8894
unsigned ntree_limit = 0) = 0;
8995
/*!
@@ -158,8 +164,7 @@ class GradientBooster : public Model, public Configurable {
158164
static GradientBooster* Create(
159165
const std::string& name,
160166
GenericParameter const* generic_param,
161-
LearnerModelParam const* learner_model_param,
162-
const std::vector<std::shared_ptr<DMatrix> >& cache_mats);
167+
LearnerModelParam const* learner_model_param);
163168

164169
static void AssertGPUSupport() {
165170
#ifndef XGBOOST_USE_CUDA
@@ -174,8 +179,7 @@ class GradientBooster : public Model, public Configurable {
174179
struct GradientBoosterReg
175180
: public dmlc::FunctionRegEntryBase<
176181
GradientBoosterReg,
177-
std::function<GradientBooster* (const std::vector<std::shared_ptr<DMatrix> > &cached_mats,
178-
LearnerModelParam const* learner_model_param)> > {
182+
std::function<GradientBooster* (LearnerModelParam const* learner_model_param)> > {
179183
};
180184

181185
/*!

include/xgboost/predictor.h

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*!
2-
* Copyright by Contributors
2+
* Copyright 2017-2020 by Contributors
33
* \file predictor.h
44
* \brief Interface of predictor,
55
* performs predictions for a gradient booster.
@@ -32,47 +32,83 @@ namespace xgboost {
3232
* \brief Contains pointer to input matrix and associated cached predictions.
3333
*/
3434
struct PredictionCacheEntry {
35-
std::shared_ptr<DMatrix> data;
35+
// A storage for caching prediction values
3636
HostDeviceVector<bst_float> predictions;
37+
// The version of current cache, corresponding number of layers of trees
38+
uint32_t version;
39+
// A weak pointer for checking whether the DMatrix object has expired.
40+
std::weak_ptr< DMatrix > ref;
41+
42+
PredictionCacheEntry() : version { 0 } {}
43+
/* \brief Update the cache entry by number of versions.
44+
*
45+
* \param v Added versions.
46+
*/
47+
void Update(uint32_t v) {
48+
version += v;
49+
}
50+
};
51+
52+
/* \brief A container for managed prediction caches.
53+
*/
54+
class PredictionContainer {
55+
std::unordered_map<DMatrix *, PredictionCacheEntry> container_;
56+
void ClearExpiredEntries();
57+
58+
public:
59+
PredictionContainer() = default;
60+
/* \brief Add a new DMatrix to the cache, at the same time this function will clear out
61+
* all expired caches by checking the `std::weak_ptr`. Caching an existing
62+
* DMatrix won't renew it.
63+
*
64+
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
65+
* entry this shared pointer is necessary. More importantly, the life time of this
66+
* cache is tied to the shared pointer.
67+
*
68+
* Another way to make a safe cache is create a proxy to this entry, with anther shared
69+
* pointer defined inside, and pass this proxy around instead of the real entry. But
70+
* seems to be too messy. In XGBoost, functions like `UpdateOneIter` will have
71+
* (memory) safe access to the DMatrix as long as it's passed in as a `shared_ptr`.
72+
*
73+
* \param m shared pointer to the DMatrix that needs to be cached.
74+
* \param device Which device should the cache be allocated on. Pass
75+
* GenericParameter::kCpuId for CPU or positive integer for GPU id.
76+
*
77+
* \return the cache entry for passed in DMatrix, either an existing cache or newly
78+
* created.
79+
*/
80+
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device);
81+
/* \brief Get a prediction cache entry. This entry must be already allocated by `Cache`
82+
* method. Otherwise a dmlc::Error is thrown.
83+
*
84+
* \param m pointer the DMatrix.
85+
* \return The prediction cache for passed in DMatrix.
86+
*/
87+
PredictionCacheEntry& Entry(DMatrix* m);
88+
/* \brief Get a const reference to the underlying hash map. Clear expired caches before
89+
* returning.
90+
*/
91+
decltype(container_) const& Container();
3792
};
3893

3994
/**
4095
* \class Predictor
4196
*
42-
* \brief Performs prediction on individual training instances or batches of
43-
* instances for GBTree. The predictor also manages a prediction cache
44-
* associated with input matrices. If possible, it will use previously
45-
* calculated predictions instead of calculating new predictions.
46-
* Prediction functions all take a GBTreeModel and a DMatrix as input and
47-
* output a vector of predictions. The predictor does not modify any state of
48-
* the model itself.
97+
* \brief Performs prediction on individual training instances or batches of instances for
98+
* GBTree. Prediction functions all take a GBTreeModel and a DMatrix as input and
99+
* output a vector of predictions. The predictor does not modify any state of the
100+
* model itself.
49101
*/
50-
51102
class Predictor {
52103
protected:
53104
/*
54105
* \brief Runtime parameters.
55106
*/
56107
GenericParameter const* generic_param_;
57-
/**
58-
* \brief Map of matrices and associated cached predictions to facilitate
59-
* storing and looking up predictions.
60-
*/
61-
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache_;
62-
63-
std::unordered_map<DMatrix*, PredictionCacheEntry>::iterator FindCache(DMatrix const* dmat) {
64-
auto cache_emtry = std::find_if(
65-
cache_->begin(), cache_->end(),
66-
[dmat](std::pair<DMatrix *, PredictionCacheEntry const &> const &kv) {
67-
return kv.second.data.get() == dmat;
68-
});
69-
return cache_emtry;
70-
}
71108

72109
public:
73-
Predictor(GenericParameter const* generic_param,
74-
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache) :
75-
generic_param_{generic_param}, cache_{cache} {}
110+
explicit Predictor(GenericParameter const* generic_param) :
111+
generic_param_{generic_param} {}
76112
virtual ~Predictor() = default;
77113

78114
/**
@@ -91,12 +127,11 @@ class Predictor {
91127
* \param model The model to predict from.
92128
* \param tree_begin The tree begin index.
93129
* \param ntree_limit (Optional) The ntree limit. 0 means do not
94-
* limit trees.
130+
* limit trees.
95131
*/
96-
97-
virtual void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
132+
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
98133
const gbm::GBTreeModel& model, int tree_begin,
99-
unsigned ntree_limit = 0) = 0;
134+
uint32_t const ntree_limit = 0) = 0;
100135

101136
/**
102137
* \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel
@@ -116,7 +151,9 @@ class Predictor {
116151
virtual void UpdatePredictionCache(
117152
const gbm::GBTreeModel& model,
118153
std::vector<std::unique_ptr<TreeUpdater>>* updaters,
119-
int num_new_trees) = 0;
154+
int num_new_trees,
155+
DMatrix* m,
156+
PredictionCacheEntry* predts) = 0;
120157

121158
/**
122159
* \fn virtual void Predictor::PredictInstance( const SparsePage::Inst&
@@ -200,18 +237,15 @@ class Predictor {
200237
* \param cache Pointer to prediction cache.
201238
*/
202239
static Predictor* Create(
203-
std::string const& name, GenericParameter const* generic_param,
204-
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache);
240+
std::string const& name, GenericParameter const* generic_param);
205241
};
206242

207243
/*!
208244
* \brief Registry entry for predictor.
209245
*/
210246
struct PredictorReg
211247
: public dmlc::FunctionRegEntryBase<
212-
PredictorReg, std::function<Predictor*(
213-
GenericParameter const*,
214-
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>>)>> {};
248+
PredictorReg, std::function<Predictor*(GenericParameter const*)>> {};
215249

216250
#define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \
217251
static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \

include/xgboost/tree_model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class RegTree : public Model {
158158
}
159159
/*! \brief whether this node is deleted */
160160
XGBOOST_DEVICE bool IsDeleted() const {
161-
return sindex_ == std::numeric_limits<unsigned>::max();
161+
return sindex_ == std::numeric_limits<uint32_t>::max();
162162
}
163163
/*! \brief whether current node is root */
164164
XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }

src/c_api/c_api.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
#include "xgboost/data.h"
16+
#include "xgboost/host_device_vector.h"
1617
#include "xgboost/learner.h"
1718
#include "xgboost/c_api.h"
1819
#include "xgboost/logging.h"
@@ -146,7 +147,7 @@ struct XGBAPIThreadLocalEntry {
146147
/*! \brief result holder for returning string pointers */
147148
std::vector<const char *> ret_vec_charp;
148149
/*! \brief returning float vector. */
149-
std::vector<bst_float> ret_vec_float;
150+
HostDeviceVector<bst_float> ret_vec_float;
150151
/*! \brief temp variable of gradient pairs. */
151152
std::vector<GradientPair> tmp_gpair;
152153
};
@@ -553,24 +554,22 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
553554
int32_t training,
554555
xgboost::bst_ulong *len,
555556
const bst_float **out_result) {
556-
std::vector<bst_float>& preds =
557+
HostDeviceVector<bst_float>& preds =
557558
XGBAPIThreadLocalStore::Get()->ret_vec_float;
558559
API_BEGIN();
559560
CHECK_HANDLE();
560561
auto *bst = static_cast<Learner*>(handle);
561-
HostDeviceVector<bst_float> tmp_preds;
562562
bst->Predict(
563563
*static_cast<std::shared_ptr<DMatrix>*>(dmat),
564564
(option_mask & 1) != 0,
565-
&tmp_preds, ntree_limit,
565+
&preds, ntree_limit,
566566
static_cast<bool>(training),
567567
(option_mask & 2) != 0,
568568
(option_mask & 4) != 0,
569569
(option_mask & 8) != 0,
570570
(option_mask & 16) != 0);
571-
preds = tmp_preds.HostVector();
572-
*out_result = dmlc::BeginPtr(preds);
573-
*len = static_cast<xgboost::bst_ulong>(preds.size());
571+
*out_result = dmlc::BeginPtr(preds.HostVector());
572+
*len = static_cast<xgboost::bst_ulong>(preds.Size());
574573
API_END();
575574
}
576575

0 commit comments

Comments
 (0)