@@ -41,7 +41,8 @@ import org.apache.spark.util.StatCounter
4141 * Params for linear regression.
4242 */
4343private [regression] trait LinearRegressionParams extends PredictorParams
44- with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
44+ with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol with
45+ HasIntercept
4546
4647/**
4748 * :: Experimental ::
@@ -121,8 +122,9 @@ class LinearRegression(override val uid: String)
121122 })
122123
123124 val numFeatures = summarizer.mean.size
124- val yMean = statCounter.mean
125- val yStd = math.sqrt(statCounter.variance)
125+ val yMean = if (hasIntercept) statCounter.mean else 0.0
126+ val yStd = if (hasIntercept) math.sqrt(statCounter.variance) else
127+ // look at glmnet6.m L761 maaaybe that has info
126128
127129 // If the yStd is zero, then the intercept is yMean with zero weights;
128130 // as a result, training is not needed.
@@ -180,6 +182,7 @@ class LinearRegression(override val uid: String)
180182 // The intercept in R's GLMNET is computed using closed form after the coefficients are
181183 // converged. See the following discussion for detail.
182184 // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
185+ // Also see the scikit learn impl at https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/base.py
183186 val intercept = yMean - dot(weights, Vectors .dense(featuresMean))
184187 if (handlePersistence) instances.unpersist()
185188
0 commit comments