From 77549a9b74c510f480df31fad314d3395313812d Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 19 Feb 2015 02:53:33 +0530 Subject: [PATCH 1/8] [SPARK-5436] Validate GradientBoostedTrees using runWithValidation --- .../mllib/tree/GradientBoostedTrees.scala | 83 +++++++++++++++++-- .../tree/configuration/BoostingStrategy.scala | 10 ++- .../tree/GradientBoostedTreesSuite.scala | 26 ++++++ 3 files changed, 111 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 61f6b1313f82e..fdef3b012e162 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost(input, boostingStrategy) + case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, boostingStrategy) + GradientBoostedTrees.boost(remappedInput, + remappedInput, boostingStrategy, validate=false) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -76,8 +77,42 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { run(input.rdd) } -} + /** + * Method to validate a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param input Validation dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return a gradient boosted trees model that can be used for prediction + */ + def runWithValidation( + trainInput: RDD[LabeledPoint], + validateInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case Regression => GradientBoostedTrees.boost( + trainInput, validateInput, boostingStrategy, validate=true) + case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedTrainInput = trainInput.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val remappedValidateInput = trainInput.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedTrainInput, remappedValidateInput, boostingStrategy, + validate=true) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]]. + */ + def runWithValidation( + trainInput: JavaRDD[LabeledPoint], + validateInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { + runWithValidation(trainInput.rdd, validateInput.rdd) + } +} object GradientBoostedTrees extends Logging { @@ -108,12 +143,16 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. * @param input training dataset + * @param validateInput validation dataset, ignored if validate is set to false. * @param boostingStrategy boosting parameters + * @param validate whether or not to use the validation dataset. * @return a gradient boosted trees model that can be used for prediction */ private def boost( input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + validateInput: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy, + validate: Boolean = false): GradientBoostedTreesModel = { val timer = new TimeTracker() timer.start("total") @@ -129,6 +168,7 @@ object GradientBoostedTrees extends Logging { val learningRate = boostingStrategy.learningRate // Prepare strategy for individual trees, which use regression with variance impurity. val treeStrategy = boostingStrategy.treeStrategy.copy + val validationTol = boostingStrategy.validationTol treeStrategy.algo = Regression treeStrategy.impurity = Variance treeStrategy.assertValid() @@ -151,14 +191,25 @@ object GradientBoostedTrees extends Logging { baseLearners(0) = firstTreeModel baseLearnerWeights(0) = 1.0 val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) - logDebug("error of gbt = " + loss.computeError(startingModel, input)) + val errorModel = loss.computeError(startingModel, input) + logDebug("error of gbt = " + errorModel) + // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") + // Just so that it can be accessed below. This error is ignored if validate is set to false. + var prevValidateError = { + if (validate) { + loss.computeError(startingModel, validateInput) + } + else { + errorModel + } + } + // psuedo-residual for second iteration data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), point.features)) - var m = 1 while (m < numIterations) { timer.start(s"building tree $m") @@ -176,7 +227,24 @@ object GradientBoostedTrees extends Logging { // Note: A model of type regression is used since we require raw prediction val partialModel = new GradientBoostedTreesModel( Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) - logDebug("error of gbt = " + loss.computeError(partialModel, input)) + val errorModel = loss.computeError(partialModel, input) + logDebug("error of gbt = " + errorModel) + + if (validate) { + // Stop training early if + // 1. Reduction in error is lesser than the validationTol or + // 2. If the error increases, that is if the model is overfit. + val currentValidateError = loss.computeError(partialModel, validateInput) + if (prevValidateError - currentValidateError < validationTol) { + return new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, + baseLearners.slice(0, m), + baseLearnerWeights.slice(0, m)) + } + else { + prevValidateError = currentValidateError + } + } // Update data with pseudo-residuals data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), point.features)) @@ -191,4 +259,5 @@ object GradientBoostedTrees extends Logging { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index ed8e6a796f8c4..e191702dc57a2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -34,6 +34,12 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * weak hypotheses used in the final model. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] + * @param validationTol Useful when runWithValidation is used. If the error rate between two + iterations is lesser than the validationTol, then stop. If run + is used, then this parameter is ignored. + + a pair of RDD's are supplied to run. If the error rate + * between two iterations is lesser than convergenceTol, then training stops. */ @Experimental case class BoostingStrategy( @@ -42,7 +48,8 @@ case class BoostingStrategy( @BeanProperty var loss: Loss, // Optional boosting parameters @BeanProperty var numIterations: Int = 100, - @BeanProperty var learningRate: Double = 0.1) extends Serializable { + @BeanProperty var learningRate: Double = 0.1, + @BeanProperty var validationTol: Double = 1e-5) extends Serializable { /** * Check validity of parameters. @@ -62,6 +69,7 @@ case class BoostingStrategy( } require(learningRate > 0 && learningRate <= 1, "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.") + require(validationTol >= 0, s"validationTol $validationTol should be greater than zero.") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index bde47606eb001..3d00347bb1f66 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -158,6 +158,30 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } } } + + test("Early stopping when validation data is provided.") { + // Set numIterations large enough so that it early stops. + val numIterations = 20 + val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) + val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + Array(SquaredError, AbsoluteError).foreach { error => + val boostingStrategy = + new BoostingStrategy(treeStrategy, error, numIterations, validationTol = 0.0) + + val gbtValidate = new GradientBoostedTrees(boostingStrategy).runWithValidation( + trainRdd, validateRdd) + assert(gbtValidate.numTrees != numIterations) + + val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) + val errorWithoutValidation = error.computeError(gbt, validateRdd) + val errorWithValidation = error.computeError(gbtValidate, validateRdd) + assert(errorWithValidation < errorWithoutValidation) + } + + } } private object GradientBoostedTreesSuite { @@ -166,4 +190,6 @@ private object GradientBoostedTreesSuite { val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) + val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120) + val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80) } From 3e743723f128f5d87f66dc05a92d3f68e6fd01cb Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 19 Feb 2015 14:26:35 +0530 Subject: [PATCH 2/8] TST: Add test for classification --- .../mllib/tree/GradientBoostedTrees.scala | 3 +- .../tree/GradientBoostedTreesSuite.scala | 37 ++++++++++++++++++- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index fdef3b012e162..37dbb3fb44518 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -227,8 +227,7 @@ object GradientBoostedTrees extends Logging { // Note: A model of type regression is used since we require raw prediction val partialModel = new GradientBoostedTreesModel( Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) - val errorModel = loss.computeError(partialModel, input) - logDebug("error of gbt = " + errorModel) + logDebug("error of gbt = " + loss.computeError(partialModel, input)) if (validate) { // Stop training early if diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 3d00347bb1f66..3726ca3eff3bf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -159,7 +159,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } } - test("Early stopping when validation data is provided.") { + test("runWithValidation performs better on a validation dataset (Regression)") { // Set numIterations large enough so that it early stops. val numIterations = 20 val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) @@ -180,8 +180,41 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { val errorWithValidation = error.computeError(gbtValidate, validateRdd) assert(errorWithValidation < errorWithoutValidation) } - } + + test("runWithValidation performs better on a validation dataset (Classification)") { + // Set numIterations large enough so that it early stops. + val numIterations = 20 + val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) + val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) + + val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, LogLoss, numIterations, validationTol = 0.0) + + // Test that it stops early. + val gbtValidate = new GradientBoostedTrees(boostingStrategy).runWithValidation( + trainRdd, validateRdd) + assert(gbtValidate.numTrees != numIterations) + + // Remap labels to {-1, 1} + val remappedInput = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + + // The error checked for internally in the GradientBoostedTrees is based on Regression. + // Hence for the validation model, the Classification error need not be strictly less than + // that done with validation. + val gbtValidateRegressor = new GradientBoostedTreesModel( + Regression, gbtValidate.trees, gbtValidate.treeWeights) + val errorWithValidation = LogLoss.computeError(gbtValidateRegressor, remappedInput) + + val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) + val gbtRegressor = new GradientBoostedTreesModel(Regression, gbt.trees, gbt.treeWeights) + val errorWithoutValidation = LogLoss.computeError(gbtRegressor, remappedInput) + + assert(errorWithValidation < errorWithoutValidation) + } + } private object GradientBoostedTreesSuite { From 55e5c3b22c39ef1f71e35eca0fda06a9080d22d7 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 19 Feb 2015 17:40:39 +0530 Subject: [PATCH 3/8] One liner for prevValidateError --- .../spark/mllib/tree/GradientBoostedTrees.scala | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 37dbb3fb44518..c387a88b32f78 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -191,21 +191,12 @@ object GradientBoostedTrees extends Logging { baseLearners(0) = firstTreeModel baseLearnerWeights(0) = 1.0 val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) - val errorModel = loss.computeError(startingModel, input) - logDebug("error of gbt = " + errorModel) + logDebug("error of gbt = " + loss.computeError(startingModel, input)) // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") - // Just so that it can be accessed below. This error is ignored if validate is set to false. - var prevValidateError = { - if (validate) { - loss.computeError(startingModel, validateInput) - } - else { - errorModel - } - } + var prevValidateError = if (validate) loss.computeError(startingModel, validateInput) else 0.0 // psuedo-residual for second iteration data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), From fad9b6e0ef8d760158410069341e285bebc52c58 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 20 Feb 2015 11:47:10 +0530 Subject: [PATCH 4/8] Made the following changes 1. Add section to documentation 2. Return corresponding to bestValidationError 3. Allow negative tolerance. --- docs/mllib-ensembles.md | 12 ++++++++++++ .../spark/mllib/tree/GradientBoostedTrees.scala | 17 ++++++++++------- .../tree/configuration/BoostingStrategy.scala | 1 - 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 23ede04b62d5b..b2f57a5053104 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -384,6 +384,18 @@ On each iteration, the algorithm uses the current ensemble to predict the label The specific mechanism for re-labeling instances is defined by a loss function (discussed below). With each iteration, GBTs further reduce this loss function on the training data. +#### Validation while training + +Gradient boosting can overfit when trained with more number of trees. In order to prevent overfitting, it might +be useful to validate while training. The method **`runWithValidation`** has been provided to make use of this +option. It takes a pair of RDD's as arguments, the first one being the training dataset and the second being the validation dataset. + +The training is stopped when the improvement in the validation error is not more than a certain tolerance +(supplied by the **`validationTol`** argument in **`BoostingStrategy`**). In practice, the validation error +decreases with the increase in number of trees and then increases as the model starts to overfit. There might +be cases, in which the validation error does not change monotonically, and the user is advised to set a large +enough negative tolerance and examine the validation curve to make further inference. + #### Losses The table below lists the losses currently supported by GBTs in MLlib. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index c387a88b32f78..23a7db93314d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -105,7 +105,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) } /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]]. + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]]. */ def runWithValidation( trainInput: JavaRDD[LabeledPoint], @@ -196,7 +196,8 @@ object GradientBoostedTrees extends Logging { // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") - var prevValidateError = if (validate) loss.computeError(startingModel, validateInput) else 0.0 + var bestValidateError = if (validate) loss.computeError(startingModel, validateInput) else 0.0 + var bestM = 1 // psuedo-residual for second iteration data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), @@ -224,15 +225,17 @@ object GradientBoostedTrees extends Logging { // Stop training early if // 1. Reduction in error is lesser than the validationTol or // 2. If the error increases, that is if the model is overfit. + // We want the model returned corresponding to the best validation error. val currentValidateError = loss.computeError(partialModel, validateInput) - if (prevValidateError - currentValidateError < validationTol) { + if (bestValidateError - currentValidateError < validationTol) { return new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, - baseLearners.slice(0, m), - baseLearnerWeights.slice(0, m)) + baseLearners.slice(0, bestM), + baseLearnerWeights.slice(0, bestM)) } - else { - prevValidateError = currentValidateError + else if (currentValidateError < bestValidateError){ + bestValidateError = currentValidateError + bestM = m + 1 } } // Update data with pseudo-residuals diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index e191702dc57a2..77e2176be9251 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -69,7 +69,6 @@ case class BoostingStrategy( } require(learningRate > 0 && learningRate <= 1, "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.") - require(validationTol >= 0, s"validationTol $validationTol should be greater than zero.") } } From b928a19d4592fb36d2bda0ae600e53d2b240b980 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 20 Feb 2015 12:05:12 +0530 Subject: [PATCH 5/8] Move validation while training section under usage tips --- docs/mllib-ensembles.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index b2f57a5053104..ce349936bd8a9 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -384,18 +384,6 @@ On each iteration, the algorithm uses the current ensemble to predict the label The specific mechanism for re-labeling instances is defined by a loss function (discussed below). With each iteration, GBTs further reduce this loss function on the training data. -#### Validation while training - -Gradient boosting can overfit when trained with more number of trees. In order to prevent overfitting, it might -be useful to validate while training. The method **`runWithValidation`** has been provided to make use of this -option. It takes a pair of RDD's as arguments, the first one being the training dataset and the second being the validation dataset. - -The training is stopped when the improvement in the validation error is not more than a certain tolerance -(supplied by the **`validationTol`** argument in **`BoostingStrategy`**). In practice, the validation error -decreases with the increase in number of trees and then increases as the model starts to overfit. There might -be cases, in which the validation error does not change monotonically, and the user is advised to set a large -enough negative tolerance and examine the validation curve to make further inference. - #### Losses The table below lists the losses currently supported by GBTs in MLlib. @@ -439,6 +427,18 @@ We omit some decision tree parameters since those are covered in the [decision t * **`algo`**: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter. +#### Validation while training + +Gradient boosting can overfit when trained with more number of trees. In order to prevent overfitting, it might +be useful to validate while training. The method **`runWithValidation`** has been provided to make use of this +option. It takes a pair of RDD's as arguments, the first one being the training dataset and the second being the validation dataset. + +The training is stopped when the improvement in the validation error is not more than a certain tolerance +(supplied by the **`validationTol`** argument in **`BoostingStrategy`**). In practice, the validation error +decreases with the increase in number of trees and then increases as the model starts to overfit. There might +be cases, in which the validation error does not change monotonically, and the user is advised to set a large +enough negative tolerance and examine the validation curve to make further inference. + ### Examples From b48a70fc168ddc673bda57208582d9af0f06d8b4 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 20 Feb 2015 16:52:31 +0530 Subject: [PATCH 6/8] COSMIT --- .../org/apache/spark/mllib/tree/GradientBoostedTrees.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 23a7db93314d0..1cf172263cdba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -80,8 +80,10 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Method to validate a gradient boosting model - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @param input Validation dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param trainInput Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param validateInput Validation dataset: + RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + Should follow same distribution as trainInput. * @return a gradient boosted trees model that can be used for prediction */ def runWithValidation( From e4d799b679c519125a07851fc90a53244177a196 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 23 Feb 2015 09:44:03 +0530 Subject: [PATCH 7/8] Addresses indentation and doc comments --- docs/mllib-ensembles.md | 15 +++---- .../mllib/tree/GradientBoostedTrees.scala | 45 ++++++++++--------- .../tree/configuration/BoostingStrategy.scala | 10 ++--- .../tree/GradientBoostedTreesSuite.scala | 18 ++++---- 4 files changed, 43 insertions(+), 45 deletions(-) diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index ce349936bd8a9..1af465f5d1df4 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -429,16 +429,15 @@ We omit some decision tree parameters since those are covered in the [decision t #### Validation while training -Gradient boosting can overfit when trained with more number of trees. In order to prevent overfitting, it might -be useful to validate while training. The method **`runWithValidation`** has been provided to make use of this -option. It takes a pair of RDD's as arguments, the first one being the training dataset and the second being the validation dataset. +Gradient boosting can overfit when trained with more trees. In order to prevent overfitting, it is useful to validate while +training. The method runWithValidation has been provided to make use of this option. It takes a pair of RDD's as arguments, the +first one being the training dataset and the second being the validation dataset. The training is stopped when the improvement in the validation error is not more than a certain tolerance -(supplied by the **`validationTol`** argument in **`BoostingStrategy`**). In practice, the validation error -decreases with the increase in number of trees and then increases as the model starts to overfit. There might -be cases, in which the validation error does not change monotonically, and the user is advised to set a large -enough negative tolerance and examine the validation curve to make further inference. - +(supplied by the validationTol argument in BoostingStrategy). In practice, the validation error +decreases initially and later increases. There might be cases in which the validation error does not change monotonically, +and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of +iterations. ### Examples diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 1cf172263cdba..65459707b9188 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -80,26 +80,28 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Method to validate a gradient boosting model - * @param trainInput Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @param validateInput Validation dataset: + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param validationInput Validation dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - Should follow same distribution as trainInput. + Should be different from and follow the same distribution as input. + e.g., these two datasets could be created from an original dataset + by using [[org.apache.spark.rdd.RDD.randomSplit()]] * @return a gradient boosted trees model that can be used for prediction */ def runWithValidation( - trainInput: RDD[LabeledPoint], - validateInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { + input: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { case Regression => GradientBoostedTrees.boost( - trainInput, validateInput, boostingStrategy, validate=true) + input, validationInput, boostingStrategy, validate=true) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. - val remappedTrainInput = trainInput.map( + val remappedInput = input.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) - val remappedValidateInput = trainInput.map( + val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedTrainInput, remappedValidateInput, boostingStrategy, + GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, validate=true) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") @@ -110,9 +112,9 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]]. */ def runWithValidation( - trainInput: JavaRDD[LabeledPoint], - validateInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { - runWithValidation(trainInput.rdd, validateInput.rdd) + input: JavaRDD[LabeledPoint], + validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { + runWithValidation(input.rdd, validationInput.rdd) } } @@ -145,16 +147,16 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. * @param input training dataset - * @param validateInput validation dataset, ignored if validate is set to false. + * @param validationInput validation dataset, ignored if validate is set to false. * @param boostingStrategy boosting parameters * @param validate whether or not to use the validation dataset. * @return a gradient boosted trees model that can be used for prediction */ private def boost( input: RDD[LabeledPoint], - validateInput: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint], boostingStrategy: BoostingStrategy, - validate: Boolean = false): GradientBoostedTreesModel = { + validate: Boolean): GradientBoostedTreesModel = { val timer = new TimeTracker() timer.start("total") @@ -198,7 +200,7 @@ object GradientBoostedTrees extends Logging { // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") - var bestValidateError = if (validate) loss.computeError(startingModel, validateInput) else 0.0 + var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0 var bestM = 1 // psuedo-residual for second iteration @@ -225,19 +227,18 @@ object GradientBoostedTrees extends Logging { if (validate) { // Stop training early if - // 1. Reduction in error is lesser than the validationTol or + // 1. Reduction in error is less than the validationTol or // 2. If the error increases, that is if the model is overfit. // We want the model returned corresponding to the best validation error. - val currentValidateError = loss.computeError(partialModel, validateInput) + val currentValidateError = loss.computeError(partialModel, validationInput) if (bestValidateError - currentValidateError < validationTol) { return new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM)) - } - else if (currentValidateError < bestValidateError){ - bestValidateError = currentValidateError - bestM = m + 1 + } else if (currentValidateError < bestValidateError){ + bestValidateError = currentValidateError + bestM = m + 1 } } // Update data with pseudo-residuals diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 77e2176be9251..35b479fac5280 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -34,13 +34,11 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * weak hypotheses used in the final model. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param validationTol Useful when runWithValidation is used. If the error rate between two - iterations is lesser than the validationTol, then stop. If run - is used, then this parameter is ignored. - - a pair of RDD's are supplied to run. If the error rate - * between two iterations is lesser than convergenceTol, then training stops. + * @param validationTol Useful when runWithValidation is used. If the error rate on the + * validation input between two iterations is less than the validationTol + * then stop. Ignored when [[run]] is used. */ + @Experimental case class BoostingStrategy( // Required boosting parameters diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 3726ca3eff3bf..b4732a381f54a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -160,7 +160,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } test("runWithValidation performs better on a validation dataset (Regression)") { - // Set numIterations large enough so that it early stops. + // Set numIterations large enough so that it stops early. val numIterations = 20 val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) @@ -171,9 +171,9 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { val boostingStrategy = new BoostingStrategy(treeStrategy, error, numIterations, validationTol = 0.0) - val gbtValidate = new GradientBoostedTrees(boostingStrategy).runWithValidation( - trainRdd, validateRdd) - assert(gbtValidate.numTrees != numIterations) + val gbtValidate = new GradientBoostedTrees(boostingStrategy). + runWithValidation(trainRdd, validateRdd) + assert(gbtValidate.numTrees !== numIterations) val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) val errorWithoutValidation = error.computeError(gbt, validateRdd) @@ -183,7 +183,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } test("runWithValidation performs better on a validation dataset (Classification)") { - // Set numIterations large enough so that it early stops. + // Set numIterations large enough so that it stops early. val numIterations = 20 val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) @@ -194,9 +194,9 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { new BoostingStrategy(treeStrategy, LogLoss, numIterations, validationTol = 0.0) // Test that it stops early. - val gbtValidate = new GradientBoostedTrees(boostingStrategy).runWithValidation( - trainRdd, validateRdd) - assert(gbtValidate.numTrees != numIterations) + val gbtValidate = new GradientBoostedTrees(boostingStrategy). + runWithValidation(trainRdd, validateRdd) + assert(gbtValidate.numTrees !== numIterations) // Remap labels to {-1, 1} val remappedInput = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) @@ -213,7 +213,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { val errorWithoutValidation = LogLoss.computeError(gbtRegressor, remappedInput) assert(errorWithValidation < errorWithoutValidation) - } + } } From 1bb21d410f7bd59855b58cf606f36a43a19a7ee0 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 24 Feb 2015 16:45:25 +0530 Subject: [PATCH 8/8] Combine regression and classification tests into a single one --- docs/mllib-ensembles.md | 2 +- .../mllib/tree/GradientBoostedTrees.scala | 2 +- .../tree/configuration/BoostingStrategy.scala | 1 - .../tree/GradientBoostedTreesSuite.scala | 73 +++++++------------ 4 files changed, 27 insertions(+), 51 deletions(-) diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 1af465f5d1df4..902ba01b08ebb 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -434,7 +434,7 @@ training. The method runWithValidation has been provided to make use of this opt first one being the training dataset and the second being the validation dataset. The training is stopped when the improvement in the validation error is not more than a certain tolerance -(supplied by the validationTol argument in BoostingStrategy). In practice, the validation error +(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error decreases initially and later increases. There might be cases in which the validation error does not change monotonically, and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of iterations. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 65459707b9188..b4466ff40937f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -236,7 +236,7 @@ object GradientBoostedTrees extends Logging { boostingStrategy.treeStrategy.algo, baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM)) - } else if (currentValidateError < bestValidateError){ + } else if (currentValidateError < bestValidateError) { bestValidateError = currentValidateError bestM = m + 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 35b479fac5280..664c8df019233 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -38,7 +38,6 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * validation input between two iterations is less than the validationTol * then stop. Ignored when [[run]] is used. */ - @Experimental case class BoostingStrategy( // Required boosting parameters diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index b4732a381f54a..b437aeaaf0547 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -159,62 +159,39 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } } - test("runWithValidation performs better on a validation dataset (Regression)") { + test("runWithValidation stops early and performs better on a validation dataset") { // Set numIterations large enough so that it stops early. val numIterations = 20 val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - Array(SquaredError, AbsoluteError).foreach { error => - val boostingStrategy = - new BoostingStrategy(treeStrategy, error, numIterations, validationTol = 0.0) - - val gbtValidate = new GradientBoostedTrees(boostingStrategy). - runWithValidation(trainRdd, validateRdd) - assert(gbtValidate.numTrees !== numIterations) - - val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) - val errorWithoutValidation = error.computeError(gbt, validateRdd) - val errorWithValidation = error.computeError(gbtValidate, validateRdd) - assert(errorWithValidation < errorWithoutValidation) + val algos = Array(Regression, Regression, Classification) + val losses = Array(SquaredError, AbsoluteError, LogLoss) + (algos zip losses) map { + case (algo, loss) => { + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + assert(gbtValidate.numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) + } + } + assert(errorWithValidation <= errorWithoutValidation) + } } } - test("runWithValidation performs better on a validation dataset (Classification)") { - // Set numIterations large enough so that it stops early. - val numIterations = 20 - val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) - val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) - - val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - val boostingStrategy = - new BoostingStrategy(treeStrategy, LogLoss, numIterations, validationTol = 0.0) - - // Test that it stops early. - val gbtValidate = new GradientBoostedTrees(boostingStrategy). - runWithValidation(trainRdd, validateRdd) - assert(gbtValidate.numTrees !== numIterations) - - // Remap labels to {-1, 1} - val remappedInput = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - - // The error checked for internally in the GradientBoostedTrees is based on Regression. - // Hence for the validation model, the Classification error need not be strictly less than - // that done with validation. - val gbtValidateRegressor = new GradientBoostedTreesModel( - Regression, gbtValidate.trees, gbtValidate.treeWeights) - val errorWithValidation = LogLoss.computeError(gbtValidateRegressor, remappedInput) - - val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) - val gbtRegressor = new GradientBoostedTreesModel(Regression, gbt.trees, gbt.treeWeights) - val errorWithoutValidation = LogLoss.computeError(gbtRegressor, remappedInput) - - assert(errorWithValidation < errorWithoutValidation) - } - } private object GradientBoostedTreesSuite {