Skip to content

Commit a341300

Browse files
committed
Fix issues upgrading state leading to SIGSEGV
1 parent e16816e commit a341300

File tree

5 files changed

+64
-21
lines changed

5 files changed

+64
-21
lines changed

include/core/CStateMachine.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <atomic>
1616
#include <cstddef>
1717
#include <list>
18+
#include <map>
1819
#include <vector>
1920

2021
namespace ml {
@@ -67,6 +68,7 @@ class CORE_EXPORT CStateMachine {
6768
using TSizeVec = std::vector<std::size_t>;
6869
using TSizeVecVec = std::vector<TSizeVec>;
6970
using TStrVec = std::vector<std::string>;
71+
using TSizeSizeMap = std::map<std::size_t, std::size_t>;
7072

7173
public:
7274
//! Set the number of machines we expect the program to use.
@@ -85,7 +87,8 @@ class CORE_EXPORT CStateMachine {
8587
//! \name Persistence
8688
//@{
8789
//! Initialize by reading state from \p traverser.
88-
bool acceptRestoreTraverser(CStateRestoreTraverser& traverser);
90+
bool acceptRestoreTraverser(CStateRestoreTraverser& traverser,
91+
const TSizeSizeMap& mapping = TSizeSizeMap());
8992

9093
//! Persist state by passing information to the supplied inserter.
9194
void acceptPersistInserter(CStatePersistInserter& inserter) const;

lib/core/CStateMachine.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace core {
2525
namespace {
2626

2727
// CStateMachine
28-
const std::string MACHINE_TAG("a");
28+
//const std::string MACHINE_TAG("a"); No longer used
2929
const std::string STATE_TAG("b");
3030

3131
// CStateMachine::SMachine
@@ -88,17 +88,26 @@ CStateMachine CStateMachine::create(const TStrVec& alphabet,
8888
return result;
8989
}
9090

91-
bool CStateMachine::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
91+
bool CStateMachine::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser,
92+
const TSizeSizeMap& mapping) {
9293
do {
9394
const std::string& name = traverser.name();
94-
RESTORE_BUILT_IN(MACHINE_TAG, m_Machine)
9595
RESTORE_BUILT_IN(STATE_TAG, m_State)
9696
} while (traverser.next());
97+
if (mapping.size() > 0) {
98+
auto mapped = mapping.find(m_State);
99+
if (mapped != mapping.end()) {
100+
m_State = mapped->second;
101+
} else {
102+
LOG_ERROR(<< "Bad mapping '" << core::CContainerPrinter::print(mapping)
103+
<< "' state = " << m_State);
104+
return false;
105+
}
106+
}
97107
return true;
98108
}
99109

100110
void CStateMachine::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
101-
inserter.insertValue(MACHINE_TAG, m_Machine);
102111
inserter.insertValue(STATE_TAG, m_State);
103112
}
104113

@@ -201,14 +210,15 @@ void CStateMachine::CMachineDeque::capacity(std::size_t capacity) {
201210
m_Capacity = capacity;
202211
}
203212

204-
const CStateMachine::SMachine& CStateMachine::CMachineDeque::operator[](std::size_t pos) const {
213+
const CStateMachine::SMachine& CStateMachine::CMachineDeque::operator[](std::size_t pos_) const {
214+
std::size_t pos{pos_};
205215
for (const auto& machines : m_Machines) {
206216
if (pos < machines.size()) {
207217
return machines[pos];
208218
}
209219
pos -= machines.size();
210220
}
211-
LOG_ABORT(<< "Invalid index '" << pos << "'");
221+
LOG_ABORT(<< "Invalid index '" << pos_ << "'");
212222
}
213223

214224
std::size_t CStateMachine::CMachineDeque::size() const {

lib/core/unittest/CStateMachineTest.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ void CStateMachineTest::testPersist() {
173173
core::CRapidXmlStateRestoreTraverser traverser(parser);
174174

175175
core::CStateMachine restored = core::CStateMachine::create(
176-
machine[1].s_Alphabet, machine[1].s_States, machine[1].s_TransitionFunction,
176+
machine[0].s_Alphabet, machine[0].s_States, machine[0].s_TransitionFunction,
177177
0); // initial state
178-
traverser.traverseSubLevel(
179-
boost::bind(&core::CStateMachine::acceptRestoreTraverser, &restored, _1));
178+
traverser.traverseSubLevel([&restored](core::CStateRestoreTraverser& traverser_) {
179+
return restored.acceptRestoreTraverser(traverser_);
180+
});
180181

181182
CPPUNIT_ASSERT_EQUAL(original.checksum(), restored.checksum());
182183
std::string newXml;

lib/maths/CTimeSeriesDecompositionDetail.cc

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
#include <algorithm>
4848
#include <cmath>
49+
#include <map>
4950
#include <numeric>
5051
#include <string>
5152
#include <vector>
@@ -61,6 +62,7 @@ using TBoolVec = std::vector<bool>;
6162
using TDoubleVec = std::vector<double>;
6263
using TSizeVec = std::vector<std::size_t>;
6364
using TSizeVecVec = std::vector<TSizeVec>;
65+
using TSizeSizeMap = std::map<std::size_t, std::size_t>;
6466
using TStrVec = std::vector<std::string>;
6567
using TTimeVec = std::vector<core_t::TTime>;
6668
using TTimeTimePr = std::pair<core_t::TTime, core_t::TTime>;
@@ -319,7 +321,7 @@ const std::string LAST_UPDATE_OLD_TAG{"j"};
319321

320322
//////////////////////// Upgrade to Version 6.3 ////////////////////////
321323

322-
const double MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3{48.0};
324+
const double MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3{48.0};
323325

324326
bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,
325327
CTrendComponent& trend,
@@ -342,7 +344,7 @@ bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,
342344

343345
// Generate some samples from the old trend model.
344346

345-
double weight{MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3 *
347+
double weight{MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3 *
346348
static_cast<double>(bucketLength) / static_cast<double>(4 * WEEK)};
347349

348350
CPRNG::CXorOShiro128Plus rng;
@@ -355,6 +357,8 @@ bool upgradeTrendModelToVersion6p3(const core_t::TTime bucketLength,
355357
return true;
356358
}
357359

360+
const TSizeSizeMap SC_STATES_UPGRADING_TO_VERSION_6_3{{0, 0}, {1, 1}, {2, 1}, {3, 2}, {4, 3}};
361+
358362
////////////////////////////////////////////////////////////////////////
359363

360364
// Constants
@@ -490,8 +494,9 @@ bool CTimeSeriesDecompositionDetail::CPeriodicityTest::acceptRestoreTraverser(
490494
do {
491495
const std::string& name{traverser.name()};
492496
RESTORE(PERIODICITY_TEST_MACHINE_6_3_TAG,
493-
traverser.traverseSubLevel(boost::bind(
494-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)))
497+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
498+
return m_Machine.acceptRestoreTraverser(traverser_);
499+
}))
495500
RESTORE_SETUP_TEARDOWN(
496501
SHORT_WINDOW_6_3_TAG, m_Windows[E_Short].reset(this->newWindow(E_Short)),
497502
m_Windows[E_Short] && traverser.traverseSubLevel(boost::bind(
@@ -792,8 +797,9 @@ bool CTimeSeriesDecompositionDetail::CCalendarTest::acceptRestoreTraverser(core:
792797
do {
793798
const std::string& name{traverser.name()};
794799
RESTORE(CALENDAR_TEST_MACHINE_6_3_TAG,
795-
traverser.traverseSubLevel(boost::bind(
796-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)))
800+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
801+
return m_Machine.acceptRestoreTraverser(traverser_);
802+
}))
797803
RESTORE_BUILT_IN(LAST_MONTH_6_3_TAG, m_LastMonth);
798804
RESTORE_SETUP_TEARDOWN(
799805
CALENDAR_TEST_6_3_TAG,
@@ -999,8 +1005,9 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(
9991005
while (traverser.next()) {
10001006
const std::string& name{traverser.name()};
10011007
RESTORE(COMPONENTS_MACHINE_6_3_TAG,
1002-
traverser.traverseSubLevel(boost::bind(
1003-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)));
1008+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
1009+
return m_Machine.acceptRestoreTraverser(traverser_);
1010+
}))
10041011
RESTORE_BUILT_IN(DECAY_RATE_6_3_TAG, m_DecayRate);
10051012
RESTORE(GAIN_CONTROLLER_6_3_TAG,
10061013
traverser.traverseSubLevel(boost::bind(&CGainController::acceptRestoreTraverser,
@@ -1035,8 +1042,10 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(
10351042
do {
10361043
const std::string& name{traverser.name()};
10371044
RESTORE(COMPONENTS_MACHINE_OLD_TAG,
1038-
traverser.traverseSubLevel(boost::bind(
1039-
&core::CStateMachine::acceptRestoreTraverser, &m_Machine, _1)));
1045+
traverser.traverseSubLevel([this](core::CStateRestoreTraverser& traverser_) {
1046+
return m_Machine.acceptRestoreTraverser(
1047+
traverser_, SC_STATES_UPGRADING_TO_VERSION_6_3);
1048+
}))
10401049
RESTORE_SETUP_TEARDOWN(TREND_OLD_TAG,
10411050
/**/,
10421051
traverser.traverseSubLevel(boost::bind(
@@ -1057,7 +1066,7 @@ bool CTimeSeriesDecompositionDetail::CComponents::acceptRestoreTraverser(
10571066
/**/)
10581067
} while (traverser.next());
10591068

1060-
m_MeanVarianceScale.add(1.0, MODEL_WEIGHT_UPGRADING_TO_VERSION_6p3);
1069+
m_MeanVarianceScale.add(1.0, MODEL_WEIGHT_UPGRADING_TO_VERSION_6_3);
10611070
}
10621071
return true;
10631072
}
@@ -1951,13 +1960,15 @@ bool CTimeSeriesDecompositionDetail::CComponents::CSeasonal::acceptRestoreTraver
19511960
RESTORE_NO_ERROR(COMPONENT_6_3_TAG,
19521961
m_Components.emplace_back(decayRate, bucketLength, traverser))
19531962
}
1963+
m_PredictionErrors.resize(m_Components.size());
19541964
} else {
19551965
// There is no version string this is historic state.
19561966
do {
19571967
const std::string& name{traverser.name()};
19581968
RESTORE_NO_ERROR(COMPONENT_OLD_TAG,
19591969
m_Components.emplace_back(decayRate, bucketLength, traverser))
19601970
} while (traverser.next());
1971+
m_PredictionErrors.resize(m_Components.size());
19611972
}
19621973
return true;
19631974
}
@@ -2253,13 +2264,15 @@ bool CTimeSeriesDecompositionDetail::CComponents::CCalendar::acceptRestoreTraver
22532264
RESTORE_NO_ERROR(COMPONENT_6_3_TAG,
22542265
m_Components.emplace_back(decayRate, bucketLength, traverser))
22552266
}
2267+
m_PredictionErrors.resize(m_Components.size());
22562268
} else {
22572269
// There is no version string this is historic state.
22582270
do {
22592271
const std::string& name{traverser.name()};
22602272
RESTORE_NO_ERROR(COMPONENT_OLD_TAG,
22612273
m_Components.emplace_back(decayRate, bucketLength, traverser))
22622274
} while (traverser.next());
2275+
m_PredictionErrors.resize(m_Components.size());
22632276
}
22642277
return true;
22652278
}

lib/maths/unittest/CTimeSeriesDecompositionTest.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,6 +2051,8 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
20512051
// Check we can validly upgrade existing state.
20522052

20532053
using TStrVec = std::vector<std::string>;
2054+
using TDouble3Vec = core::CSmallVector<double, 3>;
2055+
20542056
auto load = [](const std::string& name, std::string& result) {
20552057
std::ifstream file;
20562058
file.open(name);
@@ -2126,6 +2128,13 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
21262128
CPPUNIT_ASSERT_DOUBLES_EQUAL(expectedScale.second, scale.second,
21272129
0.005 * std::max(expectedScale.second, 0.4));
21282130
}
2131+
2132+
// Check some basic operations on the upgraded model.
2133+
decomposition.forecast(60480000, 60480000 + WEEK, HALF_HOUR, 90.0, 1.0,
2134+
[](core_t::TTime, const TDouble3Vec&) {});
2135+
for (core_t::TTime time = 60480000; time < 60480000 + WEEK; time += HALF_HOUR) {
2136+
decomposition.addPoint(time, 10.0);
2137+
}
21292138
}
21302139

21312140
LOG_DEBUG(<< "*** Trend and Seasonal Components ***");
@@ -2201,6 +2210,13 @@ void CTimeSeriesDecompositionTest::testUpgrade() {
22012210
LOG_DEBUG(<< "Mean scale error = " << maths::CBasicStatistics::mean(meanScaleError));
22022211
CPPUNIT_ASSERT(maths::CBasicStatistics::mean(meanValueError) < 0.06);
22032212
CPPUNIT_ASSERT(maths::CBasicStatistics::mean(meanScaleError) < 0.07);
2213+
2214+
// Check some basic operations on the upgraded model.
2215+
decomposition.forecast(10366200, 10366200 + WEEK, HALF_HOUR, 90.0, 1.0,
2216+
[](core_t::TTime, const TDouble3Vec&) {});
2217+
for (core_t::TTime time = 60480000; time < 60480000 + WEEK; time += HALF_HOUR) {
2218+
decomposition.addPoint(time, 10.0);
2219+
}
22042220
}
22052221
}
22062222

0 commit comments

Comments
 (0)