Skip to content

Commit bc355e2

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-8736] [ML] GBTRegressor should not threshold prediction
Changed GBTRegressor so it does NOT threshold the prediction. Added test which fails with bug but works after fix. CC: feynmanliang mengxr Author: Joseph K. Bradley <[email protected]> Closes #7134 from jkbradley/gbrt-fix and squashes the following commits: 613b90e [Joseph K. Bradley] Changed GBTRegressor so it does NOT threshold the prediction (cherry picked from commit 3ba23ff) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 80b0fe2 commit bc355e2

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ final class GBTRegressionModel(
172172
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
173173
// Classifies by thresholding sum of weighted tree predictions
174174
val treePredictions = _trees.map(_.rootNode.predict(features))
175-
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
176-
if (prediction > 0.0) 1.0 else 0.0
175+
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
177176
}
178177

179178
override def copy(extra: ParamMap): GBTRegressionModel = {

mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ package org.apache.spark.ml.regression
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.ml.impl.TreeTests
22+
import org.apache.spark.mllib.linalg.Vectors
2223
import org.apache.spark.mllib.regression.LabeledPoint
2324
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
2425
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
2526
import org.apache.spark.mllib.util.MLlibTestSparkContext
2627
import org.apache.spark.rdd.RDD
27-
import org.apache.spark.sql.DataFrame
28+
import org.apache.spark.sql.{DataFrame, Row}
2829

2930

3031
/**
@@ -67,6 +68,26 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
6768
}
6869
}
6970

71+
test("GBTRegressor behaves reasonably on toy data") {
72+
val df = sqlContext.createDataFrame(Seq(
73+
LabeledPoint(10, Vectors.dense(1, 2, 3, 4)),
74+
LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)),
75+
LabeledPoint(11, Vectors.dense(2, 2, 3, 4)),
76+
LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)),
77+
LabeledPoint(9, Vectors.dense(1, 2, 6, 4)),
78+
LabeledPoint(-4, Vectors.dense(6, 3, 2, 2))
79+
))
80+
val gbt = new GBTRegressor()
81+
.setMaxDepth(2)
82+
.setMaxIter(2)
83+
val model = gbt.fit(df)
84+
val preds = model.transform(df)
85+
val predictions = preds.select("prediction").map(_.getDouble(0))
86+
// Checks based on SPARK-8736 (to ensure it is not doing classification)
87+
assert(predictions.max() > 2)
88+
assert(predictions.min() < -1)
89+
}
90+
7091
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
7192
/*
7293
test("runWithValidation stops early and performs better on a validation dataset") {

0 commit comments

Comments
 (0)