From 68e2e34d4eac509c01ccb3bdb1e272bde423e674 Mon Sep 17 00:00:00 2001 From: Valeriy Khakhutskyy <1292899+valeriy42@users.noreply.github.com> Date: Fri, 14 Aug 2020 19:56:57 +0200 Subject: [PATCH] [ML] Activate model metadata output (#1456) Activate the output of the model metadata and the corresponding unit tests for total feature importance. The implementation itself was introduced in #1387 however, I need to fix the documentation, it was originally attributed to v7.10. Hence, I mark this PR as enhancement to rectify the docs. --- docs/CHANGELOG.asciidoc | 5 +- lib/api/CDataFrameAnalyzer.cc | 3 +- ...CDataFrameAnalyzerFeatureImportanceTest.cc | 107 +++++++++--------- 3 files changed, 56 insertions(+), 59 deletions(-) diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 3402977f95..3b4ba28927 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -30,6 +30,10 @@ == {es} version 7.10.0 +=== Enhancements + +* Calculate total feature importance to store with model metadata. (See {ml-pull}1387[#1387].) + === Bug Fixes * Fix progress on resume after final training has completed for classification and regression. @@ -64,7 +68,6 @@ regression. (See {ml-pull}1340[#1340].) * Improvement in handling large inference model definitions. (See {ml-pull}1349[#1349].) * Add a peak_model_bytes field to model_size_stats. (See {ml-pull}1389[#1389].) -* Calculate total feature importance as a new result type. (See {ml-pull}1387[#1387].) === Bug Fixes diff --git a/lib/api/CDataFrameAnalyzer.cc b/lib/api/CDataFrameAnalyzer.cc index 90fe79858e..2c17a7f12a 100644 --- a/lib/api/CDataFrameAnalyzer.cc +++ b/lib/api/CDataFrameAnalyzer.cc @@ -144,8 +144,7 @@ void CDataFrameAnalyzer::run() { analysisRunner->waitToFinish(); this->writeInferenceModel(*analysisRunner, outputWriter); this->writeResultsOf(*analysisRunner, outputWriter); - // TODO reactivate once Java parsing is ready - // this->writeInferenceModelMetadata(*analysisRunner, outputWriter); + this->writeInferenceModelMetadata(*analysisRunner, outputWriter); } } diff --git a/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc b/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc index 31e98379fe..a95ac3f0ac 100644 --- a/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc +++ b/lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc @@ -537,16 +537,15 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) { BOOST_REQUIRE_CLOSE(c3Sum, c4Sum, 5.0); // c3 and c4 within 5% of each other // make sure the local approximation differs from the prediction always by the same bias (up to a numeric error) BOOST_REQUIRE_SMALL(maths::CBasicStatistics::variance(bias), 1e-6); - // TODO reactivate once Java parsing is ready - // BOOST_TEST_REQUIRE(hasTotalFeatureImportance); - // BOOST_REQUIRE_CLOSE(c1TotalShapActual, - // maths::CBasicStatistics::mean(c1TotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c2TotalShapActual, - // maths::CBasicStatistics::mean(c2TotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c3TotalShapActual, - // maths::CBasicStatistics::mean(c3TotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c4TotalShapActual, - // maths::CBasicStatistics::mean(c4TotalShapExpected), 1.0); + BOOST_TEST_REQUIRE(hasTotalFeatureImportance); + BOOST_REQUIRE_CLOSE(c1TotalShapActual, + maths::CBasicStatistics::mean(c1TotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c2TotalShapActual, + maths::CBasicStatistics::mean(c2TotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c3TotalShapActual, + maths::CBasicStatistics::mean(c3TotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c4TotalShapActual, + maths::CBasicStatistics::mean(c4TotalShapExpected), 1.0); } BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoImportance, SFixture) { @@ -629,7 +628,6 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) { if (result["model_metadata"].HasMember("total_feature_importance")) { hasTotalFeatureImportance = true; } - // TODO reactivate once Java parsing is ready c1FooTotalShapActual = readTotalShapValue(result, "c1", "foo"); c2FooTotalShapActual = readTotalShapValue(result, "c2", "foo"); c3FooTotalShapActual = readTotalShapValue(result, "c3", "foo"); @@ -650,24 +648,23 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) { BOOST_REQUIRE_CLOSE(c3Sum, c4Sum, 40.0); // c3 and c4 within 40% of each other // make sure the local approximation differs from the prediction always by the same bias (up to a numeric error) BOOST_REQUIRE_SMALL(maths::CBasicStatistics::variance(bias), 1e-6); - // TODO reactivate once Java parsing is ready - // BOOST_TEST_REQUIRE(hasTotalFeatureImportance); - // BOOST_REQUIRE_CLOSE(c1FooTotalShapActual, - // maths::CBasicStatistics::mean(c1TotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c2FooTotalShapActual, - // maths::CBasicStatistics::mean(c2TotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c3FooTotalShapActual, - // maths::CBasicStatistics::mean(c3TotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c4FooTotalShapActual, - // maths::CBasicStatistics::mean(c4TotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c1BarTotalShapActual, - // maths::CBasicStatistics::mean(c1TotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c2BarTotalShapActual, - // maths::CBasicStatistics::mean(c2TotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c3BarTotalShapActual, - // maths::CBasicStatistics::mean(c3TotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c4BarTotalShapActual, - // maths::CBasicStatistics::mean(c4TotalShapExpected), 1.0); + BOOST_TEST_REQUIRE(hasTotalFeatureImportance); + BOOST_REQUIRE_CLOSE(c1FooTotalShapActual, + maths::CBasicStatistics::mean(c1TotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c2FooTotalShapActual, + maths::CBasicStatistics::mean(c2TotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c3FooTotalShapActual, + maths::CBasicStatistics::mean(c3TotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c4FooTotalShapActual, + maths::CBasicStatistics::mean(c4TotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c1BarTotalShapActual, + maths::CBasicStatistics::mean(c1TotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c2BarTotalShapActual, + maths::CBasicStatistics::mean(c2TotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c3BarTotalShapActual, + maths::CBasicStatistics::mean(c3TotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c4BarTotalShapActual, + maths::CBasicStatistics::mean(c4TotalShapExpected), 1.0); } BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SFixture) { @@ -734,7 +731,6 @@ BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SF if (result["model_metadata"].HasMember("total_feature_importance")) { hasTotalFeatureImportance = true; } - // TODO reactivate once Java parsing is ready c1FooTotalShapActual = readTotalShapValue(result, "c1", "foo"); c2FooTotalShapActual = readTotalShapValue(result, "c2", "foo"); c3FooTotalShapActual = readTotalShapValue(result, "c3", "foo"); @@ -749,32 +745,31 @@ BOOST_FIXTURE_TEST_CASE(testMultiClassClassificationFeatureImportanceAllShap, SF c4BazTotalShapActual = readTotalShapValue(result, "c4", "baz"); } } - // TODO reactivate once Java parsing is ready - // BOOST_TEST_REQUIRE(hasTotalFeatureImportance); - // BOOST_REQUIRE_CLOSE(c1FooTotalShapActual, - // maths::CBasicStatistics::mean(c1FooTotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c2FooTotalShapActual, - // maths::CBasicStatistics::mean(c2FooTotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c3FooTotalShapActual, - // maths::CBasicStatistics::mean(c3FooTotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c4FooTotalShapActual, - // maths::CBasicStatistics::mean(c4FooTotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c1BarTotalShapActual, - // maths::CBasicStatistics::mean(c1BarTotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c2BarTotalShapActual, - // maths::CBasicStatistics::mean(c2BarTotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c3BarTotalShapActual, - // maths::CBasicStatistics::mean(c3BarTotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c4BarTotalShapActual, - // maths::CBasicStatistics::mean(c4BarTotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c1BazTotalShapActual, - // maths::CBasicStatistics::mean(c1BazTotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c2BazTotalShapActual, - // maths::CBasicStatistics::mean(c2BazTotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c3BazTotalShapActual, - // maths::CBasicStatistics::mean(c3BazTotalShapExpected), 1.0); - // BOOST_REQUIRE_CLOSE(c4BazTotalShapActual, - // maths::CBasicStatistics::mean(c4BazTotalShapExpected), 1.0); + BOOST_TEST_REQUIRE(hasTotalFeatureImportance); + BOOST_REQUIRE_CLOSE(c1FooTotalShapActual, + maths::CBasicStatistics::mean(c1FooTotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c2FooTotalShapActual, + maths::CBasicStatistics::mean(c2FooTotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c3FooTotalShapActual, + maths::CBasicStatistics::mean(c3FooTotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c4FooTotalShapActual, + maths::CBasicStatistics::mean(c4FooTotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c1BarTotalShapActual, + maths::CBasicStatistics::mean(c1BarTotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c2BarTotalShapActual, + maths::CBasicStatistics::mean(c2BarTotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c3BarTotalShapActual, + maths::CBasicStatistics::mean(c3BarTotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c4BarTotalShapActual, + maths::CBasicStatistics::mean(c4BarTotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c1BazTotalShapActual, + maths::CBasicStatistics::mean(c1BazTotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c2BazTotalShapActual, + maths::CBasicStatistics::mean(c2BazTotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c3BazTotalShapActual, + maths::CBasicStatistics::mean(c3BazTotalShapExpected), 1.0); + BOOST_REQUIRE_CLOSE(c4BazTotalShapActual, + maths::CBasicStatistics::mean(c4BazTotalShapExpected), 1.0); } BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceNoShap, SFixture) {