Skip to content

Commit 1704ece

Browse files
authored
[ML] Fix progress on resume after final training has completed for classification and regression (#1453)
Backport #1443.
1 parent ffc3189 commit 1704ece

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828

2929
//=== Regressions
3030

31+
== {es} version 7.10.0
32+
33+
=== Bug Fixes
34+
35+
* Fix progress on resume after final training has completed for classification and regression.
36+
We previously showed progress stuck at zero for final training. (See {ml-pull}1443[#1443].)
37+
3138
== {es} version 7.9.0
3239

3340
=== New Features

include/maths/CBoostedTreeImpl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
315315
//! Start monitoring the final model training.
316316
void startProgressMonitoringFinalTrain();
317317

318+
//! Skip monitoring the final model training.
319+
void skipProgressMonitoringFinalTrain();
320+
318321
//! Record the training state using the \p recordTrainState callback function
319322
void recordState(const TTrainingStateCallback& recordTrainState) const;
320323

lib/maths/CBoostedTreeImpl.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
272272

273273
core::CProgramCounters::counter(counter_t::E_DFTPMTrainedForestNumberTrees) =
274274
m_BestForest.size();
275+
} else {
276+
this->skipProgressMonitoringFinalTrain();
275277
}
276278

277279
this->computeClassificationWeights(frame);
@@ -346,9 +348,9 @@ std::size_t CBoostedTreeImpl::estimateMemoryUsage(std::size_t numberRows,
346348
std::size_t dataTypeMemoryUsage{maximumNumberFeatures * sizeof(CDataFrameUtils::SDataType)};
347349
std::size_t featureSampleProbabilities{maximumNumberFeatures * sizeof(double)};
348350
// Assuming either many or few missing rows, we get good compression of the bit
349-
// vector. Specifically, we'll assume the average run length is 256 for which
350-
// we get a constant 4 * 8 / 256.
351-
std::size_t missingFeatureMaskMemoryUsage{32 * numberColumns * numberRows / 256};
351+
// vector. Specifically, we'll assume the average run length is 64 for which
352+
// we get a constant 8 / 64.
353+
std::size_t missingFeatureMaskMemoryUsage{8 * numberColumns * numberRows / 64};
352354
std::size_t trainTestMaskMemoryUsage{
353355
2 * static_cast<std::size_t>(std::ceil(std::log2(static_cast<double>(m_NumberFolds)))) *
354356
numberRows};
@@ -1347,13 +1349,18 @@ void CBoostedTreeImpl::startProgressMonitoringFineTuneHyperparameters() {
13471349
}
13481350

13491351
void CBoostedTreeImpl::startProgressMonitoringFinalTrain() {
1352+
13501353
// The final model training uses more data so it's monitored separately.
13511354

13521355
m_Instrumentation->startNewProgressMonitoredTask(CBoostedTreeFactory::FINAL_TRAINING);
13531356
m_TrainingProgress = core::CLoopProgress{
13541357
m_MaximumNumberTrees, m_Instrumentation->progressCallback(), 1.0, 1024};
13551358
}
13561359

1360+
void CBoostedTreeImpl::skipProgressMonitoringFinalTrain() {
1361+
m_Instrumentation->startNewProgressMonitoredTask(CBoostedTreeFactory::FINAL_TRAINING);
1362+
}
1363+
13571364
namespace {
13581365
const std::string VERSION_7_8_TAG{"7.8"};
13591366
const TStrVec SUPPORTED_VERSIONS{VERSION_7_8_TAG};

0 commit comments

Comments
 (0)