Skip to content

Commit 3e74372

Browse files
committed
TST: Add test for classification
1 parent 77549a9 commit 3e74372

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,7 @@ object GradientBoostedTrees extends Logging {
227227
// Note: A model of type regression is used since we require raw prediction
228228
val partialModel = new GradientBoostedTreesModel(
229229
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
230-
val errorModel = loss.computeError(partialModel, input)
231-
logDebug("error of gbt = " + errorModel)
230+
logDebug("error of gbt = " + loss.computeError(partialModel, input))
232231

233232
if (validate) {
234233
// Stop training early if

mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
159159
}
160160
}
161161

162-
test("Early stopping when validation data is provided.") {
162+
test("runWithValidation performs better on a validation dataset (Regression)") {
163163
// Set numIterations large enough so that it early stops.
164164
val numIterations = 20
165165
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
@@ -180,8 +180,41 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
180180
val errorWithValidation = error.computeError(gbtValidate, validateRdd)
181181
assert(errorWithValidation < errorWithoutValidation)
182182
}
183-
184183
}
184+
185+
test("runWithValidation performs better on a validation dataset (Classification)") {
186+
// Set numIterations large enough so that it early stops.
187+
val numIterations = 20
188+
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
189+
val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
190+
191+
val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2,
192+
categoricalFeaturesInfo = Map.empty)
193+
val boostingStrategy =
194+
new BoostingStrategy(treeStrategy, LogLoss, numIterations, validationTol = 0.0)
195+
196+
// Test that it stops early.
197+
val gbtValidate = new GradientBoostedTrees(boostingStrategy).runWithValidation(
198+
trainRdd, validateRdd)
199+
assert(gbtValidate.numTrees != numIterations)
200+
201+
// Remap labels to {-1, 1}
202+
val remappedInput = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
203+
204+
// The error checked for internally in the GradientBoostedTrees is based on Regression.
205+
// Hence for the validation model, the Classification error need not be strictly less than
206+
// that done with validation.
207+
val gbtValidateRegressor = new GradientBoostedTreesModel(
208+
Regression, gbtValidate.trees, gbtValidate.treeWeights)
209+
val errorWithValidation = LogLoss.computeError(gbtValidateRegressor, remappedInput)
210+
211+
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
212+
val gbtRegressor = new GradientBoostedTreesModel(Regression, gbt.trees, gbt.treeWeights)
213+
val errorWithoutValidation = LogLoss.computeError(gbtRegressor, remappedInput)
214+
215+
assert(errorWithValidation < errorWithoutValidation)
216+
}
217+
185218
}
186219

187220
private object GradientBoostedTreesSuite {

0 commit comments

Comments
 (0)