Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/api/CInferenceModelMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace api {
//! (such as totol feature importance) into JSON format.
class API_EXPORT CInferenceModelMetadata {
public:
static const std::string JSON_BASELINE_TAG;
static const std::string JSON_FEATURE_IMPORTANCE_BASELINE_TAG;
static const std::string JSON_CLASS_NAME_TAG;
static const std::string JSON_CLASSES_TAG;
static const std::string JSON_FEATURE_NAME_TAG;
Expand Down Expand Up @@ -48,20 +50,26 @@ class API_EXPORT CInferenceModelMetadata {
//! Add importances \p values to the feature with index \p i to calculate total feature importance.
//! Total feature importance is the mean of the magnitudes of importances for individual data points.
void addToFeatureImportance(std::size_t i, const TVector& values);
//! Set the feature importance baseline (the individual feature importances are additive corrections
//! to the baseline value).
void featureImportanceBaseline(TVector&& baseline);

private:
using TMeanAccumulator =
std::vector<maths::CBasicStatistics::SSampleMean<double>::TAccumulator>;
using TMinMaxAccumulator = std::vector<maths::CBasicStatistics::CMinMax<double>>;
using TSizeMeanAccumulatorUMap = std::unordered_map<std::size_t, TMeanAccumulator>;
using TSizeMinMaxAccumulatorUMap = std::unordered_map<std::size_t, TMinMaxAccumulator>;
using TOptionalVector = boost::optional<TVector>;

private:
void writeTotalFeatureImportance(TRapidJsonWriter& writer) const;
void writeFeatureImportanceBaseline(TRapidJsonWriter& writer) const;

private:
TSizeMeanAccumulatorUMap m_TotalShapValuesMean;
TSizeMinMaxAccumulatorUMap m_TotalShapValuesMinMax;
TOptionalVector m_ShapBaseline;
TStrVec m_ColumnNames;
TStrVec m_ClassValues;
TPredictionFieldTypeResolverWriter m_PredictionFieldTypeResolverWriter =
Expand Down
2 changes: 1 addition & 1 deletion include/maths/CTreeShapFeatureImportance.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
const TStrVec& columnNames() const;

//! Get the baseline.
double baseline(std::size_t classIdx = 0) const;
TVector baseline() const;

private:
//! Collects the elements of the path through decision tree that are updated together
Expand Down
113 changes: 52 additions & 61 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,74 +169,61 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
[this](const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) {
this->writePredictedCategoryValue(categoryValue, writer);
});
featureImportance->shap(row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
TDoubleVec baseline;
baseline.reserve(numberClasses);
for (std::size_t j = 0; j < shap[0].size() && j < numberClasses; ++j) {
baseline.push_back(featureImportance->baseline(j));
}
for (auto i : indices) {
if (shap[i].norm() != 0.0) {
writer.StartObject();
writer.Key(FEATURE_NAME_FIELD_NAME);
writer.String(featureNames[i]);
if (shap[i].size() == 1) {
// output feature importance for individual classes in binary case
writer.Key(CLASSES_FIELD_NAME);
writer.StartArray();
for (std::size_t j = 0; j < numberClasses; ++j) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writePredictedCategoryValue(classValues[j], writer);
writer.Key(IMPORTANCE_FIELD_NAME);
if (j == 1) {
writer.Double(shap[i](0));
} else {
writer.Double(-shap[i](0));
featureImportance->shap(
row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
const TStrVec& featureNames,
const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
writer.Key(FEATURE_IMPORTANCE_FIELD_NAME);
writer.StartArray();
for (auto i : indices) {
if (shap[i].norm() != 0.0) {
writer.StartObject();
writer.Key(FEATURE_NAME_FIELD_NAME);
writer.String(featureNames[i]);
if (shap[i].size() == 1) {
// output feature importance for individual classes in binary case
writer.Key(CLASSES_FIELD_NAME);
writer.StartArray();
for (std::size_t j = 0; j < numberClasses; ++j) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writePredictedCategoryValue(classValues[j], writer);
writer.Key(IMPORTANCE_FIELD_NAME);
if (j == 1) {
writer.Double(shap[i](0));
} else {
writer.Double(-shap[i](0));
}
writer.EndObject();
}
writer.EndObject();
}
writer.EndArray();
} else {
// output feature importance for individual classes in multiclass case
writer.Key(CLASSES_FIELD_NAME);
writer.StartArray();
TDoubleVec featureImportanceSum(numberClasses, 0.0);
for (std::size_t j = 0;
j < shap[i].size() && j < numberClasses; ++j) {
for (auto k : indices) {
featureImportanceSum[j] += shap[k](j);
writer.EndArray();
} else {
// output feature importance for individual classes in multiclass case
writer.Key(CLASSES_FIELD_NAME);
writer.StartArray();
for (std::size_t j = 0;
j < shap[i].size() && j < numberClasses; ++j) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writePredictedCategoryValue(classValues[j], writer);
writer.Key(IMPORTANCE_FIELD_NAME);
writer.Double(shap[i](j));
writer.EndObject();
}
writer.EndArray();
}
for (std::size_t j = 0;
j < shap[i].size() && j < numberClasses; ++j) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writePredictedCategoryValue(classValues[j], writer);
writer.Key(IMPORTANCE_FIELD_NAME);
double correctedShap{
shap[i](j) * (baseline[j] / featureImportanceSum[j] + 1.0)};
writer.Double(correctedShap);
writer.EndObject();
}
writer.EndArray();
writer.EndObject();
}
writer.EndObject();
}
}
writer.EndArray();
writer.EndArray();

for (std::size_t i = 0; i < shap.size(); ++i) {
if (shap[i].lpNorm<1>() != 0) {
const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
for (std::size_t i = 0; i < shap.size(); ++i) {
if (shap[i].lpNorm<1>() != 0) {
const_cast<CDataFrameTrainBoostedTreeClassifierRunner*>(this)
->m_InferenceModelMetadata.addToFeatureImportance(i, shap[i]);
}
}
}
});
});
}
writer.EndObject();
}
Expand Down Expand Up @@ -306,6 +293,10 @@ CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelDefinition(

CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelMetadata() const {
const auto& featureImportance = this->boostedTree().shap();
if (featureImportance) {
m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline());
}
return m_InferenceModelMetadata;
}

