Skip to content

Commit 9988ca8

Browse files
author
DB Tsai
committed
addressed feedback and fixed a bug. TODO: documentation and build another
synthetic dataset which can catch the bug fixed in this commit.
1 parent fcbaefe commit 9988ca8

File tree

2 files changed

+46
-53
lines changed

2 files changed

+46
-53
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 36 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ml.regression
1919

20-
import scala.collection.mutable.ArrayBuffer
20+
import scala.collection.mutable
2121

2222
import breeze.linalg.{norm => brzNorm, DenseVector => BDV}
2323
import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
@@ -79,11 +79,11 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
7979
/**
8080
* Set the convergence tolerance of iterations.
8181
* Smaller value will lead to higher accuracy with the cost of more iterations.
82-
* Default is 1E-9.
82+
* Default is 1E-6.
8383
* @group setParam
8484
*/
8585
def setTol(value: Double): this.type = set(tol, value)
86-
setDefault(tol -> 1E-9)
86+
setDefault(tol -> 1E-6)
8787

8888
override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
8989
// Extract columns from data. If dataset is persisted, do not persist instances.
@@ -119,12 +119,8 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
119119
val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam
120120
val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam
121121

122-
val costFun = new LeastSquaresCostFun(
123-
instances,
124-
yStd, yMean,
125-
featuresStd,
126-
featuresMean,
127-
effectiveL2RegParam)
122+
val costFun = new LeastSquaresCostFun(instances, yStd, yMean,
123+
featuresStd, featuresMean, effectiveL2RegParam)
128124

129125
val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
130126
new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol))
@@ -137,12 +133,13 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
137133
optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
138134

139135
var state = states.next()
140-
val lossHistory = new ArrayBuffer[Double](paramMap(maxIter))
141-
while(states.hasNext) {
142-
lossHistory.append(state.value)
136+
val lossHistory = mutable.ArrayBuilder.make[Double]
137+
138+
while (states.hasNext) {
139+
lossHistory += state.value
143140
state = states.next()
144141
}
145-
lossHistory.append(state.value)
142+
lossHistory += state.value
146143

147144
// TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector.
148145
// The weights are trained in the scaled space; we're converting them back to
@@ -151,7 +148,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
151148
val rawWeights = state.x.toArray.clone()
152149
var i = 0
153150
while (i < rawWeights.length) {
154-
rawWeights(i) = if (featuresStd(i) != 0.0) rawWeights(i) * yStd / featuresStd(i) else 0.0
151+
rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
155152
i += 1
156153
}
157154
Vectors.dense(rawWeights)
@@ -218,7 +215,7 @@ private class LeastSquaresAggregator(
218215
}
219216
i += 1
220217
}
221-
(weightsArray, -sum, weightsArray.length)
218+
(weightsArray, -sum + labelMean / labelStd, weightsArray.length)
222219
}
223220
private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
224221

