@@ -159,7 +159,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
159159 }
160160 }
161161
162- test(" Early stopping when validation data is provided. " ) {
162+ test(" runWithValidation performs better on a validation dataset (Regression) " ) {
163163 // Set numIterations large enough so that it early stops.
164164 val numIterations = 20
165165 val trainRdd = sc.parallelize(GradientBoostedTreesSuite .trainData, 2 )
@@ -180,8 +180,41 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
180180 val errorWithValidation = error.computeError(gbtValidate, validateRdd)
181181 assert(errorWithValidation < errorWithoutValidation)
182182 }
183-
184183 }
184+
185+ test(" runWithValidation performs better on a validation dataset (Classification)" ) {
186+ // Set numIterations large enough so that it early stops.
187+ val numIterations = 20
188+ val trainRdd = sc.parallelize(GradientBoostedTreesSuite .trainData, 2 )
189+ val validateRdd = sc.parallelize(GradientBoostedTreesSuite .validateData, 2 )
190+
191+ val treeStrategy = new Strategy (algo = Classification , impurity = Variance , maxDepth = 2 ,
192+ categoricalFeaturesInfo = Map .empty)
193+ val boostingStrategy =
194+ new BoostingStrategy (treeStrategy, LogLoss , numIterations, validationTol = 0.0 )
195+
196+ // Test that it stops early.
197+ val gbtValidate = new GradientBoostedTrees (boostingStrategy).runWithValidation(
198+ trainRdd, validateRdd)
199+ assert(gbtValidate.numTrees != numIterations)
200+
201+ // Remap labels to {-1, 1}
202+ val remappedInput = validateRdd.map(x => new LabeledPoint (2 * x.label - 1 , x.features))
203+
204+ // The error checked for internally in the GradientBoostedTrees is based on Regression.
205+ // Hence for the validation model, the Classification error need not be strictly less than
206+ // that done with validation.
207+ val gbtValidateRegressor = new GradientBoostedTreesModel (
208+ Regression , gbtValidate.trees, gbtValidate.treeWeights)
209+ val errorWithValidation = LogLoss .computeError(gbtValidateRegressor, remappedInput)
210+
211+ val gbt = GradientBoostedTrees .train(trainRdd, boostingStrategy)
212+ val gbtRegressor = new GradientBoostedTreesModel (Regression , gbt.trees, gbt.treeWeights)
213+ val errorWithoutValidation = LogLoss .computeError(gbtRegressor, remappedInput)
214+
215+ assert(errorWithValidation < errorWithoutValidation)
216+ }
217+
185218}
186219
187220private object GradientBoostedTreesSuite {
0 commit comments