Skip to content

Commit 594a2cf

Browse files
sethahyanboliang
authored andcommitted
[SPARK-17792][ML] L-BFGS solver for linear regression does not accept general numeric label column types
## What changes were proposed in this pull request? Before, we computed `instances` in LinearRegression in two spots, even though they did the same thing. One of them did not cast the label column to `DoubleType`. This patch consolidates the computation and always casts the label column to `DoubleType`. ## How was this patch tested? Added a unit test to check all solvers. This test failed before this patch. Author: sethah <[email protected]> Closes #15364 from sethah/linreg_numeric_type. (cherry picked from commit 3713bb1) Signed-off-by: Yanbo Liang <[email protected]>
1 parent b1a9c41 commit 594a2cf

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -163,17 +163,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
163163
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
164164
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
165165

166+
val instances: RDD[Instance] = dataset.select(
167+
col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
168+
case Row(label: Double, weight: Double, features: Vector) =>
169+
Instance(label, weight, features)
170+
}
171+
166172
if (($(solver) == "auto" && $(elasticNetParam) == 0.0 &&
167173
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") {
168174
require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " +
169175
"solver is used.'")
170176
// For low dimensional data, WeightedLeastSquares is more efficiently since the
171177
// training algorithm only requires one pass through the data. (SPARK-10668)
172-
val instances: RDD[Instance] = dataset.select(
173-
col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
174-
case Row(label: Double, weight: Double, features: Vector) =>
175-
Instance(label, weight, features)
176-
}
177178

178179
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
179180
$(standardization), true)
@@ -196,12 +197,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
196197
return lrModel.setSummary(trainingSummary)
197198
}
198199

199-
val instances: RDD[Instance] =
200-
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
201-
case Row(label: Double, weight: Double, features: Vector) =>
202-
Instance(label, weight, features)
203-
}
204-
205200
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
206201
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
207202

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,12 +1019,14 @@ class LinearRegressionSuite
10191019
}
10201020

10211021
test("should support all NumericType labels and not support other types") {
1022-
val lr = new LinearRegression().setMaxIter(1)
1023-
MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
1024-
lr, spark, isClassification = false) { (expected, actual) =>
1022+
for (solver <- Seq("auto", "l-bfgs", "normal")) {
1023+
val lr = new LinearRegression().setMaxIter(1).setSolver(solver)
1024+
MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
1025+
lr, spark, isClassification = false) { (expected, actual) =>
10251026
assert(expected.intercept === actual.intercept)
10261027
assert(expected.coefficients === actual.coefficients)
10271028
}
1029+
}
10281030
}
10291031
}
10301032

0 commit comments

Comments
 (0)