Skip to content

Commit 06a0f99

Browse files
authored
[ML] Fix issues upgrading state leading to possible abort of the autodetect process (#136)
Closes #135.
1 parent e16816e commit 06a0f99

File tree

6 files changed

+75
-22
lines changed

6 files changed

+75
-22
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ Fix corner case failing to calculate lgamma values and the correspoinding log er
7272
=== Bug Fixes
7373

7474
Function description for population lat_long results should be lat_long instead of mean ({pull}81[#81])
75-
7675
By-fields should respect model_plot_config.terms ({pull}86[#86])
76+
The trend decomposition state wasn't being correctly upgraded potentially causing the autodetect process to abort ({pull}136[#136])
7777

7878
=== Regressions
7979

include/core/CStateMachine.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <atomic>
1616
#include <cstddef>
1717
#include <list>
18+
#include <map>
1819
#include <vector>
1920

2021
namespace ml {
@@ -67,6 +68,7 @@ class CORE_EXPORT CStateMachine {
6768
using TSizeVec = std::vector<std::size_t>;
6869
using TSizeVecVec = std::vector<TSizeVec>;
6970
using TStrVec = std::vector<std::string>;
71+
using TSizeSizeMap = std::map<std::size_t, std::size_t>;
7072

7173
public:
7274
//! Set the number of machines we expect the program to use.
@@ -85,7 +87,8 @@ class CORE_EXPORT CStateMachine {
8587
//! \name Persistence
8688
//@{
8789
//! Initialize by reading state from \p traverser.
88-
bool acceptRestoreTraverser(CStateRestoreTraverser& traverser);
90+
bool acceptRestoreTraverser(CStateRestoreTraverser& traverser,
91+
const TSizeSizeMap& mapping = TSizeSizeMap());
8992

9093
//! Persist state by passing information to the supplied inserter.
9194
void acceptPersistInserter(CStatePersistInserter& inserter) const;

lib/core/CStateMachine.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace core {
2525
namespace {
2626

2727
// CStateMachine
28-
const std::string MACHINE_TAG("a");
28+
//const std::string MACHINE_TAG("a"); No longer used
2929
const std::string STATE_TAG("b");
3030

3131
// CStateMachine::SMachine
@@ -88,17 +88,26 @@ CStateMachine CStateMachine::create(const TStrVec& alphabet,
8888
return result;
8989
}
9090

91-
bool CStateMachine::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
91+
bool CStateMachine::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser,
92+
const TSizeSizeMap& mapping) {
9293
do {
9394
const std::string& name = traverser.name();
94-
RESTORE_BUILT_IN(MACHINE_TAG, m_Machine)
9595
RESTORE_BUILT_IN(STATE_TAG, m_State)
9696
} while (traverser.next());
97+
if (mapping.size() > 0) {
98+
auto mapped = mapping.find(m_State);
99+
if (mapped != mapping.end()) {
100+
m_State = mapped->second;
101+
} else {
102+
LOG_ERROR(<< "Bad mapping '" << core::CContainerPrinter::print(mapping)
103+
<< "' state = " << m_State);
104+
return false;
105+
}
106+
}
97107
return true;
98108
}
99109

100110
void CStateMachine::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
101-
inserter.insertValue(MACHINE_TAG, m_Machine);
102111
inserter.insertValue(STATE_TAG, m_State);
103112
}
104113

@@ -201,14 +210,15 @@ void CStateMachine::CMachineDeque::capacity(std::size_t capacity) {
201210
m_Capacity = capacity;
202211
}
203212

204-
const CStateMachine::SMachine& CStateMachine::CMachineDeque::operator[](std::size_t pos) const {
213+
const CStateMachine::SMachine& CStateMachine::CMachineDeque::operator[](std::size_t pos_) const {
214+
std::size_t pos{pos_};
205215
for (const auto& machines : m_Machines) {
206216
if (pos < machines.size()) {
207217
return machines[pos];
208218
}
209219
pos -= machines.size();
210220
}
211-
LOG_ABORT(<< "Invalid index '" << pos << "'");
221+
LOG_ABORT(<< "Invalid index '" << pos_ << "'");
212222
}
213223

214224
std::size_t CStateMachine::CMachineDeque::size() const {

lib/core/unittest/CStateMachineTest.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ void CStateMachineTest::testPersist() {
173173
core::CRapidXmlStateRestoreTraverser traverser(parser);
174174

175175
core::CStateMachine restored = core::CStateMachine::create(
176-
machine[1].s_Alphabet, machine[1].s_States, machine[1].s_TransitionFunction,
176+
machine[0].s_Alphabet, machine[0].s_States, machine[0].s_TransitionFunction,
177177
0); // initial state
178-
traverser.traverseSubLevel(
179-
boost::bind(&core::CStateMachine::acceptRestoreTraverser, &restored, _1));
178+
traverser.traverseSubLevel([&restored](core::CStateRestoreTraverser& traverser_) {
179+
return restored.acceptRestoreTraverser(traverser_);
180+
});
180181

181182
CPPUNIT_ASSERT_EQUAL(original.checksum(), restored.checksum());
182183
std::string newXml;

lib/maths/CTimeSeriesDecompositionDetail.cc

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
#include <algorithm>
4848
#include <cmath>
49+
#include <map>
4950
#include <numeric>
5051
#include <string>
5152
#include <vector>
@@ -61,6 +62,7 @@ using TBoolVec = std::vector<bool>;
6162
using TDoubleVec = std::vector<double>;
6263
using TSizeVec = std::vector<std::size_t>;
6364
using TSizeVecVec = std::vector<TSizeVec>;
65+
using TSizeSizeMap = std::map<std::size_t, std::size_t>;
6466
using TStrVec = std::vector<std::string>;
6567
using TTimeVec = std::vector<core_t::TTime>;
6668
using TTimeTimePr = std::pair<core_t::TTime, core_t::TTime>;
@@ -319,7 +321,7 @@ const std::string LAST_UPDATE_OLD_TAG{"j"};
319321

320322
//////////////////////// Upgrade to Version 6.3 ////////////////////////
321323

322-
const double MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3{48.0};
324+
const double MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3{48.0};
323325

324326
bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,
325327
CTrendComponent& trend,
@@ -342,7 +344,7 @@ bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,
342344

343345
// Generate some samples from the old trend model.
344346

345-
double weight{MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3 *
347+
double weight{MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3 *
346348
static_cast<double>(bucketLength) / static_cast<double>(4 * WEEK)};
347349

348350
CPRNG::CXorOShiro128Plus rng;
@@ -355,6 +357,18 @@ bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,
355357
return true;
356358
}
357359

360+
// This implements the mapping from restored states to their best
361+
// equivalents; specifically:
362+
// SC_NEW_COMPONENTS |-> SC_NEW_COMPONENTS
363+
// SC_NORMAL |-> SC_NORMAL
364+
// SC_FORECASTING |-> SC_NORMAL
365+
// SC_DISABLED |-> SC_DISABLED
366+
// SC_ERROR |-> SC_ERROR
367+
// Note that we don't try and restore the periodicity test state
368+
// (see CTimeSeriesDecomposition::acceptRestoreTraverser) and the
369+
// calendar test state is unchanged.
370+
const TSizeSizeMap SC_STATES_UPGRADING_TO_VERSION_6_3{{0, 0}, {1, 1}, {2, 1}, {3, 2}, {4, 3}};
371+
358372
////////////////////////////////////////////////////////////////////////
359373

360374
// Constants
@@ -490,8 +504,9 @@ bool CTimeSeriesDecompositionDetail::CPeriodicityTest::acceptRestoreTraverser(
490504
do {
491505
const std::string& name{traverser.name()};
492506
RESTORE(PERIODICITY_TEST_MACHINE_6_3_TAG,
493-
traverser.traverseSubLevel(boost::bind(
494-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)))
507+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
508+
return m_Machine.acceptRestoreTraverser(traverser_);
509+
}))
495510
RESTORE_SETUP_TEARDOWN(
496511
SHORT_WINDOW_6_3_TAG, m_Windows[E_Short].reset(this->newWindow(E_Short)),
497512
m_Windows[E_Short] && traverser.traverseSubLevel(boost::bind(
@@ -792,8 +807,9 @@ bool CTimeSeriesDecompositionDetail::CCalendarTest::acceptRestoreTraverser(core:
792807
do {
793808
const std::string& name{traverser.name()};
794809
RESTORE(CALENDAR_TEST_MACHINE_6_3_TAG,
795-
traverser.traverseSubLevel(boost::bind(
796-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)))
810+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
811+
return m_Machine.acceptRestoreTraverser(traverser_);
812+
}))
797813
RESTORE_BUILT_IN(LAST_MONTH_6_3_TAG, m_LastMonth);
798814
RESTORE_SETUP_TEARDOWN(
799815
CALENDAR_TEST_6_3_TAG,
@@ -999,8 +1015,9 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(
9991015
while (traverser.next()) {
10001016
const std::string& name{traverser.name()};
10011017
RESTORE(COMPONENTS_MACHINE_6_3_TAG,
1002-
traverser.traverseSubLevel(boost::bind(
1003-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)));
1018+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
1019+
return m_Machine.acceptRestoreTraverser(traverser_);
1020+
}))
10041021
RESTORE_BUILT_IN(DECAY_RATE_6_3_TAG, m_DecayRate);
10051022
RESTORE(GAIN_CONTROLLER_6_3_TAG,
10061023
traverser.traverseSubLevel(boost::bind(&CGainController::acceptRestoreTraverser,
@@ -1035,8 +1052,10 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(
10351052
do {
10361053
const std::string& name{traverser.name()};
10371054
RESTORE(COMPONENTS_MACHINE_OLD_TAG,
1038-
traverser.traverseSubLevel(boost::bind(
1039-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)));
1055+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
1056+
return m_Machine.acceptRestoreTraverser(
1057+
traverser_, SC_STATES_UPGRADING_TO_VERSION_6_3);
1058+
}))
10401059
RESTORE_SETUP_TEARDOWN(TREND_OLD_TAG,
10411060
/**/,
10421061
traverser.traverseSubLevel(boost::bind(
@@ -1057,7 +1076,7 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(
10571076
/**/)
10581077
} while (traverser.next());
10591078

1060-
m_MeanVarianceScale.add(1.0, MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3);
1079+
m_MeanVarianceScale.add(1.0, MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3);
10611080
}
10621081
return true;
10631082
}
@@ -1951,13 +1970,15 @@ bool CTimeSeriesDecompositionDetail::CComponents::CSeasonal::acceptRestoreTraver
19511970
RESTORE_NO_ERROR(COMPONENT_6_3_TAG,
19521971
m_Components.emplace_back(decayRate, bucketLength, traverser))
19531972
}
1973+
m_PredictionErrors.resize(m_Components.size());
19541974
} else {
19551975
// There is no version string this is historic state.
19561976
do {
19571977
const std::string& name{traverser.name()};
19581978
RESTORE_NO_ERROR(COMPONENT_OLD_TAG,
19591979
m_Components.emplace_back(decayRate, bucketLength, traverser))
19601980
} while (traverser.next());
1981+
m_PredictionErrors.resize(m_Components.size());
19611982
}
19621983
return true;
19631984
}
@@ -2253,13 +2274,15 @@ bool CTimeSeriesDecompositionDetail::CComponents::CCalendar::acceptRestoreTraver
22532274
RESTORE_NO_ERROR(COMPONENT_6_3_TAG,
22542275
m_Components.emplace_back(decayRate, bucketLength, traverser))
22552276
}
2277+
m_PredictionErrors.resize(m_Components.size());
22562278
} else {
22572279
// There is no version string this is historic state.
22582280
do {
22592281
const std::string& name{traverser.name()};
22602282
RESTORE_NO_ERROR(COMPONENT_OLD_TAG,
22612283
m_Components.emplace_back(decayRate, bucketLength, traverser))
22622284
} while (traverser.next());
2285+
m_PredictionErrors.resize(m_Components.size());
22632286
}
22642287
return true;
22652288
}

lib/maths/unittest/CTimeSeriesDecompositionTest.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,6 +2051,8 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
20512051
// Check we can validly upgrade existing state.
20522052

20532053
using TStrVec = std::vector<std::string>;
2054+
using TDouble3Vec = core::CSmallVector<double, 3>;
2055+
20542056
auto load = [](const std::string& name, std::string& result) {
20552057
std::ifstream file;
20562058
file.open(name);
@@ -2126,6 +2128,13 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
21262128
CPPUNIT_ASSERT_DOUBLES_EQUAL(expectedScale.second, scale.second,
21272129
0.005 * std::max(expectedScale.second, 0.4));
21282130
}
2131+
2132+
// Check some basic operations on the upgraded model.
2133+
decomposition.forecast(60480000, 60480000 + WEEK, HALF_HOUR, 90.0, 1.0,
2134+
[](core_t::TTime, const TDouble3Vec&) {});
2135+
for (core_t::TTime time = 60480000; time < 60480000 + WEEK; time += HALF_HOUR) {
2136+
decomposition.addPoint(time, 10.0);
2137+
}
21292138
}
21302139

21312140
LOG_DEBUG(<< "*** Trend and Seasonal Components ***");
@@ -2201,6 +2210,13 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
22012210
LOG_DEBUG(<< "Mean scale error = " << maths::CBasicStatistics::mean(meanScaleError));
22022211
CPPUNIT_ASSERT(maths::CBasicStatistics::mean(meanValueError) < 0.06);
22032212
CPPUNIT_ASSERT(maths::CBasicStatistics::mean(meanScaleError) < 0.07);
2213+
2214+
// Check some basic operations on the upgraded model.
2215+
decomposition.forecast(10366200, 10366200 + WEEK, HALF_HOUR, 90.0, 1.0,
2216+
[](core_t::TTime, const TDouble3Vec&) {});
2217+
for (core_t::TTime time = 60480000; time < 60480000 + WEEK; time += HALF_HOUR) {
2218+
decomposition.addPoint(time, 10.0);
2219+
}
22042220
}
22052221
}
22062222

0 commit comments

Comments
 (0)