1717
1818package org .apache .spark .ml .regression
1919
20- import scala .collection .mutable . ArrayBuffer
20+ import scala .collection .mutable
2121
2222import breeze .linalg .{norm => brzNorm , DenseVector => BDV }
2323import breeze .optimize .{LBFGS => BreezeLBFGS , OWLQN => BreezeOWLQN }
@@ -79,11 +79,11 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
7979 /**
8080 * Set the convergence tolerance of iterations.
8181 * Smaller value will lead to higher accuracy with the cost of more iterations.
82- * Default is 1E-9 .
82+ * Default is 1E-6 .
8383 * @group setParam
8484 */
8585 def setTol (value : Double ): this .type = set(tol, value)
86- setDefault(tol -> 1E-9 )
86+ setDefault(tol -> 1E-6 )
8787
8888 override protected def train (dataset : DataFrame , paramMap : ParamMap ): LinearRegressionModel = {
8989 // Extract columns from data. If dataset is persisted, do not persist instances.
@@ -119,12 +119,8 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
119119 val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam
120120 val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam
121121
122- val costFun = new LeastSquaresCostFun (
123- instances,
124- yStd, yMean,
125- featuresStd,
126- featuresMean,
127- effectiveL2RegParam)
122+ val costFun = new LeastSquaresCostFun (instances, yStd, yMean,
123+ featuresStd, featuresMean, effectiveL2RegParam)
128124
129125 val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0 ) {
130126 new BreezeLBFGS [BDV [Double ]](paramMap(maxIter), 10 , paramMap(tol))
@@ -137,12 +133,13 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
137133 optimizer.iterations(new CachedDiffFunction (costFun), initialWeights.toBreeze.toDenseVector)
138134
139135 var state = states.next()
140- val lossHistory = new ArrayBuffer [Double ](paramMap(maxIter))
141- while (states.hasNext) {
142- lossHistory.append(state.value)
136+ val lossHistory = mutable.ArrayBuilder .make[Double ]
137+
138+ while (states.hasNext) {
139+ lossHistory += state.value
143140 state = states.next()
144141 }
145- lossHistory.append( state.value)
142+ lossHistory += state.value
146143
147144 // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector.
148145 // The weights are trained in the scaled space; we're converting them back to
@@ -151,7 +148,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
151148 val rawWeights = state.x.toArray.clone()
152149 var i = 0
153150 while (i < rawWeights.length) {
154- rawWeights(i) = if (featuresStd(i) != 0.0 ) rawWeights(i) * yStd / featuresStd(i) else 0.0
151+ rawWeights(i) *= { if (featuresStd(i) != 0.0 ) yStd / featuresStd(i) else 0.0 }
155152 i += 1
156153 }
157154 Vectors .dense(rawWeights)
@@ -218,7 +215,7 @@ private class LeastSquaresAggregator(
218215 }
219216 i += 1
220217 }
221- (weightsArray, - sum, weightsArray.length)
218+ (weightsArray, - sum + labelMean / labelStd , weightsArray.length)
222219 }
223220 private val effectiveWeightsVector = Vectors .dense(effectiveWeightsArray)
224221
@@ -237,7 +234,7 @@ private class LeastSquaresAggregator(
237234 require(dim == data.size, s " Dimensions mismatch when adding new sample. " +
238235 s " Expecting $dim but got ${data.size}. " )
239236
240- val diff = dot(data, effectiveWeightsVector) - ( label - labelMean) / labelStd + offset
237+ val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset
241238
242239 if (diff != 0 ) {
243240 val localGradientSumArray = gradientSumArray
@@ -266,7 +263,7 @@ private class LeastSquaresAggregator(
266263 require(dim == other.dim, s " Dimensions mismatch when merging with another " +
267264 s " LeastSquaresAggregator. Expecting $dim but got ${other.dim}. " )
268265
269- if (this .totalCnt != 0 && other.totalCnt != 0 ) {
266+ if (other.totalCnt != 0 ) {
270267 totalCnt += other.totalCnt
271268 lossSum += other.lossSum
272269 diffSum += other.diffSum
@@ -278,11 +275,6 @@ private class LeastSquaresAggregator(
278275 localThisGradientSumArray(i) += localOtherGradientSumArray(i)
279276 i += 1
280277 }
281- } else if (totalCnt == 0 && other.totalCnt != 0 ) {
282- this .totalCnt = other.totalCnt
283- this .lossSum = other.lossSum
284- this .diffSum = other.diffSum
285- System .arraycopy(other.gradientSumArray, 0 , this .gradientSumArray, 0 , dim)
286278 }
287279 this
288280 }
@@ -304,7 +296,7 @@ private class LeastSquaresAggregator(
304296 Vectors .dense(temp)
305297 }
306298
307- axpy(- diffSum, result, correction )
299+ axpy(- diffSum, correction, result )
308300 scal(1.0 / totalCnt, result)
309301 result
310302 }
@@ -319,27 +311,25 @@ private class LeastSquaresCostFun(
319311 effectiveL2regParam : Double ) extends DiffFunction [BDV [Double ]] {
320312
321313 override def calculate (weights : BDV [Double ]): (Double , BDV [Double ]) = {
322- val w = Vectors .fromBreeze(weights)
323-
324- val leastSquaresAggregator = data.treeAggregate(
325- new LeastSquaresAggregator (w, labelStd, labelMean, featuresStd, featuresMean))(
326- seqOp = (c, v) => (c, v) match {
327- case (aggregator, (label, features)) => aggregator.add(label, features)
328- },
329- combOp = (c1, c2) => (c1, c2) match {
330- case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
331- })
332-
333- // regVal is sum of weight squares for L2 regularization
334- val norm = brzNorm(weights, 2.0 )
335- val regVal = 0.5 * effectiveL2regParam * norm * norm
336-
337- val loss = leastSquaresAggregator.loss + regVal
338-
339- val gradientTotal = w.copy
340- scal(effectiveL2regParam, gradientTotal)
341- axpy(1.0 , leastSquaresAggregator.gradient, gradientTotal)
342-
343- (loss, gradientTotal.toBreeze.asInstanceOf [BDV [Double ]])
344- }
314+ val w = Vectors .fromBreeze(weights)
315+
316+ val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator (w, labelStd,
317+ labelMean, featuresStd, featuresMean))(
318+ seqOp = (c, v) => (c, v) match {
319+ case (aggregator, (label, features)) => aggregator.add(label, features)
320+ },
321+ combOp = (c1, c2) => (c1, c2) match {
322+ case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
323+ })
324+
325+ // regVal is sum of weight squares for L2 regularization
326+ val norm = brzNorm(weights, 2.0 )
327+ val regVal = 0.5 * effectiveL2regParam * norm * norm
328+
329+ val loss = leastSquaresAggregator.loss + regVal
330+ val gradient = leastSquaresAggregator.gradient
331+ axpy(effectiveL2regParam, w, gradient)
332+
333+ (loss, gradient.toBreeze.asInstanceOf [BDV [Double ]])
334+ }
345335}
0 commit comments