From 9babef5732b8482053cc71753aca5c7bd3fb91a6 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Nov 2018 16:15:15 +0100 Subject: [PATCH 1/4] [SPARK-25959][ML] GBTClassifier picks wrong impurity stats on loading --- mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala | 2 +- .../org/apache/spark/ml/classification/GBTClassifierSuite.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 00157fe63af41..e5adeb529efaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -538,7 +538,7 @@ private[ml] object GBTClassifierParams { Array("logistic").map(_.toLowerCase(Locale.ROOT)) } -private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { +private[ml] trait GBTClassifierParams extends GBTParams with TreeRegressorParams { /** * Loss function which GBT tries to minimize. (case-insensitive) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 304977634189c..cedbaf1858ef4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -448,6 +448,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { model2: GBTClassificationModel): Unit = { TreeTests.checkEqual(model, model2) assert(model.numFeatures === model2.numFeatures) + assert(model.featureImportances == model2.featureImportances) } val gbt = new GBTClassifier() From 7d24f33d4d802096b06594e8bf643ed5b4546613 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Nov 2018 16:57:41 +0100 Subject: [PATCH 2/4] fix mima --- project/MimaExcludes.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 350d8ad6942ff..7efd17bf4e6c5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,8 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), // [SPARK-25737] Remove JavaSparkContextVarargsWorkaround ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.api.java.JavaSparkContext"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.union"), From 6b9121197ca2490c9a6a6d3f23f1ea4e2a19d99f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 13 Nov 2018 17:27:27 +0100 Subject: [PATCH 3/4] address comments --- .../org/apache/spark/ml/tree/treeParams.scala | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index e5adeb529efaa..11c18c160d41c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -538,7 +538,7 @@ private[ml] object GBTClassifierParams { Array("logistic").map(_.toLowerCase(Locale.ROOT)) } -private[ml] trait GBTClassifierParams extends GBTParams with TreeRegressorParams { +private[ml] trait GBTClassifierParams extends GBTParams { /** * Loss function which GBT tries to minimize. (case-insensitive) @@ -566,6 +566,41 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeRegressorParams throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") } } + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "variance". + * (default = variance) + * @group param + */ + final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", + (value: String) => + TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) + + setDefault(impurity -> "variance") + + /** + * @deprecated This method is deprecated and will be removed in 3.0.0. + * @group setParam + */ + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setImpurity(value: String): this.type = set(impurity, value) + + /** @group getParam */ + final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) + + /** Convert new impurity to old impurity. */ + private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "variance" => OldVariance + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"GBTClassifierParams was given unrecognized impurity: $impurity") + } + } } private[ml] object GBTRegressorParams { From 4aefc9f9d561f1c5f2c68442681cd0f8cf4dea62 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 14 Nov 2018 10:24:55 +0100 Subject: [PATCH 4/4] address comment --- .../ml/classification/GBTClassifier.scala | 4 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../org/apache/spark/ml/tree/treeParams.scala | 54 ++++--------------- project/MimaExcludes.scala | 8 +++ 5 files changed, 23 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 62cfa39746ff0..62c6bdbdeb285 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -427,7 +427,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - metadata.getAndSetParams(model) + // We ignore the impurity while loading models because in previous models it was wrongly + // set to gini (see SPARK-25959). + metadata.getAndSetParams(model, Some(List("impurity"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 6fa656275c1fd..c9de85de42fa5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -145,7 +145,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S @Since("1.4.0") object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] { /** Accessor for supported impurities: variance */ - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities @Since("2.0.0") override def load(path: String): DecisionTreeRegressor = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 82bf66ff66d8a..66d57ad6c4348 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -146,7 +146,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{ /** Accessor for supported impurity settings: variance */ @Since("1.4.0") - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 11c18c160d41c..f1e3836ebe476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -258,11 +258,7 @@ private[ml] object TreeClassifierParams { private[ml] trait DecisionTreeClassifierParams extends DecisionTreeParams with TreeClassifierParams -/** - * Parameters for Decision Tree-based regression algorithms. - */ -private[ml] trait TreeRegressorParams extends Params { - +private[ml] trait HasVarianceImpurity extends Params { /** * Criterion used for information gain calculation (case-insensitive). * Supported: "variance". @@ -271,9 +267,9 @@ private[ml] trait TreeRegressorParams extends Params { */ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", + s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}", (value: String) => - TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) + HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "variance") @@ -299,12 +295,17 @@ private[ml] trait TreeRegressorParams extends Params { } } -private[ml] object TreeRegressorParams { +private[ml] object HasVarianceImpurity { // These options should be lowercase. final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase(Locale.ROOT)) } +/** + * Parameters for Decision Tree-based regression algorithms. + */ +private[ml] trait TreeRegressorParams extends HasVarianceImpurity + private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams with TreeRegressorParams with HasVarianceCol { @@ -538,7 +539,7 @@ private[ml] object GBTClassifierParams { Array("logistic").map(_.toLowerCase(Locale.ROOT)) } -private[ml] trait GBTClassifierParams extends GBTParams { +private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity { /** * Loss function which GBT tries to minimize. (case-insensitive) @@ -566,41 +567,6 @@ private[ml] trait GBTClassifierParams extends GBTParams { throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") } } - - /** - * Criterion used for information gain calculation (case-insensitive). - * Supported: "variance". - * (default = variance) - * @group param - */ - final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + - " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", - (value: String) => - TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) - - setDefault(impurity -> "variance") - - /** - * @deprecated This method is deprecated and will be removed in 3.0.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") - def setImpurity(value: String): this.type = set(impurity, value) - - /** @group getParam */ - final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) - - /** Convert new impurity to old impurity. */ - private[ml] def getOldImpurity: OldImpurity = { - getImpurity match { - case "variance" => OldVariance - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException( - s"GBTClassifierParams was given unrecognized impurity: $impurity") - } - } } private[ml] object GBTRegressorParams { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eaefed95a4f54..a8d2b5d1d9cb6 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -38,6 +38,14 @@ object MimaExcludes { lazy val v30excludes = v24excludes ++ Seq( // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3 ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"),