@@ -159,62 +159,39 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
159159 }
160160 }
161161
162- test(" runWithValidation performs better on a validation dataset (Regression) " ) {
162+ test(" runWithValidation stops early and performs better on a validation dataset" ) {
163163 // Set numIterations large enough so that it stops early.
164164 val numIterations = 20
165165 val trainRdd = sc.parallelize(GradientBoostedTreesSuite .trainData, 2 )
166166 val validateRdd = sc.parallelize(GradientBoostedTreesSuite .validateData, 2 )
167167
168- val treeStrategy = new Strategy (algo = Regression , impurity = Variance , maxDepth = 2 ,
169- categoricalFeaturesInfo = Map .empty)
170- Array (SquaredError , AbsoluteError ).foreach { error =>
171- val boostingStrategy =
172- new BoostingStrategy (treeStrategy, error, numIterations, validationTol = 0.0 )
173-
174- val gbtValidate = new GradientBoostedTrees (boostingStrategy).
175- runWithValidation(trainRdd, validateRdd)
176- assert(gbtValidate.numTrees !== numIterations)
177-
178- val gbt = GradientBoostedTrees .train(trainRdd, boostingStrategy)
179- val errorWithoutValidation = error.computeError(gbt, validateRdd)
180- val errorWithValidation = error.computeError(gbtValidate, validateRdd)
181- assert(errorWithValidation < errorWithoutValidation)
168+ val algos = Array (Regression , Regression , Classification )
169+ val losses = Array (SquaredError , AbsoluteError , LogLoss )
170+ (algos zip losses) map {
171+ case (algo, loss) => {
172+ val treeStrategy = new Strategy (algo = algo, impurity = Variance , maxDepth = 2 ,
173+ categoricalFeaturesInfo = Map .empty)
174+ val boostingStrategy =
175+ new BoostingStrategy (treeStrategy, loss, numIterations, validationTol = 0.0 )
176+ val gbtValidate = new GradientBoostedTrees (boostingStrategy)
177+ .runWithValidation(trainRdd, validateRdd)
178+ assert(gbtValidate.numTrees !== numIterations)
179+
180+ // Test that it performs better on the validation dataset.
181+ val gbt = GradientBoostedTrees .train(trainRdd, boostingStrategy)
182+ val (errorWithoutValidation, errorWithValidation) = {
183+ if (algo == Classification ) {
184+ val remappedRdd = validateRdd.map(x => new LabeledPoint (2 * x.label - 1 , x.features))
185+ (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
186+ } else {
187+ (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
188+ }
189+ }
190+ assert(errorWithValidation <= errorWithoutValidation)
191+ }
182192 }
183193 }
184194
185- test(" runWithValidation performs better on a validation dataset (Classification)" ) {
186- // Set numIterations large enough so that it stops early.
187- val numIterations = 20
188- val trainRdd = sc.parallelize(GradientBoostedTreesSuite .trainData, 2 )
189- val validateRdd = sc.parallelize(GradientBoostedTreesSuite .validateData, 2 )
190-
191- val treeStrategy = new Strategy (algo = Classification , impurity = Variance , maxDepth = 2 ,
192- categoricalFeaturesInfo = Map .empty)
193- val boostingStrategy =
194- new BoostingStrategy (treeStrategy, LogLoss , numIterations, validationTol = 0.0 )
195-
196- // Test that it stops early.
197- val gbtValidate = new GradientBoostedTrees (boostingStrategy).
198- runWithValidation(trainRdd, validateRdd)
199- assert(gbtValidate.numTrees !== numIterations)
200-
201- // Remap labels to {-1, 1}
202- val remappedInput = validateRdd.map(x => new LabeledPoint (2 * x.label - 1 , x.features))
203-
204- // The error checked for internally in the GradientBoostedTrees is based on Regression.
205- // Hence for the validation model, the Classification error need not be strictly less than
206- // that done with validation.
207- val gbtValidateRegressor = new GradientBoostedTreesModel (
208- Regression , gbtValidate.trees, gbtValidate.treeWeights)
209- val errorWithValidation = LogLoss .computeError(gbtValidateRegressor, remappedInput)
210-
211- val gbt = GradientBoostedTrees .train(trainRdd, boostingStrategy)
212- val gbtRegressor = new GradientBoostedTreesModel (Regression , gbt.trees, gbt.treeWeights)
213- val errorWithoutValidation = LogLoss .computeError(gbtRegressor, remappedInput)
214-
215- assert(errorWithValidation < errorWithoutValidation)
216- }
217-
218195}
219196
220197private object GradientBoostedTreesSuite {
0 commit comments