Skip to content

Conversation

@mgaido91
Copy link
Contributor

What changes were proposed in this pull request?

Our feature importance calculation is taken from sklearn's one, which has been recently fixed (in scikit-learn/scikit-learn#11176). Citing the description of that PR:

Because the feature importances are (currently, by default) normalized and then averaged, feature importances from later stages are overweighted.

The PR performs a fix similar to sklearn's one. The per-tree normalization of the feature importance is skipped and GBT.

Credits for pointing out clearly the issue and the sklearn's PR to Daniel Jumper.

How was this patch tested?

modified UT, checked that the computed featureImportance in that test is similar to sklearn's one (ti can't be the same, because the trees may be slightly different)

* Estimate of the importance of each feature.
*
* Each feature's importance is the average of its importance across all trees in the ensemble
* The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is needed to update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it is still valid. The final vector is still normalized to 1.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't you skip normalization of importance vector?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see. The normalization mentioned here is for total importance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the normalization of the importance vector for each tree, but then at the end the vector is still normalized. To simplify in a diagram, before the PR it was:
tree importance -> normalization -> sum -> normalization
now it is
tree importance -> sum -> normalization
So the final result is still normalized.

@SparkQA
Copy link

SparkQA commented Feb 13, 2019

Test build #102296 has finished for PR 23773 at commit 283f093.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 13, 2019

Test build #102297 has finished for PR 23773 at commit 678277a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@mgaido91
Copy link
Contributor Author

@viirya any more comments? cc @srowen

val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
val mostIF = importanceFeatures.argmax
assert(mostImportantFeature !== mostIF)
assert(mostIF === 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously two most important features are different. Why now they are both 1?

Copy link
Contributor Author

@mgaido91 mgaido91 Feb 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the exact reason why they were different earlier (of course the behavior changed because of the fix, but this is expected). You can compare the importances vector with the one returned by sklearn: as I mentioned in the PR description they are very similar (so sklearn too says 1 is the most important in both scenarios using sklearn too).

PS please notice that sklearn version must be >= 0.20.0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I should have commented on this; I actually don't know why the test previously asserted the answers must be different. That's actually the thing I'd least expect, though it's possible. Why does it still assert the importances are different? I suspect they won't match exactly, sure, but if there's an assertion here, isn't it that they're close? They may just not be that comparable in which case there's nothing to assert.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion is there to check that a different subset strategy actually produces different results. In particular, in the first case, the importances vector is [1.0, 0.0, ...] while in the second case more features are used (because the trees can check a random variable at time), so the vector is something like [0.7, ...]. Hence this assertion makes sense in order to check that the featureSubset strategy is properly taken in account.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I get it, we just expect something different to happen under the hood, even if we're largely expecting a similar or the same answer. Leave it in; if it failed because it exactly matched, we'd know it, and could easily figure out whether that's actually now expected or a bug.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular, in the first case, the importances vector is [1.0, 0.0, ...] while in the second case more features are used (because the trees can check a random variable at time), so the vector is something like [0.7, ...].

Don't the second case use just one feature and the first case use all features? What you mean more features are used for the second case? Or I misread the test code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the first case, every tree can choose among all features. Since feature 1 basically is the correct "label" (I mean they are the same), in the first case all the trees choose feature 1 in the first node and they get 100% accuracy. Hence the importance vector is [1.0, 0.0, ...]. In the second case, only 1 random feature per time can be considered. So the trees are more "diverse" and they consider also other features. So the importance vector is the one I mentioned above. You can maybe try and debug this UT if you want to understand better (probably it is more effective than my poor english) or you can try and run the same in sklearn.

Copy link
Member

@viirya viirya Feb 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mgaido91.

I don't have a workable laptop in recent days. So it is hardly for me to run the unit test. That is why I ask for more details.

Sounds that this assertion assert(importances(mostImportantFeature) !== importanceFeatures(mostIF)) makes sense. But for assert(mostIF === 1), because it picks one random feature per time, are we sure that the most importance feature is 1 at all cases? In extreme case, this feature might not be chosen at all. It is potentially flaky. This assertion doesn't make too much sense to me, maybe we don't need it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the seed is fixed, so the UT is actually deterministic and there is no flakyness. Despite with a different seed the result may be different, I'd consider very unlikely anyway that 1 would not be the most important one in any case, since it is really the ground truth in this case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it is correct since there is fixed seed.

Anyway, assert(mostIF === 1) actually means assert(mostImportantFeature == mostIF). This assertion doesn't make much sense as the previous one assert(mostImportantFeature !== mostIF). It doesn't tell us much except that it is happened to have the same most important feature...

OK for me to leave it as it.

@srowen
Copy link
Member

srowen commented Feb 16, 2019

Merged to master

@srowen srowen closed this in 5d8a934 Feb 16, 2019
jackylee-ch pushed a commit to jackylee-ch/spark that referenced this pull request Feb 18, 2019
…or GBT

## What changes were proposed in this pull request?

Our feature importance calculation is taken from sklearn's one, which has been recently fixed (in scikit-learn/scikit-learn#11176). Citing the description of that PR:

> Because the feature importances are (currently, by default) normalized and then averaged, feature importances from later stages are overweighted.

The PR performs a fix similar to sklearn's one. The per-tree normalization of the feature importance is skipped and GBT.

Credits for pointing out clearly the issue and the sklearn's PR to Daniel Jumper.

## How was this patch tested?

modified UT, checked that the computed `featureImportance` in that test is similar to sklearn's one (ti can't be the same, because the trees may be slightly different)

Closes apache#23773 from mgaido91/SPARK-26721.

Authored-by: Marco Gaido <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants