Skip to content

Commit ae5baa8

Browse files
committed
CR feedback, move the fitIntercept down rather than changing ymean and etc above
1 parent f34971c commit ae5baa8

File tree

2 files changed

+21
-26
lines changed

2 files changed

+21
-26
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ class LinearRegression(override val uid: String)
8181
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
8282
setDefault(fitIntercept -> true)
8383

84-
8584
/**
8685
* Set the ElasticNet mixing parameter.
8786
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
@@ -131,7 +130,7 @@ class LinearRegression(override val uid: String)
131130
})
132131

133132
val numFeatures = summarizer.mean.size
134-
val yMean = if ($(fitIntercept)) statCounter.mean else 0.0
133+
val yMean = statCounter.mean
135134
val yStd = math.sqrt(statCounter.variance)
136135
// look at glmnet5.m L761 maaaybe that has info
137136

@@ -144,11 +143,7 @@ class LinearRegression(override val uid: String)
144143
return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean)
145144
}
146145

147-
val featuresMean = if ($(fitIntercept)) {
148-
summarizer.mean.toArray
149-
} else {
150-
new Array[Double](numFeatures)
151-
}
146+
val featuresMean = summarizer.mean.toArray
152147
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
153148

154149
// Since we implicitly do the feature scaling when we compute the cost function
@@ -157,7 +152,7 @@ class LinearRegression(override val uid: String)
157152
val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
158153
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
159154

160-
val costFun = new LeastSquaresCostFun(instances, yStd, yMean,
155+
val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
161156
featuresStd, featuresMean, effectiveL2RegParam)
162157

163158
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
@@ -195,7 +190,7 @@ class LinearRegression(override val uid: String)
195190
// The intercept in R's GLMNET is computed using closed form after the coefficients are
196191
// converged. See the following discussion for detail.
197192
// http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
198-
val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
193+
val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
199194
if (handlePersistence) instances.unpersist()
200195

201196
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
@@ -320,6 +315,7 @@ private class LeastSquaresAggregator(
320315
weights: Vector,
321316
labelStd: Double,
322317
labelMean: Double,
318+
fitIntercept: Boolean,
323319
featuresStd: Array[Double],
324320
featuresMean: Array[Double]) extends Serializable {
325321

@@ -340,7 +336,7 @@ private class LeastSquaresAggregator(
340336
}
341337
i += 1
342338
}
343-
(weightsArray, -sum + labelMean / labelStd, weightsArray.length)
339+
(weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length)
344340
}
345341

346342
private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
@@ -423,6 +419,7 @@ private class LeastSquaresCostFun(
423419
data: RDD[(Double, Vector)],
424420
labelStd: Double,
425421
labelMean: Double,
422+
fitIntercept: Boolean,
426423
featuresStd: Array[Double],
427424
featuresMean: Array[Double],
428425
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
@@ -431,7 +428,7 @@ private class LeastSquaresCostFun(
431428
val w = Vectors.fromBreeze(weights)
432429

433430
val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
434-
labelMean, featuresStd, featuresMean))(
431+
labelMean, fitIntercept, featuresStd, featuresMean))(
435432
seqOp = (c, v) => (c, v) match {
436433
case (aggregator, (label, features)) => aggregator.add(label, features)
437434
},

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.{DataFrame, Row}
2626
class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
2727

2828
@transient var dataset: DataFrame = _
29-
@transient var datasetNR: DataFrame = _
29+
@transient var datasetWithoutIntercept: DataFrame = _
3030

3131
/**
3232
* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -45,7 +45,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
4545
dataset = sqlContext.createDataFrame(
4646
sc.parallelize(LinearDataGenerator.generateLinearInput(
4747
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
48-
datasetNR = sqlContext.createDataFrame(
48+
/**
49+
* datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
50+
* training model without intercept
51+
*/
52+
datasetWithoutIntercept = sqlContext.createDataFrame(
4953
sc.parallelize(LinearDataGenerator.generateLinearInput(
5054
0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
5155

@@ -88,15 +92,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
8892
test("linear regression without intercept without regularization") {
8993
val trainer = (new LinearRegression).setFitIntercept(false)
9094
val model = trainer.fit(dataset)
91-
val modelNR = trainer.fit(datasetNR)
95+
val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)
9296

9397
/**
94-
* Using the following R code to load the data and train the model using glmnet package.
95-
*
96-
* library("glmnet")
97-
* data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
98-
* features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
99-
* label <- as.numeric(data$V1)
10098
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
10199
* intercept = FALSE))
102100
* > weights
@@ -113,18 +111,18 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
113111
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
114112
/**
115113
* Then again with the data with no intercept:
116-
* > weightsNR
114+
* > weightsWithoutIntercept
117115
* 3 x 1 sparse Matrix of class "dgCMatrix"
118116
* s0
119117
* (Intercept) .
120118
* as.numeric.data3.V2. 4.70011
121119
* as.numeric.data3.V3. 7.19943
122120
*/
123-
val weightsRNR = Array(4.70011, 7.19943)
121+
val weightsWithoutInterceptR = Array(4.70011, 7.19943)
124122

125-
assert(modelNR.intercept ~== 0 relTol 1E-3)
126-
assert(modelNR.weights(0) ~== weightsRNR(0) relTol 1E-3)
127-
assert(modelNR.weights(1) ~== weightsRNR(1) relTol 1E-3)
123+
assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3)
124+
assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3)
125+
assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3)
128126
}
129127

130128

@@ -186,7 +184,6 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
186184
}
187185
}
188186

189-
190187
test("linear regression with intercept with L2 regularization") {
191188
val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
192189
val model = trainer.fit(dataset)
@@ -272,6 +269,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
272269
assert(prediction1 ~== prediction2 relTol 1E-5)
273270
}
274271
}
272+
275273
test("linear regression without intercept with ElasticNet regularization") {
276274
val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
277275
.setFitIntercept(false)

0 commit comments

Comments
 (0)