Skip to content

Commit 3d394c6

Browse files
author
Hendrik Muhs
authored
[ML] Fix #51 - fix lgamma calculation for x-means clustering (#126) (#131)
replacing boost::lgamma with std::lgamma and guarding the output of lgamma for finite values. fixes #51
1 parent 45a7c24 commit 3d394c6

File tree

6 files changed

+81
-7
lines changed

6 files changed

+81
-7
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ by unnecessary reference counting ({pull}108[#108])
4040
Age seasonal components in proportion to the fraction of values with which they're updated ({pull}88[#88])
4141
Persist and restore was missing some of the trend model state ({pull}#99[#99])
4242
Stop zero variance data generating a log error in the forecast confidence interval calculation ({pull}#107[#107])
43+
Fix corner case failing to calculate lgamma values and the correspoinding log errors ({pull}#126[#126])
4344

4445
=== Regressions
4546

include/maths/CTools.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,9 @@ class MATHS_EXPORT CTools : private core::CNonInstantiatable {
693693
//! A custom implementation of \f$\log(1 - x)\f$ which handles the
694694
//! cancellation error for small x.
695695
static double logOneMinusX(double x);
696+
697+
//! A wrapper around lgamma which handles corner cases if requested
698+
static bool lgamma(double value, double& result, bool checkForFinite = false);
696699
};
697700
}
698701
}

lib/maths/CTools.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,5 +2005,10 @@ double CTools::logOneMinusX(double x) {
20052005

20062006
return result;
20072007
}
2008+
2009+
bool CTools::lgamma(double value, double& result, bool checkForFinite) {
2010+
result = std::lgamma(value);
2011+
return checkForFinite ? std::isfinite(result) : true;
2012+
}
20082013
}
20092014
}

lib/maths/CXMeansOnline1d.cc

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
#include <boost/math/constants/constants.hpp>
3131
#include <boost/math/distributions/normal.hpp>
3232
#include <boost/math/special_functions/digamma.hpp>
33-
#include <boost/math/special_functions/gamma.hpp>
3433

3534
#include <algorithm>
3635
#include <cmath>
@@ -285,6 +284,7 @@ void BICGain(maths_t::EDataType dataType,
285284
double mr = mean(dataType, mvr);
286285
double vr = std::max(variance(dataType, mvr), vmin);
287286

287+
bool haveGamma = distributions.haveGamma();
288288
try {
289289
// Mixture of log-normals (method of moments)
290290
double sl = std::log(1.0 + vl / CTools::pow2(ml + logNormalOffset));
@@ -299,7 +299,21 @@ void BICGain(maths_t::EDataType dataType,
299299

300300
double log2piv = std::log(boost::math::double_constants::two_pi * v);
301301
double log2pis = std::log(boost::math::double_constants::two_pi * s);
302-
double loggn = boost::math::lgamma(a) - a * std::log(b);
302+
303+
double loggn = 0.0;
304+
double loggnl = 0.0;
305+
double loggnr = 0.0;
306+
307+
if (haveGamma && maths::CTools::lgamma(a, loggn, true) &&
308+
maths::CTools::lgamma(al, loggnl, true) &&
309+
maths::CTools::lgamma(ar, loggnr, true)) {
310+
loggn -= a * std::log(b);
311+
loggnl -= al * std::log(bl) + std::log(wl);
312+
loggnr -= ar * std::log(br) + std::log(wr);
313+
} else {
314+
haveGamma = false;
315+
}
316+
303317
double log2pivl =
304318
std::log(boost::math::double_constants::two_pi * vl / CTools::pow2(wl));
305319
double log2pivr =
@@ -308,8 +322,6 @@ void BICGain(maths_t::EDataType dataType,
308322
std::log(boost::math::double_constants::two_pi * sl / CTools::pow2(wl));
309323
double log2pisr =
310324
std::log(boost::math::double_constants::two_pi * sr / CTools::pow2(wr));
311-
double loggnl = boost::math::lgamma(al) - al * std::log(bl) - std::log(wl);
312-
double loggnr = boost::math::lgamma(ar) - ar * std::log(br) - std::log(wr);
313325

314326
for (std::size_t i = start; i < split; ++i) {
315327
double ni = CBasicStatistics::count(categories[i]);
@@ -372,15 +384,15 @@ void BICGain(maths_t::EDataType dataType,
372384
double ll1 =
373385
min(distributions.haveNormal() ? ll1n : boost::numeric::bounds<double>::highest(),
374386
distributions.haveLogNormal() ? ll1l : boost::numeric::bounds<double>::highest(),
375-
distributions.haveGamma() ? ll1g : boost::numeric::bounds<double>::highest()) +
387+
haveGamma ? ll1g : boost::numeric::bounds<double>::highest()) +
376388
distributions.parameters() * logn;
377389
double ll2 =
378390
min(distributions.haveNormal() ? ll2nl : boost::numeric::bounds<double>::highest(),
379391
distributions.haveLogNormal() ? ll2ll : boost::numeric::bounds<double>::highest(),
380-
distributions.haveGamma() ? ll2gl : boost::numeric::bounds<double>::highest()) +
392+
haveGamma ? ll2gl : boost::numeric::bounds<double>::highest()) +
381393
min(distributions.haveNormal() ? ll2nr : boost::numeric::bounds<double>::highest(),
382394
distributions.haveLogNormal() ? ll2lr : boost::numeric::bounds<double>::highest(),
383-
distributions.haveGamma() ? ll2gr : boost::numeric::bounds<double>::highest()) +
395+
haveGamma ? ll2gr : boost::numeric::bounds<double>::highest()) +
384396
(2.0 * distributions.parameters() + 1.0) * logn;
385397

386398
LOG_TRACE(<< "BIC(1) = " << ll1 << ", BIC(2) = " << ll2);

lib/maths/unittest/CToolsTest.cc

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include <boost/optional.hpp>
2727
#include <boost/range.hpp>
2828

29+
#include <array>
30+
2931
using namespace ml;
3032
using namespace maths;
3133
using namespace test;
@@ -1168,6 +1170,54 @@ void CToolsTest::testMiscellaneous() {
11681170
}
11691171
}
11701172

1173+
void CToolsTest::testLgamma() {
1174+
std::array<double, 8> testData = {3.5,
1175+
0.125,
1176+
-0.125,
1177+
0.000244140625,
1178+
1.3552527156068805e-20,
1179+
4.95547e+25,
1180+
5.01753e+25,
1181+
8.64197e+25};
1182+
1183+
std::array<double, 8> expectedData = {
1184+
1.2009736023470742248160218814507129957702389154682,
1185+
2.0194183575537963453202905211670995899482809521344,
1186+
2.1653002489051702517540619481440174064962195287626,
1187+
8.3176252939431805089043336920440196990966796875000,
1188+
45.7477139169563926657247066032141447067260742187500,
1189+
2.882355039447984e+27,
1190+
2.919076782442754e+27,
1191+
5.074673490557339e+27};
1192+
1193+
for (std::size_t i = 0u; i < testData.size(); ++i) {
1194+
double actual;
1195+
double expected = expectedData[i];
1196+
CPPUNIT_ASSERT(maths::CTools::lgamma(testData[i], actual, true));
1197+
CPPUNIT_ASSERT_DOUBLES_EQUAL(expected, actual, 1e-5 * expected);
1198+
}
1199+
1200+
double result;
1201+
CPPUNIT_ASSERT(maths::CTools::lgamma(0, result));
1202+
CPPUNIT_ASSERT_EQUAL(result, std::numeric_limits<double>::infinity());
1203+
1204+
CPPUNIT_ASSERT((maths::CTools::lgamma(0, result, true) == false));
1205+
CPPUNIT_ASSERT_EQUAL(result, std::numeric_limits<double>::infinity());
1206+
1207+
CPPUNIT_ASSERT((maths::CTools::lgamma(-1, result)));
1208+
CPPUNIT_ASSERT_EQUAL(result, std::numeric_limits<double>::infinity());
1209+
1210+
CPPUNIT_ASSERT((maths::CTools::lgamma(-1, result, true) == false));
1211+
CPPUNIT_ASSERT_EQUAL(result, std::numeric_limits<double>::infinity());
1212+
1213+
CPPUNIT_ASSERT((maths::CTools::lgamma(std::numeric_limits<double>::max() - 1, result)));
1214+
CPPUNIT_ASSERT_EQUAL(result, std::numeric_limits<double>::infinity());
1215+
1216+
CPPUNIT_ASSERT((maths::CTools::lgamma(std::numeric_limits<double>::max() - 1,
1217+
result, true) == false));
1218+
CPPUNIT_ASSERT_EQUAL(result, std::numeric_limits<double>::infinity());
1219+
}
1220+
11711221
CppUnit::Test* CToolsTest::suite() {
11721222
CppUnit::TestSuite* suiteOfTests = new CppUnit::TestSuite("CToolsTest");
11731223

@@ -1187,6 +1237,8 @@ CppUnit::Test* CToolsTest::suite() {
11871237
"CToolsTest::testFastLog", &CToolsTest::testFastLog));
11881238
suiteOfTests->addTest(new CppUnit::TestCaller<CToolsTest>(
11891239
"CToolsTest::testMiscellaneous", &CToolsTest::testMiscellaneous));
1240+
suiteOfTests->addTest(new CppUnit::TestCaller<CToolsTest>(
1241+
"CToolsTest::testLgamma", &CToolsTest::testLgamma));
11901242

11911243
return suiteOfTests;
11921244
}

lib/maths/unittest/CToolsTest.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class CToolsTest : public CppUnit::TestFixture {
1818
void testSpread();
1919
void testFastLog();
2020
void testMiscellaneous();
21+
void testLgamma();
2122

2223
static CppUnit::Test* suite();
2324
};

0 commit comments

Comments
 (0)