Expand Down
6 changes: 5 additions & 1 deletion lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelDefinition(

CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
CDataFrameTrainBoostedTreeRegressionRunner::inferenceModelMetadata() const {
return TOptionalInferenceModelMetadata(m_InferenceModelMetadata);
const auto& featureImportance = this->boostedTree().shap();
if (featureImportance) {
m_InferenceModelMetadata.featureImportanceBaseline(featureImportance->baseline());
}
return m_InferenceModelMetadata;
}

// clang-format off
Expand Down
54 changes: 54 additions & 0 deletions lib/api/CInferenceModelMetadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace api {

void CInferenceModelMetadata::write(TRapidJsonWriter& writer) const {
this->writeTotalFeatureImportance(writer);
this->writeFeatureImportanceBaseline(writer);
}

void CInferenceModelMetadata::writeTotalFeatureImportance(TRapidJsonWriter& writer) const {
Expand Down Expand Up @@ -88,6 +89,53 @@ void CInferenceModelMetadata::writeTotalFeatureImportance(TRapidJsonWriter& writ
writer.EndArray();
}

void CInferenceModelMetadata::writeFeatureImportanceBaseline(TRapidJsonWriter& writer) const {
if (m_ShapBaseline) {
writer.Key(JSON_FEATURE_IMPORTANCE_BASELINE_TAG);
writer.StartObject();
if (m_ShapBaseline->size() == 1 && m_ClassValues.empty()) {
// Regression
writer.Key(JSON_BASELINE_TAG);
writer.Double(m_ShapBaseline.get()(0));
} else if (m_ShapBaseline->size() == 1 && m_ClassValues.empty() == false) {
// Binary classification
writer.Key(JSON_CLASSES_TAG);
writer.StartArray();
for (std::size_t j = 0; j < m_ClassValues.size(); ++j) {
writer.StartObject();
writer.Key(JSON_CLASS_NAME_TAG);
m_PredictionFieldTypeResolverWriter(m_ClassValues[j], writer);
writer.Key(JSON_BASELINE_TAG);
if (j == 1) {
writer.Double(m_ShapBaseline.get()(0));
} else {
writer.Double(-m_ShapBaseline.get()(0));
}
writer.EndObject();
}

writer.EndArray();

} else {
// Multiclass classification
writer.Key(JSON_CLASSES_TAG);
writer.StartArray();
for (std::size_t j = 0; j < static_cast<std::size_t>(m_ShapBaseline->size()) &&
j < m_ClassValues.size();
++j) {
writer.StartObject();
writer.Key(JSON_CLASS_NAME_TAG);
m_PredictionFieldTypeResolverWriter(m_ClassValues[j], writer);
writer.Key(JSON_BASELINE_TAG);
writer.Double(m_ShapBaseline.get()(j));
writer.EndObject();
}
writer.EndArray();
}
writer.EndObject();
}
}

const std::string& CInferenceModelMetadata::typeString() const {
return JSON_MODEL_METADATA_TAG;
}
Expand Down Expand Up @@ -119,7 +167,13 @@ void CInferenceModelMetadata::addToFeatureImportance(std::size_t i, const TVecto
}
}

void CInferenceModelMetadata::featureImportanceBaseline(TVector&& baseline) {
m_ShapBaseline = baseline;
}

// clang-format off
const std::string CInferenceModelMetadata::JSON_BASELINE_TAG{"baseline"};
const std::string CInferenceModelMetadata::JSON_FEATURE_IMPORTANCE_BASELINE_TAG{"feature_importance_baseline"};
const std::string CInferenceModelMetadata::JSON_CLASS_NAME_TAG{"class_name"};
const std::string CInferenceModelMetadata::JSON_CLASSES_TAG{"classes"};
const std::string CInferenceModelMetadata::JSON_FEATURE_NAME_TAG{"feature_name"};
Expand Down
Loading