Skip to content

Commit 1bb21d4

Browse files
committed
Combine regression and classification tests into a single one
1 parent e4d799b commit 1bb21d4

File tree

4 files changed

+27
-51
lines changed

4 files changed

+27
-51
lines changed

docs/mllib-ensembles.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ training. The method runWithValidation has been provided to make use of this opt
434434
first one being the training dataset and the second being the validation dataset.
435435

436436
The training is stopped when the improvement in the validation error is not more than a certain tolerance
437-
(supplied by the validationTol argument in BoostingStrategy). In practice, the validation error
437+
(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error
438438
decreases initially and later increases. There might be cases in which the validation error does not change monotonically,
439439
and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of
440440
iterations.

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ object GradientBoostedTrees extends Logging {
236236
boostingStrategy.treeStrategy.algo,
237237
baseLearners.slice(0, bestM),
238238
baseLearnerWeights.slice(0, bestM))
239-
} else if (currentValidateError < bestValidateError){
239+
} else if (currentValidateError < bestValidateError) {
240240
bestValidateError = currentValidateError
241241
bestM = m + 1
242242
}

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
3838
* validation input between two iterations is less than the validationTol
3939
* then stop. Ignored when [[run]] is used.
4040
*/
41-
4241
@Experimental
4342
case class BoostingStrategy(
4443
// Required boosting parameters

mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala

Lines changed: 25 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -159,62 +159,39 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
159159
}
160160
}
161161

162-
test("runWithValidation performs better on a validation dataset (Regression)") {
162+
test("runWithValidation stops early and performs better on a validation dataset") {
163163
// Set numIterations large enough so that it stops early.
164164
val numIterations = 20
165165
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
166166
val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
167167

168-
val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
169-
categoricalFeaturesInfo = Map.empty)
170-
Array(SquaredError, AbsoluteError).foreach { error =>
171-
val boostingStrategy =
172-
new BoostingStrategy(treeStrategy, error, numIterations, validationTol = 0.0)
173-
174-
val gbtValidate = new GradientBoostedTrees(boostingStrategy).
175-
runWithValidation(trainRdd, validateRdd)
176-
assert(gbtValidate.numTrees !== numIterations)
177-
178-
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
179-
val errorWithoutValidation = error.computeError(gbt, validateRdd)
180-
val errorWithValidation = error.computeError(gbtValidate, validateRdd)
181-
assert(errorWithValidation < errorWithoutValidation)
168+
val algos = Array(Regression, Regression, Classification)
169+
val losses = Array(SquaredError, AbsoluteError, LogLoss)
170+
(algos zip losses) map {
171+
case (algo, loss) => {
172+
val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
173+
categoricalFeaturesInfo = Map.empty)
174+
val boostingStrategy =
175+
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
176+
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
177+
.runWithValidation(trainRdd, validateRdd)
178+
assert(gbtValidate.numTrees !== numIterations)
179+
180+
// Test that it performs better on the validation dataset.
181+
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
182+
val (errorWithoutValidation, errorWithValidation) = {
183+
if (algo == Classification) {
184+
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
185+
(loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
186+
} else {
187+
(loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
188+
}
189+
}
190+
assert(errorWithValidation <= errorWithoutValidation)
191+
}
182192
}
183193
}
184194

185-
test("runWithValidation performs better on a validation dataset (Classification)") {
186-
// Set numIterations large enough so that it stops early.
187-
val numIterations = 20
188-
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
189-
val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
190-
191-
val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2,
192-
categoricalFeaturesInfo = Map.empty)
193-
val boostingStrategy =
194-
new BoostingStrategy(treeStrategy, LogLoss, numIterations, validationTol = 0.0)
195-
196-
// Test that it stops early.
197-
val gbtValidate = new GradientBoostedTrees(boostingStrategy).
198-
runWithValidation(trainRdd, validateRdd)
199-
assert(gbtValidate.numTrees !== numIterations)
200-
201-
// Remap labels to {-1, 1}
202-
val remappedInput = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
203-
204-
// The error checked for internally in the GradientBoostedTrees is based on Regression.
205-
// Hence for the validation model, the Classification error need not be strictly less than
206-
// that done with validation.
207-
val gbtValidateRegressor = new GradientBoostedTreesModel(
208-
Regression, gbtValidate.trees, gbtValidate.treeWeights)
209-
val errorWithValidation = LogLoss.computeError(gbtValidateRegressor, remappedInput)
210-
211-
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
212-
val gbtRegressor = new GradientBoostedTreesModel(Regression, gbt.trees, gbt.treeWeights)
213-
val errorWithoutValidation = LogLoss.computeError(gbtRegressor, remappedInput)
214-
215-
assert(errorWithValidation < errorWithoutValidation)
216-
}
217-
218195
}
219196

220197
private object GradientBoostedTreesSuite {

0 commit comments

Comments
 (0)