diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index abe2d1febfdf..a5ed4a38a886 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -341,11 +341,12 @@ class GBTClassificationModel private[ml]( * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) * and follows the implementation from scikit-learn. - + * * See `DecisionTreeClassificationModel.featureImportances` */ @Since("2.0.0") - lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = + TreeEnsembleModel.featureImportances(trees, numFeatures, perTreeNormalization = false) /** Raw prediction for the positive class. */ private def margin(features: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 9a5b7d59e9ae..9f0f567a5b53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -285,7 +285,8 @@ class GBTRegressionModel private[ml]( * @see `DecisionTreeRegressionModel.featureImportances` */ @Since("2.0.0") - lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = + TreeEnsembleModel.featureImportances(trees, numFeatures, perTreeNormalization = false) /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 51d5d5c58c57..e95c55f6048f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -135,7 +135,7 @@ private[ml] object TreeEnsembleModel { * - Average over trees: * - importance(feature j) = sum (over nodes which split on feature j) of the gain, * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. + * - Normalize importances for tree to sum to 1 (only if `perTreeNormalization` is `true`). * - Normalize feature importance vector to sum to 1. * * References: @@ -145,9 +145,15 @@ private[ml] object TreeEnsembleModel { * @param numFeatures Number of features in model (even if not all are explicitly used by * the model). * If -1, then numFeatures is set based on the max feature index in all trees. + * @param perTreeNormalization By default this is set to `true` and it means that the importances + * of each tree are normalized before being summed. If set to `false`, + * the normalization is skipped. * @return Feature importance values, of length numFeatures. */ - def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = { + def featureImportances[M <: DecisionTreeModel]( + trees: Array[M], + numFeatures: Int, + perTreeNormalization: Boolean = true): Vector = { val totalImportances = new OpenHashMap[Int, Double]() trees.foreach { tree => // Aggregate feature importance vector for this tree @@ -155,10 +161,19 @@ private[ml] object TreeEnsembleModel { computeFeatureImportance(tree.rootNode, importances) // Normalize importance vector for this tree, and add it to total. // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? - val treeNorm = importances.map(_._2).sum + val treeNorm = if (perTreeNormalization) { + importances.map(_._2).sum + } else { + // We won't use it + Double.NaN + } if (treeNorm != 0) { importances.foreach { case (idx, impt) => - val normImpt = impt / treeNorm + val normImpt = if (perTreeNormalization) { + impt / treeNorm + } else { + impt + } totalImportances.changeValue(idx, normImpt, _ + normImpt) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index cedbaf1858ef..cd59900c521c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -363,7 +363,8 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances val mostIF = importanceFeatures.argmax - assert(mostImportantFeature !== mostIF) + assert(mostIF === 1) + assert(importances(mostImportantFeature) !== importanceFeatures(mostIF)) } test("model evaluateEachIteration") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index b145c7a3dc95..46fa3767efdc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -200,7 +200,8 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances val mostIF = importanceFeatures.argmax - assert(mostImportantFeature !== mostIF) + assert(mostIF === 1) + assert(importances(mostImportantFeature) !== importanceFeatures(mostIF)) } test("model evaluateEachIteration") {