@@ -81,7 +81,6 @@ class LinearRegression(override val uid: String)
8181 def setFitIntercept (value : Boolean ): this .type = set(fitIntercept, value)
8282 setDefault(fitIntercept -> true )
8383
84-
8584 /**
8685 * Set the ElasticNet mixing parameter.
8786 * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
@@ -131,7 +130,7 @@ class LinearRegression(override val uid: String)
131130 })
132131
133132 val numFeatures = summarizer.mean.size
134- val yMean = if ($(fitIntercept)) statCounter.mean else 0.0
133+ val yMean = statCounter.mean
135134 val yStd = math.sqrt(statCounter.variance)
136135 // look at glmnet5.m L761 maaaybe that has info
137136
@@ -144,11 +143,7 @@ class LinearRegression(override val uid: String)
144143 return new LinearRegressionModel (uid, Vectors .sparse(numFeatures, Seq ()), yMean)
145144 }
146145
147- val featuresMean = if ($(fitIntercept)) {
148- summarizer.mean.toArray
149- } else {
150- new Array [Double ](numFeatures)
151- }
146+ val featuresMean = summarizer.mean.toArray
152147 val featuresStd = summarizer.variance.toArray.map(math.sqrt)
153148
154149 // Since we implicitly do the feature scaling when we compute the cost function
@@ -157,7 +152,7 @@ class LinearRegression(override val uid: String)
157152 val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
158153 val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
159154
160- val costFun = new LeastSquaresCostFun (instances, yStd, yMean,
155+ val costFun = new LeastSquaresCostFun (instances, yStd, yMean, $(fitIntercept),
161156 featuresStd, featuresMean, effectiveL2RegParam)
162157
163158 val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0 ) {
@@ -195,7 +190,7 @@ class LinearRegression(override val uid: String)
195190 // The intercept in R's GLMNET is computed using closed form after the coefficients are
196191 // converged. See the following discussion for detail.
197192 // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
198- val intercept = yMean - dot(weights, Vectors .dense(featuresMean))
193+ val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors .dense(featuresMean)) else 0.0
199194 if (handlePersistence) instances.unpersist()
200195
201196 // TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
@@ -320,6 +315,7 @@ private class LeastSquaresAggregator(
320315 weights : Vector ,
321316 labelStd : Double ,
322317 labelMean : Double ,
318+ fitIntercept : Boolean ,
323319 featuresStd : Array [Double ],
324320 featuresMean : Array [Double ]) extends Serializable {
325321
@@ -340,7 +336,7 @@ private class LeastSquaresAggregator(
340336 }
341337 i += 1
342338 }
343- (weightsArray, - sum + labelMean / labelStd, weightsArray.length)
339+ (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0 , weightsArray.length)
344340 }
345341
346342 private val effectiveWeightsVector = Vectors .dense(effectiveWeightsArray)
@@ -423,6 +419,7 @@ private class LeastSquaresCostFun(
423419 data : RDD [(Double , Vector )],
424420 labelStd : Double ,
425421 labelMean : Double ,
422+ fitIntercept : Boolean ,
426423 featuresStd : Array [Double ],
427424 featuresMean : Array [Double ],
428425 effectiveL2regParam : Double ) extends DiffFunction [BDV [Double ]] {
@@ -431,7 +428,7 @@ private class LeastSquaresCostFun(
431428 val w = Vectors .fromBreeze(weights)
432429
433430 val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator (w, labelStd,
434- labelMean, featuresStd, featuresMean))(
431+ labelMean, fitIntercept, featuresStd, featuresMean))(
435432 seqOp = (c, v) => (c, v) match {
436433 case (aggregator, (label, features)) => aggregator.add(label, features)
437434 },
0 commit comments