Skip to content

Commit d1e43e4

Browse files
authored
[ML] Store feature importance baselines in model metadata (#1524)
With this PR we will be able to store the feature importance baselines explicitly in the model_metadata. Being able baseline to retrieve the baselines will significantly simplify UI code related to the feature importance visualization.
1 parent 2ae1f3e commit d1e43e4

7 files changed

+179
-123
lines changed

include/api/CInferenceModelMetadata.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace api {
2121
//! (such as totol feature importance) into JSON format.
2222
class API_EXPORT CInferenceModelMetadata {
2323
public:
24+
static const std::string JSON_BASELINE_TAG;
25+
static const std::string JSON_FEATURE_IMPORTANCE_BASELINE_TAG;
2426
static const std::string JSON_CLASS_NAME_TAG;
2527
static const std::string JSON_CLASSES_TAG;
2628
static const std::string JSON_FEATURE_NAME_TAG;
@@ -48,20 +50,26 @@ class API_EXPORT CInferenceModelMetadata {
4850
//! Add importances \p values to the feature with index \p i to calculate total feature importance.
4951
//! Total feature importance is the mean of the magnitudes of importances for individual data points.
5052
void addToFeatureImportance(std::size_t i, const TVector& values);
53+
//! Set the feature importance baseline (the individual feature importances are additive corrections
54+
//! to the baseline value).
55+
void featureImportanceBaseline(TVector&& baseline);
5156

5257
private:
5358
using TMeanAccumulator =
5459
std::vector<maths::CBasicStatistics::SSampleMean<double>::TAccumulator>;
5560
using TMinMaxAccumulator = std::vector<maths::CBasicStatistics::CMinMax<double>>;
5661
using TSizeMeanAccumulatorUMap = std::unordered_map<std::size_t, TMeanAccumulator>;
5762
using TSizeMinMaxAccumulatorUMap = std::unordered_map<std::size_t, TMinMaxAccumulator>;
63+
using TOptionalVector = boost::optional<TVector>;
5864

5965
private:
6066
void writeTotalFeatureImportance(TRapidJsonWriter& writer) const;
67+
void writeFeatureImportanceBaseline(TRapidJsonWriter& writer) const;
6168

6269
private:
6370
TSizeMeanAccumulatorUMap m_TotalShapValuesMean;
6471
TSizeMinMaxAccumulatorUMap m_TotalShapValuesMinMax;
72+
TOptionalVector m_ShapBaseline;
6573
TStrVec m_ColumnNames;
6674
TStrVec m_ClassValues;
6775
TPredictionFieldTypeResolverWriter m_PredictionFieldTypeResolverWriter =

include/maths/CTreeShapFeatureImportance.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
7777
const TStrVec& columnNames() const;
7878

7979
//! Get the baseline.
80-
double baseline(std::size_t classIdx = 0) const;
80+
TVector baseline() const;
8181

8282
private:
8383
//! Collects the elements of the path through decision tree that are updated together

lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc

Lines changed: 52 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -169,74 +169,61 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
169169
[this](const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) {
170170
this->writePredictedCategoryValue(categoryValue, writer);
171171
});
172-
featureImportance->shap(row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
173-
const TStrVec& featureNames,
174-
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
175-
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
176-
writer.StartArray();
177-
TDoubleVec baseline;
178-
baseline.reserve(numberClasses);
179-
for (std::size_t j = 0; j < shap[0].size() && j < numberClasses; ++j) {
180-
baseline.push_back(featureImportance->baseline(j));
181-
}
182-
for (auto i : indices) {
183-
if (shap[i].norm() != 0.0) {
184-
writer.StartObject();
185-
writer.Key(FEATURE_NAME_FIELD_NAME);
186-
writer.String(featureNames[i]);
187-
if (shap[i].size() == 1) {
188-
// output feature importance for individual classes in binary case
189-
writer.Key(CLASSES_FIELD_NAME);
190-
writer.StartArray();
191-
for (std::size_t j = 0; j < numberClasses; ++j) {
192-
writer.StartObject();
193-
writer.Key(CLASS_NAME_FIELD_NAME);
194-
writePredictedCategoryValue(classValues[j], writer);
195-
writer.Key(IMPORTANCE_FIELD_NAME);
196-
if (j == 1) {
197-
writer.Double(shap[i](0));
198-
} else {
199-
writer.Double(-shap[i](0));
172+
featureImportance->shap(
173+
row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
174+
const TStrVec& featureNames,
175+
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
176+
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
177+
writer.StartArray();
178+
for (auto i : indices) {
179+
if (shap[i].norm() != 0.0) {
180+
writer.StartObject();
181+
writer.Key(FEATURE_NAME_FIELD_NAME);
182+
writer.String(featureNames[i]);
183+
if (shap[i].size() == 1) {
184+
// output feature importance for individual classes in binary case
185+
writer.Key(CLASSES_FIELD_NAME);
186+
writer.StartArray();
187+
for (std::size_t j = 0; j < numberClasses; ++j) {
188+
writer.StartObject();
189+
writer.Key(CLASS_NAME_FIELD_NAME);
190+
writePredictedCategoryValue(classValues[j], writer);
191+
writer.Key(IMPORTANCE_FIELD_NAME);
192+
if (j == 1) {
193+
writer.Double(shap[i](0));
194+
} else {
195+
writer.Double(-shap[i](0));
196+
}
197+
writer.EndObject();
200198
}
201-
writer.EndObject();
202-
}
203-
writer.EndArray();
204-
} else {
205-
// output feature importance for individual classes in multiclass case
206-
writer.Key(CLASSES_FIELD_NAME);
207-
writer.StartArray();
208-
TDoubleVec featureImportanceSum(numberClasses, 0.0);
209-
for (std::size_t j = 0;
210-
j < shap[i].size() && j < numberClasses; ++j) {
211-
for (auto k : indices) {
212-
featureImportanceSum[j] += shap[k](j);
199+
writer.EndArray();
200+
} else {
201+
// output feature importance for individual classes in multiclass case
202+
writer.Key(CLASSES_FIELD_NAME);
203+
writer.StartArray();
204+
for (std::size_t j = 0;
205+
j < shap[i].size() && j < numberClasses; ++j) {
206+
writer.StartObject();
207+
writer.Key(CLASS_NAME_FIELD_NAME);
208+
writePredictedCategoryValue(classValues[j], writer);
209+
writer.Key(IMPORTANCE_FIELD_NAME);
210+
writer.Double(shap[i](j));
211+
writer.EndObject();
213212
}
213+
writer.EndArray();
214214
}
215-
for (std::size_t j = 0;
216-
j < shap[i].size() && j < numberClasses; ++j) {
217-
writer.StartObject();
218-
writer.Key(CLASS_NAME_FIELD_NAME);
219-
writePredictedCategoryValue(classValues[j], writer);
220-
writer.Key(IMPORTANCE_FIELD_NAME);
221-
double correctedShap{
222-
shap[i](j) * (baseline[j] / featureImportanceSum[j] + 1.0)};
223-
writer.Double(correctedShap);
224-
writer.EndObject();
225-
}
226-
writer.EndArray();
215+
writer.EndObject();
227216
}
228-
writer.EndObject();
229217
}
230-
}
231-
writer.EndArray();
218+
writer.EndArray();
232219

233-
for (std::size_t i = 0; i < shap.size(); ++i) {
234-
if (shap[i].lpNorm<1>() != 0) {
235-
const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
236-
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
220+
for (std::size_t i = 0; i < shap.size(); ++i) {
221+
if (shap[i].lpNorm<1>() != 0) {
222+
const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
223+
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
224+
}
237225
}
238-
}
239-
});
226+
});
240227
}
241228
writer.EndObject();
242229
}
@@ -306,6 +293,10 @@ CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelDefinition(
306293

307294
CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
308295
CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelMetadata() const {
296+
const auto& featureImportance = this->boostedTree().shap();
297+
if (featureImportance) {
298+
m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline());
299+
}
309300
return m_InferenceModelMetadata;
310301
}
311302

lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,11 @@ CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelDefinition(
155155

156156
CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
157157
CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelMetadata() const {
158-
return TOptionalInferenceModelMetadata(m_InferenceModelMetadata);
158+
const auto& featureImportance = this->boostedTree().shap();
159+
if (featureImportance) {
160+
m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline());
161+
}
162+
return m_InferenceModelMetadata;
159163
}
160164

161165
// clang-format off

lib/api/CInferenceModelMetadata.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace api {
1212

1313
void CInferenceModelMetadata::write(TRapidJsonWriter& writer) const {
1414
this->writeTotalFeatureImportance(writer);
15+
this->writeFeatureImportanceBaseline(writer);
1516
}
1617

1718
void CInferenceModelMetadata::writeTotalFeatureImportance(TRapidJsonWriter& writer) const {
@@ -88,6 +89,53 @@ void CInferenceModelMetadata::writeTotalFeatureImportance(TRapidJsonWriter& writ
8889
writer.EndArray();
8990
}
9091

92+
void CInferenceModelMetadata::writeFeatureImportanceBaseline(TRapidJsonWriter& writer) const {
93+
if (m_ShapBaseline) {
94+
writer.Key(JSON_FEATURE_IMPORTANCE_BASELINE_TAG);
95+
writer.StartObject();
96+
if (m_ShapBaseline->size() == 1 && m_ClassValues.empty()) {
97+
// Regression
98+
writer.Key(JSON_BASELINE_TAG);
99+
writer.Double(m_ShapBaseline.get()(0));
100+
} else if (m_ShapBaseline->size() == 1 && m_ClassValues.empty() == false) {
101+
// Binary classification
102+
writer.Key(JSON_CLASSES_TAG);
103+
writer.StartArray();
104+
for (std::size_t j = 0; j < m_ClassValues.size(); ++j) {
105+
writer.StartObject();
106+
writer.Key(JSON_CLASS_NAME_TAG);
107+
m_PredictionFieldTypeResolverWriter(m_ClassValues[j], writer);
108+
writer.Key(JSON_BASELINE_TAG);
109+
if (j == 1) {
110+
writer.Double(m_ShapBaseline.get()(0));
111+
} else {
112+
writer.Double(-m_ShapBaseline.get()(0));
113+
}
114+
writer.EndObject();
115+
}
116+
117+
writer.EndArray();
118+
119+
} else {
120+
// Multiclass classification
121+
writer.Key(JSON_CLASSES_TAG);
122+
writer.StartArray();
123+
for (std::size_t j = 0; j < static_cast<std::size_t>(m_ShapBaseline->size()) &&
124+
j < m_ClassValues.size();
125+
++j) {
126+
writer.StartObject();
127+
writer.Key(JSON_CLASS_NAME_TAG);
128+
m_PredictionFieldTypeResolverWriter(m_ClassValues[j], writer);
129+
writer.Key(JSON_BASELINE_TAG);
130+
writer.Double(m_ShapBaseline.get()(j));
131+
writer.EndObject();
132+
}
133+
writer.EndArray();
134+
}
135+
writer.EndObject();
136+
}
137+
}
138+
91139
const std::string& CInferenceModelMetadata::typeString() const {
92140
return JSON_MODEL_METADATA_TAG;
93141
}
@@ -119,7 +167,13 @@ void CInferenceModelMetadata::addToFeatureImportance(std::size_t i, const TVecto
119167
}
120168
}
121169

170+
void CInferenceModelMetadata::featureImportanceBaseline(TVector&& baseline) {
171+
m_ShapBaseline = baseline;
172+
}
173+
122174
// clang-format off
175+
const std::string CInferenceModelMetadata::JSON_BASELINE_TAG{"baseline"};
176+
const std::string CInferenceModelMetadata::JSON_FEATURE_IMPORTANCE_BASELINE_TAG{"feature_importance_baseline"};
123177
const std::string CInferenceModelMetadata::JSON_CLASS_NAME_TAG{"class_name"};
124178
const std::string CInferenceModelMetadata::JSON_CLASSES_TAG{"classes"};
125179
const std::string CInferenceModelMetadata::JSON_FEATURE_NAME_TAG{"feature_name"};

0 commit comments

Comments
 (0)