From 283f0930120a4c5b9fd9b2ae0d10672081a3dff6 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 13 Feb 2019 14:59:54 +0100 Subject: [PATCH 1/3] [SPARK-26721][ML] Avoid per-tree nornmalization in featureImportance for GBT --- .../ml/classification/GBTClassifier.scala | 5 +++-- .../spark/ml/regression/GBTRegressor.scala | 3 ++- .../org/apache/spark/ml/tree/treeModels.scala | 18 +++++++++++++++--- .../ml/classification/GBTClassifierSuite.scala | 3 ++- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index abe2d1febfdf..a5ed4a38a886 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -341,11 +341,12 @@ class GBTClassificationModel private[ml]( * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) * and follows the implementation from scikit-learn. - + * * See `DecisionTreeClassificationModel.featureImportances` */ @Since("2.0.0") - lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = + TreeEnsembleModel.featureImportances(trees, numFeatures, perTreeNormalization = false) /** Raw prediction for the positive class. */ private def margin(features: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 9a5b7d59e9ae..9f0f567a5b53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -285,7 +285,8 @@ class GBTRegressionModel private[ml]( * @see `DecisionTreeRegressionModel.featureImportances` */ @Since("2.0.0") - lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = + TreeEnsembleModel.featureImportances(trees, numFeatures, perTreeNormalization = false) /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 51d5d5c58c57..a416c566d37f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -147,7 +147,10 @@ private[ml] object TreeEnsembleModel { * If -1, then numFeatures is set based on the max feature index in all trees. * @return Feature importance values, of length numFeatures. */ - def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = { + def featureImportances[M <: DecisionTreeModel]( + trees: Array[M], + numFeatures: Int, + perTreeNormalization: Boolean = true): Vector = { val totalImportances = new OpenHashMap[Int, Double]() trees.foreach { tree => // Aggregate feature importance vector for this tree @@ -155,10 +158,19 @@ private[ml] object TreeEnsembleModel { computeFeatureImportance(tree.rootNode, importances) // Normalize importance vector for this tree, and add it to total. // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? - val treeNorm = importances.map(_._2).sum + val treeNorm = if (perTreeNormalization) { + importances.map(_._2).sum + } else { + // We won't use it + Double.NaN + } if (treeNorm != 0) { importances.foreach { case (idx, impt) => - val normImpt = impt / treeNorm + val normImpt = if (perTreeNormalization) { + impt / treeNorm + } else { + impt + } totalImportances.changeValue(idx, normImpt, _ + normImpt) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index cedbaf1858ef..cd59900c521c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -363,7 +363,8 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances val mostIF = importanceFeatures.argmax - assert(mostImportantFeature !== mostIF) + assert(mostIF === 1) + assert(importances(mostImportantFeature) !== importanceFeatures(mostIF)) } test("model evaluateEachIteration") { From d1f15bf5b900645d9c130e821c541dac782517eb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 13 Feb 2019 16:07:23 +0100 Subject: [PATCH 2/3] address comment: improve doc --- .../src/main/scala/org/apache/spark/ml/tree/treeModels.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index a416c566d37f..e95c55f6048f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -135,7 +135,7 @@ private[ml] object TreeEnsembleModel { * - Average over trees: * - importance(feature j) = sum (over nodes which split on feature j) of the gain, * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. + * - Normalize importances for tree to sum to 1 (only if `perTreeNormalization` is `true`). * - Normalize feature importance vector to sum to 1. * * References: @@ -145,6 +145,9 @@ private[ml] object TreeEnsembleModel { * @param numFeatures Number of features in model (even if not all are explicitly used by * the model). * If -1, then numFeatures is set based on the max feature index in all trees. + * @param perTreeNormalization By default this is set to `true` and it means that the importances + * of each tree are normalized before being summed. If set to `false`, + * the normalization is skipped. * @return Feature importance values, of length numFeatures. */ def featureImportances[M <: DecisionTreeModel]( From 678277a803fbdd45217c4677ea71b01239384941 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 13 Feb 2019 16:07:37 +0100 Subject: [PATCH 3/3] fix UTs --- .../org/apache/spark/ml/regression/GBTRegressorSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index b145c7a3dc95..46fa3767efdc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -200,7 +200,8 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances val mostIF = importanceFeatures.argmax - assert(mostImportantFeature !== mostIF) + assert(mostIF === 1) + assert(importances(mostImportantFeature) !== importanceFeatures(mostIF)) } test("model evaluateEachIteration") {