11/* !
2- * Copyright by Contributors
2+ * Copyright 2017-2020 by Contributors
33 * \file predictor.h
44 * \brief Interface of predictor,
55 * performs predictions for a gradient booster.
@@ -32,47 +32,83 @@ namespace xgboost {
3232 * \brief Contains pointer to input matrix and associated cached predictions.
3333 */
3434struct PredictionCacheEntry {
35- std::shared_ptr<DMatrix> data;
35+ // A storage for caching prediction values
3636 HostDeviceVector<bst_float> predictions;
37+ // The version of current cache, corresponding number of layers of trees
38+ uint32_t version;
39+ // A weak pointer for checking whether the DMatrix object has expired.
40+ std::weak_ptr< DMatrix > ref;
41+
42+ PredictionCacheEntry () : version { 0 } {}
43+ /* \brief Update the cache entry by number of versions.
44+ *
45+ * \param v Added versions.
46+ */
47+ void Update (uint32_t v) {
48+ version += v;
49+ }
50+ };
51+
52+ /* \brief A container for managed prediction caches.
53+ */
54+ class PredictionContainer {
55+ std::unordered_map<DMatrix *, PredictionCacheEntry> container_;
56+ void ClearExpiredEntries ();
57+
58+ public:
59+ PredictionContainer () = default ;
60+ /* \brief Add a new DMatrix to the cache, at the same time this function will clear out
61+ * all expired caches by checking the `std::weak_ptr`. Caching an existing
62+ * DMatrix won't renew it.
63+ *
64+ * Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
65+ * entry this shared pointer is necessary. More importantly, the life time of this
66+ * cache is tied to the shared pointer.
67+ *
68+ * Another way to make a safe cache is create a proxy to this entry, with anther shared
69+ * pointer defined inside, and pass this proxy around instead of the real entry. But
70+ * seems to be too messy. In XGBoost, functions like `UpdateOneIter` will have
71+ * (memory) safe access to the DMatrix as long as it's passed in as a `shared_ptr`.
72+ *
73+ * \param m shared pointer to the DMatrix that needs to be cached.
74+ * \param device Which device should the cache be allocated on. Pass
75+ * GenericParameter::kCpuId for CPU or positive integer for GPU id.
76+ *
77+ * \return the cache entry for passed in DMatrix, either an existing cache or newly
78+ * created.
79+ */
80+ PredictionCacheEntry& Cache (std::shared_ptr<DMatrix> m, int32_t device);
81+ /* \brief Get a prediction cache entry. This entry must be already allocated by `Cache`
82+ * method. Otherwise a dmlc::Error is thrown.
83+ *
84+ * \param m pointer the DMatrix.
85+ * \return The prediction cache for passed in DMatrix.
86+ */
87+ PredictionCacheEntry& Entry (DMatrix* m);
88+ /* \brief Get a const reference to the underlying hash map. Clear expired caches before
89+ * returning.
90+ */
91+ decltype (container_) const & Container ();
3792};
3893
3994/* *
4095 * \class Predictor
4196 *
42- * \brief Performs prediction on individual training instances or batches of
43- * instances for GBTree. The predictor also manages a prediction cache
44- * associated with input matrices. If possible, it will use previously
45- * calculated predictions instead of calculating new predictions.
46- * Prediction functions all take a GBTreeModel and a DMatrix as input and
47- * output a vector of predictions. The predictor does not modify any state of
48- * the model itself.
97+ * \brief Performs prediction on individual training instances or batches of instances for
98+ * GBTree. Prediction functions all take a GBTreeModel and a DMatrix as input and
99+ * output a vector of predictions. The predictor does not modify any state of the
100+ * model itself.
49101 */
50-
51102class Predictor {
52103 protected:
53104 /*
54105 * \brief Runtime parameters.
55106 */
56107 GenericParameter const * generic_param_;
57- /* *
58- * \brief Map of matrices and associated cached predictions to facilitate
59- * storing and looking up predictions.
60- */
61- std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache_;
62-
63- std::unordered_map<DMatrix*, PredictionCacheEntry>::iterator FindCache (DMatrix const * dmat) {
64- auto cache_emtry = std::find_if (
65- cache_->begin (), cache_->end (),
66- [dmat](std::pair<DMatrix *, PredictionCacheEntry const &> const &kv) {
67- return kv.second .data .get () == dmat;
68- });
69- return cache_emtry;
70- }
71108
72109 public:
73- Predictor (GenericParameter const * generic_param,
74- std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache) :
75- generic_param_{generic_param}, cache_{cache} {}
110+ explicit Predictor (GenericParameter const * generic_param) :
111+ generic_param_{generic_param} {}
76112 virtual ~Predictor () = default ;
77113
78114 /* *
@@ -91,12 +127,11 @@ class Predictor {
91127 * \param model The model to predict from.
92128 * \param tree_begin The tree begin index.
93129 * \param ntree_limit (Optional) The ntree limit. 0 means do not
94- * limit trees.
130+ * limit trees.
95131 */
96-
97- virtual void PredictBatch (DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
132+ virtual void PredictBatch (DMatrix* dmat, PredictionCacheEntry* out_preds,
98133 const gbm::GBTreeModel& model, int tree_begin,
99- unsigned ntree_limit = 0 ) = 0;
134+ uint32_t const ntree_limit = 0 ) = 0;
100135
101136 /* *
102137 * \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel
@@ -116,7 +151,9 @@ class Predictor {
116151 virtual void UpdatePredictionCache (
117152 const gbm::GBTreeModel& model,
118153 std::vector<std::unique_ptr<TreeUpdater>>* updaters,
119- int num_new_trees) = 0;
154+ int num_new_trees,
155+ DMatrix* m,
156+ PredictionCacheEntry* predts) = 0;
120157
121158 /* *
122159 * \fn virtual void Predictor::PredictInstance( const SparsePage::Inst&
@@ -200,18 +237,15 @@ class Predictor {
200237 * \param cache Pointer to prediction cache.
201238 */
202239 static Predictor* Create (
203- std::string const & name, GenericParameter const * generic_param,
204- std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache);
240+ std::string const & name, GenericParameter const * generic_param);
205241};
206242
207243/* !
208244 * \brief Registry entry for predictor.
209245 */
210246struct PredictorReg
211247 : public dmlc::FunctionRegEntryBase<
212- PredictorReg, std::function<Predictor*(
213- GenericParameter const *,
214- std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>>)>> {};
248+ PredictorReg, std::function<Predictor*(GenericParameter const *)>> {};
215249
216250#define XGBOOST_REGISTER_PREDICTOR (UniqueId, Name ) \
217251 static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \
0 commit comments