88
99#include < maths/CBasicStatistics.h>
1010#include < maths/CDataFramePredictiveModel.h>
11+ #include < maths/CSampling.h>
1112#include < maths/CTools.h>
13+ #include < maths/CToolsDetail.h>
1214#include < maths/CTreeShapFeatureImportance.h>
1315
1416#include < api/CDataFrameAnalyzer.h>
17+ #include < api/CDataFrameTrainBoostedTreeRunner.h>
1518
1619#include < test/CDataFrameAnalysisSpecificationFactory.h>
1720#include < test/CRandomNumbers.h>
@@ -27,12 +30,14 @@ using namespace ml;
2730
2831namespace {
2932using TDoubleVec = std::vector<double >;
33+ using TVector = maths::CDenseVector<double >;
3034using TStrVec = std::vector<std::string>;
3135using TRowItr = core::CDataFrame::TRowItr;
3236using TRowRef = core::CDataFrame::TRowRef;
3337using TMeanAccumulator = maths::CBasicStatistics::SSampleMean<double >::TAccumulator;
3438using TMeanAccumulatorVec = std::vector<TMeanAccumulator>;
3539using TMeanVarAccumulator = maths::CBasicStatistics::SSampleMeanVar<double >::TAccumulator;
40+ using TMemoryMappedMatrix = maths::CMemoryMappedDenseMatrix<double >;
3641
3742void setupLinearRegressionData (const TStrVec& fieldNames,
3843 TStrVec& fieldValues,
@@ -128,6 +133,47 @@ void setupBinaryClassificationData(const TStrVec& fieldNames,
128133 }
129134}
130135
136+ void setupMultiClassClassificationData (const TStrVec& fieldNames,
137+ TStrVec& fieldValues,
138+ api::CDataFrameAnalyzer& analyzer,
139+ const TDoubleVec& weights,
140+ const TDoubleVec& values) {
141+ TStrVec classes{" foo" , " bar" , " baz" };
142+ maths::CPRNG::CXorOShiro128Plus rng;
143+ std::uniform_real_distribution<double > u01;
144+ int numberFeatures{static_cast <int >(weights.size ())};
145+ TDoubleVec w{weights};
146+ int numberClasses{static_cast <int >(classes.size ())};
147+ auto probability = [&](const TDoubleVec& row) {
148+ TMemoryMappedMatrix W (&w[0 ], numberClasses, numberFeatures);
149+ TVector x (numberFeatures);
150+ for (int i = 0 ; i < numberFeatures; ++i) {
151+ x (i) = row[i];
152+ }
153+ TVector logit{W * x};
154+ return maths::CTools::softmax (std::move (logit));
155+ };
156+ auto target = [&](const TDoubleVec& row) {
157+ TDoubleVec probabilities{probability (row).to <TDoubleVec>()};
158+ return classes[maths::CSampling::categoricalSample (rng, probabilities)];
159+ };
160+
161+ for (std::size_t i = 0 ; i < values.size (); i += weights.size ()) {
162+ TDoubleVec row (weights.size ());
163+ for (std::size_t j = 0 ; j < weights.size (); ++j) {
164+ row[j] = values[i + j];
165+ }
166+
167+ fieldValues[0 ] = target (row);
168+ for (std::size_t j = 0 ; j < row.size (); ++j) {
169+ fieldValues[j + 1 ] = core::CStringUtils::typeToStringPrecise (
170+ row[j], core::CIEEE754::E_DoublePrecision);
171+ }
172+
173+ analyzer.handleRecord (fieldNames, fieldValues);
174+ }
175+ }
176+
131177struct SFixture {
132178 rapidjson::Document
133179 runRegression (std::size_t shapValues, TDoubleVec weights, double noiseVar = 0.0 ) {
@@ -231,6 +277,57 @@ struct SFixture {
231277 return results;
232278 }
233279
280+ rapidjson::Document runMultiClassClassification (std::size_t shapValues,
281+ TDoubleVec&& weights) {
282+ auto outputWriterFactory = [&]() {
283+ return std::make_unique<core::CJsonOutputStreamWrapper>(s_Output);
284+ };
285+ test::CDataFrameAnalysisSpecificationFactory specFactory;
286+ api::CDataFrameAnalyzer analyzer{
287+ specFactory.rows (s_Rows)
288+ .memoryLimit (26000000 )
289+ .predictionCategoricalFieldNames ({" target" })
290+ .predictionAlpha (s_Alpha)
291+ .predictionLambda (s_Lambda)
292+ .predictionGamma (s_Gamma)
293+ .predictionSoftTreeDepthLimit (s_SoftTreeDepthLimit)
294+ .predictionSoftTreeDepthTolerance (s_SoftTreeDepthTolerance)
295+ .predictionEta (s_Eta)
296+ .predictionMaximumNumberTrees (s_MaximumNumberTrees)
297+ .predictionFeatureBagFraction (s_FeatureBagFraction)
298+ .predictionNumberTopShapValues (shapValues)
299+ .numberClasses (3 )
300+ .numberTopClasses (3 )
301+ .predictionSpec (test::CDataFrameAnalysisSpecificationFactory::classification (), " target" ),
302+ outputWriterFactory};
303+ TStrVec fieldNames{" target" , " c1" , " c2" , " c3" , " c4" , " ." , " ." };
304+ TStrVec fieldValues{" " , " " , " " , " " , " " , " 0" , " " };
305+ test::CRandomNumbers rng;
306+
307+ TDoubleVec values;
308+ rng.generateUniformSamples (-10.0 , 10.0 , weights.size () * s_Rows, values);
309+
310+ setupMultiClassClassificationData (fieldNames, fieldValues, analyzer, weights, values);
311+
312+ analyzer.handleRecord (fieldNames, {" " , " " , " " , " " , " " , " " , " $" });
313+
314+ LOG_DEBUG (<< " estimated memory usage = "
315+ << core::CProgramCounters::counter (counter_t ::E_DFTPMEstimatedPeakMemoryUsage));
316+ LOG_DEBUG (<< " peak memory = "
317+ << core::CProgramCounters::counter (counter_t ::E_DFTPMPeakMemoryUsage));
318+ LOG_DEBUG (<< " time to train = " << core::CProgramCounters::counter (counter_t ::E_DFTPMTimeToTrain)
319+ << " ms" );
320+
321+ BOOST_TEST_REQUIRE (
322+ core::CProgramCounters::counter (counter_t ::E_DFTPMPeakMemoryUsage) <
323+ core::CProgramCounters::counter (counter_t ::E_DFTPMEstimatedPeakMemoryUsage));
324+
325+ rapidjson::Document results;
326+ rapidjson::ParseResult ok (results.Parse (s_Output.str ()));
327+ BOOST_TEST_REQUIRE (static_cast <bool >(ok) == true );
328+ return results;
329+ }
330+
234331 rapidjson::Document runRegressionWithMissingFeatures (std::size_t shapValues) {
235332 auto outputWriterFactory = [&]() {
236333 return std::make_unique<core::CJsonOutputStreamWrapper>(s_Output);
@@ -289,9 +386,48 @@ struct SFixture {
289386
290387template <typename RESULTS>
291388double readShapValue (const RESULTS& results, std::string shapField) {
292- shapField = maths::CTreeShapFeatureImportance::SHAP_PREFIX + shapField;
293- if (results[" row_results" ][" results" ][" ml" ].HasMember (shapField)) {
294- return results[" row_results" ][" results" ][" ml" ][shapField].GetDouble ();
389+ if (results[" row_results" ][" results" ][" ml" ].HasMember (
390+ api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME)) {
391+ for (const auto & shapResult :
392+ results[" row_results" ][" results" ][" ml" ][api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME]
393+ .GetArray ()) {
394+ if (shapResult[api::CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME]
395+ .GetString () == shapField) {
396+ return shapResult[api::CDataFrameTrainBoostedTreeRunner::IMPORTANCE_FIELD_NAME]
397+ .GetDouble ();
398+ }
399+ }
400+ }
401+ return 0.0 ;
402+ }
403+
404+ template <typename RESULTS>
405+ double readShapValue (const RESULTS& results, std::string shapField, std::string className) {
406+ if (results[" row_results" ][" results" ][" ml" ].HasMember (
407+ api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME)) {
408+ for (const auto & shapResult :
409+ results[" row_results" ][" results" ][" ml" ][api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME]
410+ .GetArray ()) {
411+ if (shapResult[api::CDataFrameTrainBoostedTreeRunner::FEATURE_NAME_FIELD_NAME]
412+ .GetString () == shapField) {
413+ if (shapResult.HasMember (className)) {
414+ return shapResult[className].GetDouble ();
415+ }
416+ }
417+ }
418+ }
419+ return 0.0 ;
420+ }
421+
422+ template <typename RESULTS>
423+ double readClassProbability (const RESULTS& results, std::string className) {
424+ if (results[" row_results" ][" results" ][" ml" ].HasMember (" top_classes" )) {
425+ for (const auto & topClasses :
426+ results[" row_results" ][" results" ][" ml" ][" top_classes" ].GetArray ()) {
427+ if (topClasses[" class_name" ].GetString () == className) {
428+ return topClasses[" class_probability" ].GetDouble ();
429+ }
430+ }
295431 }
296432 return 0.0 ;
297433}
@@ -324,9 +460,7 @@ BOOST_FIXTURE_TEST_CASE(testRegressionFeatureImportanceAllShap, SFixture) {
324460 c3Sum += std::fabs (c3);
325461 c4Sum += std::fabs (c4);
326462 // assert that no SHAP value for the dependent variable is returned
327- BOOST_TEST_REQUIRE (result[" row_results" ][" results" ][" ml" ].HasMember (
328- maths::CTreeShapFeatureImportance::SHAP_PREFIX +
329- " target" ) == false );
463+ BOOST_REQUIRE_EQUAL (readShapValue (result, " target" ), 0.0 );
330464 }
331465 }
332466
@@ -421,25 +555,58 @@ BOOST_FIXTURE_TEST_CASE(testClassificationFeatureImportanceAllShap, SFixture) {
421555 BOOST_REQUIRE_SMALL (maths::CBasicStatistics::variance (bias), 1e-6 );
422556}
423557
558+ BOOST_FIXTURE_TEST_CASE (testMultiClassClassificationFeatureImportanceAllShap, SFixture) {
559+
560+ std::size_t topShapValues{4 };
561+ auto results{runMultiClassClassification (topShapValues, {0.5 , -0.7 , 0.2 , -0.2 })};
562+
563+ for (const auto & result : results.GetArray ()) {
564+ if (result.HasMember (" row_results" )) {
565+ double c1Sum{readShapValue (result, " c1" )};
566+ double c2Sum{readShapValue (result, " c2" )};
567+ double c3Sum{readShapValue (result, " c3" )};
568+ double c4Sum{readShapValue (result, " c4" )};
569+ // We should have at least one feature that is important
570+ BOOST_TEST_REQUIRE ((c1Sum > 0.0 || c2Sum > 0.0 || c3Sum > 0.0 || c4Sum > 0.0 ));
571+
572+ // class shap values should sum(abs()) to the overall feature importance
573+ double c1f{readShapValue (result, " c1" , " foo" )};
574+ double c1bar{readShapValue (result, " c1" , " bar" )};
575+ double c1baz{readShapValue (result, " c1" , " baz" )};
576+ BOOST_REQUIRE_CLOSE (
577+ c1Sum, std::abs (c1f) + std::abs (c1bar) + std::abs (c1baz), 1e-6 );
578+
579+ double c2f{readShapValue (result, " c2" , " foo" )};
580+ double c2bar{readShapValue (result, " c2" , " bar" )};
581+ double c2baz{readShapValue (result, " c2" , " baz" )};
582+ BOOST_REQUIRE_CLOSE (
583+ c2Sum, std::abs (c2f) + std::abs (c2bar) + std::abs (c2baz), 1e-6 );
584+
585+ double c3f{readShapValue (result, " c3" , " foo" )};
586+ double c3bar{readShapValue (result, " c3" , " bar" )};
587+ double c3baz{readShapValue (result, " c3" , " baz" )};
588+ BOOST_REQUIRE_CLOSE (
589+ c3Sum, std::abs (c3f) + std::abs (c3bar) + std::abs (c3baz), 1e-6 );
590+
591+ double c4f{readShapValue (result, " c4" , " foo" )};
592+ double c4bar{readShapValue (result, " c4" , " bar" )};
593+ double c4baz{readShapValue (result, " c4" , " baz" )};
594+ BOOST_REQUIRE_CLOSE (
595+ c4Sum, std::abs (c4f) + std::abs (c4bar) + std::abs (c4baz), 1e-6 );
596+ }
597+ }
598+ }
599+
424600BOOST_FIXTURE_TEST_CASE (testRegressionFeatureImportanceNoShap, SFixture) {
425601 // Test that if topShapValue is set to 0, no feature importance values are returned.
426602 std::size_t topShapValues{0 };
427603 auto results{runRegression (topShapValues, {50.0 , 150.0 , 50.0 , -50.0 })};
428604
429605 for (const auto & result : results.GetArray ()) {
430606 if (result.HasMember (" row_results" )) {
431- BOOST_TEST_REQUIRE (
432- result[" row_results" ][" results" ][" ml" ].HasMember (
433- maths::CTreeShapFeatureImportance::SHAP_PREFIX + " c1" ) == false );
434- BOOST_TEST_REQUIRE (
435- result[" row_results" ][" results" ][" ml" ].HasMember (
436- maths::CTreeShapFeatureImportance::SHAP_PREFIX + " c2" ) == false );
437- BOOST_TEST_REQUIRE (
438- result[" row_results" ][" results" ][" ml" ].HasMember (
439- maths::CTreeShapFeatureImportance::SHAP_PREFIX + " c3" ) == false );
440- BOOST_TEST_REQUIRE (
441- result[" row_results" ][" results" ][" ml" ].HasMember (
442- maths::CTreeShapFeatureImportance::SHAP_PREFIX + " c4" ) == false );
607+ BOOST_TEST_REQUIRE (result[" row_results" ][" results" ][" ml" ].HasMember (
608+ api::CDataFrameTrainBoostedTreeRunner::FEATURE_IMPORTANCE_FIELD_NAME) ==
609+ false );
443610 }
444611 }
445612}
0 commit comments