@@ -169,74 +169,61 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
169169 [this ](const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) {
170170 this ->writePredictedCategoryValue (categoryValue, writer);
171171 });
172- featureImportance->shap (row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
173- const TStrVec& featureNames,
174- const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
175- writer.Key (FEATURE_IMPORTANCE_FIELD_NAME);
176- writer.StartArray ();
177- TDoubleVec baseline;
178- baseline.reserve (numberClasses);
179- for (std::size_t j = 0 ; j < shap[0 ].size () && j < numberClasses; ++j) {
180- baseline.push_back (featureImportance->baseline (j));
181- }
182- for (auto i : indices) {
183- if (shap[i].norm () != 0.0 ) {
184- writer.StartObject ();
185- writer.Key (FEATURE_NAME_FIELD_NAME);
186- writer.String (featureNames[i]);
187- if (shap[i].size () == 1 ) {
188- // output feature importance for individual classes in binary case
189- writer.Key (CLASSES_FIELD_NAME);
190- writer.StartArray ();
191- for (std::size_t j = 0 ; j < numberClasses; ++j) {
192- writer.StartObject ();
193- writer.Key (CLASS_NAME_FIELD_NAME);
194- writePredictedCategoryValue (classValues[j], writer);
195- writer.Key (IMPORTANCE_FIELD_NAME);
196- if (j == 1 ) {
197- writer.Double (shap[i](0 ));
198- } else {
199- writer.Double (-shap[i](0 ));
172+ featureImportance->shap (
173+ row, [&](const maths::CTreeShapFeatureImportance::TSizeVec& indices,
174+ const TStrVec& featureNames,
175+ const maths::CTreeShapFeatureImportance::TVectorVec& shap) {
176+ writer.Key (FEATURE_IMPORTANCE_FIELD_NAME);
177+ writer.StartArray ();
178+ for (auto i : indices) {
179+ if (shap[i].norm () != 0.0 ) {
180+ writer.StartObject ();
181+ writer.Key (FEATURE_NAME_FIELD_NAME);
182+ writer.String (featureNames[i]);
183+ if (shap[i].size () == 1 ) {
184+ // output feature importance for individual classes in binary case
185+ writer.Key (CLASSES_FIELD_NAME);
186+ writer.StartArray ();
187+ for (std::size_t j = 0 ; j < numberClasses; ++j) {
188+ writer.StartObject ();
189+ writer.Key (CLASS_NAME_FIELD_NAME);
190+ writePredictedCategoryValue (classValues[j], writer);
191+ writer.Key (IMPORTANCE_FIELD_NAME);
192+ if (j == 1 ) {
193+ writer.Double (shap[i](0 ));
194+ } else {
195+ writer.Double (-shap[i](0 ));
196+ }
197+ writer.EndObject ();
200198 }
201- writer.EndObject ();
202- }
203- writer.EndArray ();
204- } else {
205- // output feature importance for individual classes in multiclass case
206- writer.Key (CLASSES_FIELD_NAME);
207- writer.StartArray ();
208- TDoubleVec featureImportanceSum (numberClasses, 0.0 );
209- for (std::size_t j = 0 ;
210- j < shap[i].size () && j < numberClasses; ++j) {
211- for (auto k : indices) {
212- featureImportanceSum[j] += shap[k](j);
199+ writer.EndArray ();
200+ } else {
201+ // output feature importance for individual classes in multiclass case
202+ writer.Key (CLASSES_FIELD_NAME);
203+ writer.StartArray ();
204+ for (std::size_t j = 0 ;
205+ j < shap[i].size () && j < numberClasses; ++j) {
206+ writer.StartObject ();
207+ writer.Key (CLASS_NAME_FIELD_NAME);
208+ writePredictedCategoryValue (classValues[j], writer);
209+ writer.Key (IMPORTANCE_FIELD_NAME);
210+ writer.Double (shap[i](j));
211+ writer.EndObject ();
213212 }
213+ writer.EndArray ();
214214 }
215- for (std::size_t j = 0 ;
216- j < shap[i].size () && j < numberClasses; ++j) {
217- writer.StartObject ();
218- writer.Key (CLASS_NAME_FIELD_NAME);
219- writePredictedCategoryValue (classValues[j], writer);
220- writer.Key (IMPORTANCE_FIELD_NAME);
221- double correctedShap{
222- shap[i](j) * (baseline[j] / featureImportanceSum[j] + 1.0 )};
223- writer.Double (correctedShap);
224- writer.EndObject ();
225- }
226- writer.EndArray ();
215+ writer.EndObject ();
227216 }
228- writer.EndObject ();
229217 }
230- }
231- writer.EndArray ();
218+ writer.EndArray ();
232219
233- for (std::size_t i = 0 ; i < shap.size (); ++i) {
234- if (shap[i].lpNorm <1 >() != 0 ) {
235- const_cast <CDataFrameTrainBoostedTreeClassifierRunner*>(this )
236- ->m_InferenceModelMetadata .addToFeatureImportance (i, shap[i]);
220+ for (std::size_t i = 0 ; i < shap.size (); ++i) {
221+ if (shap[i].lpNorm <1 >() != 0 ) {
222+ const_cast <CDataFrameTrainBoostedTreeClassifierRunner*>(this )
223+ ->m_InferenceModelMetadata .addToFeatureImportance (i, shap[i]);
224+ }
237225 }
238- }
239- });
226+ });
240227 }
241228 writer.EndObject ();
242229}
@@ -306,6 +293,10 @@ CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelDefinition(
306293
307294CDataFrameAnalysisRunner::TOptionalInferenceModelMetadata
308295CDataFrameTrainBoostedTreeClassifierRunner::inferenceModelMetadata () const {
296+ const auto & featureImportance = this ->boostedTree ().shap ();
297+ if (featureImportance) {
298+ m_InferenceModelMetadata.featureImportanceBaseline (featureImportance->baseline ());
299+ }
309300 return m_InferenceModelMetadata;
310301}
311302
0 commit comments