@@ -237,7 +234,7 @@ private class LeastSquaresAggregator(
237234
require(dim == data.size, s"Dimensions mismatch when adding new sample." +
238235
s" Expecting $dim but got ${data.size}.")
239236

240-
val diff = dot(data, effectiveWeightsVector) - (label - labelMean) / labelStd + offset
237+
val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset
241238

242239
if (diff != 0) {
243240
val localGradientSumArray = gradientSumArray
@@ -266,7 +263,7 @@ private class LeastSquaresAggregator(
266263
require(dim == other.dim, s"Dimensions mismatch when merging with another " +
267264
s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.")
268265

269-
if (this.totalCnt != 0 && other.totalCnt != 0) {
266+
if (other.totalCnt != 0) {
270267
totalCnt += other.totalCnt
271268
lossSum += other.lossSum
272269
diffSum += other.diffSum
@@ -278,11 +275,6 @@ private class LeastSquaresAggregator(
278275
localThisGradientSumArray(i) += localOtherGradientSumArray(i)
279276
i += 1
280277
}
281-
} else if (totalCnt == 0 && other.totalCnt != 0) {
282-
this.totalCnt = other.totalCnt
283-
this.lossSum = other.lossSum
284-
this.diffSum = other.diffSum
285-
System.arraycopy(other.gradientSumArray, 0, this.gradientSumArray, 0, dim)
286278
}
287279
this
288280
}
@@ -304,7 +296,7 @@ private class LeastSquaresAggregator(
304296
Vectors.dense(temp)
305297
}
306298

307-
axpy(-diffSum, result, correction)
299+
axpy(-diffSum, correction, result)
308300
scal(1.0 / totalCnt, result)
309301
result
310302
}
@@ -319,27 +311,25 @@ private class LeastSquaresCostFun(
319311
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
320312

321313
override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
322-
val w = Vectors.fromBreeze(weights)
323-
324-
val leastSquaresAggregator = data.treeAggregate(
325-
new LeastSquaresAggregator(w, labelStd, labelMean, featuresStd, featuresMean))(
326-
seqOp = (c, v) => (c, v) match {
327-
case (aggregator, (label, features)) => aggregator.add(label, features)
328-
},
329-
combOp = (c1, c2) => (c1, c2) match {
330-
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
331-
})
332-
333-
// regVal is sum of weight squares for L2 regularization
334-
val norm = brzNorm(weights, 2.0)
335-
val regVal = 0.5 * effectiveL2regParam * norm * norm
336-
337-
val loss = leastSquaresAggregator.loss + regVal
338-
339-
val gradientTotal = w.copy
340-
scal(effectiveL2regParam, gradientTotal)
341-
axpy(1.0, leastSquaresAggregator.gradient, gradientTotal)
342-
343-
(loss, gradientTotal.toBreeze.asInstanceOf[BDV[Double]])
344-
}
314+
val w = Vectors.fromBreeze(weights)
315+
316+
val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
317+
labelMean, featuresStd, featuresMean))(
318+
seqOp = (c, v) => (c, v) match {
319+
case (aggregator, (label, features)) => aggregator.add(label, features)
320+
},
321+
combOp = (c1, c2) => (c1, c2) match {
322+
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
323+
})
324+
325+
// regVal is sum of weight squares for L2 regularization
326+
val norm = brzNorm(weights, 2.0)
327+
val regVal = 0.5 * effectiveL2regParam * norm * norm
328+
329+
val loss = leastSquaresAggregator.loss + regVal
330+
val gradient = leastSquaresAggregator.gradient
331+
axpy(effectiveL2regParam, w, gradient)
332+
333+
(loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
334+
}
345335
}

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.mllib.optimization
1919

20+
import scala.collection.mutable
2021
import scala.collection.mutable.ArrayBuffer
2122

2223
import breeze.linalg.{DenseVector => BDV}
@@ -164,7 +165,7 @@ object LBFGS extends Logging {
164165
regParam: Double,
165166
initialWeights: Vector): (Vector, Array[Double]) = {
166167

167-
val lossHistory = new ArrayBuffer[Double](maxNumIterations)
168+
val lossHistory = mutable.ArrayBuilder.make[Double]
168169

169170
val numExamples = data.count()
170171

@@ -181,24 +182,26 @@ object LBFGS extends Logging {
181182
* and regVal is the regularization value computed in the previous iteration as well.
182183
*/
183184
var state = states.next()
184-
while(states.hasNext) {
185-
lossHistory.append(state.value)
185+
while (states.hasNext) {
186+
lossHistory += state.value
186187
state = states.next()
187188
}
188-
lossHistory.append(state.value)
189+
lossHistory += state.value
189190
val weights = Vectors.fromBreeze(state.x)
190191

192+
val lossHistoryArray = lossHistory.result()
193+
191194
logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format(
192-
lossHistory.takeRight(10).mkString(", ")))
195+
lossHistoryArray.takeRight(10).mkString(", ")))
193196

194-
(weights, lossHistory.toArray)
197+
(weights, lossHistoryArray)
195198
}
196199

197200
/**
198201
* CostFun implements Breeze's DiffFunction[T], which returns the loss and gradient
199202
* at a particular point (weights). It's used in Breeze's convex optimization routines.
200203
*/
201-
private[spark] class CostFun(
204+
private class CostFun(
202205
data: RDD[(Double, Vector)],
203206
gradient: Gradient,
204207
updater: Updater,

0 commit comments

Comments
 (0)