Skip to content

Commit 6a827d5

Browse files
DB Tsaimengxr
authored andcommitted
[SPARK-5253] [ML] LinearRegression with L1/L2 (ElasticNet) using OWLQN
Author: DB Tsai <[email protected]> Author: DB Tsai <[email protected]> Closes apache#4259 from dbtsai/lir and squashes the following commits: a81c201 [DB Tsai] add import org.apache.spark.util.Utils back 9fc48ed [DB Tsai] rebase 2178b63 [DB Tsai] add comments 9988ca8 [DB Tsai] addressed feedback and fixed a bug. TODO: documentation and build another synthetic dataset which can catch the bug fixed in this commit. fcbaefe [DB Tsai] Refactoring 4eb078d [DB Tsai] first commit
1 parent 268c419 commit 6a827d5

File tree

8 files changed

+508
-64
lines changed

8 files changed

+508
-64
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ private[shared] object SharedParamsCodeGen {
4646
ParamDesc[String]("outputCol", "output column name"),
4747
ParamDesc[Int]("checkpointInterval", "checkpoint interval"),
4848
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
49-
ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")))
49+
ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")),
50+
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"),
51+
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"))
5052

5153
val code = genSharedParams(params)
5254
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,4 +276,38 @@ trait HasSeed extends Params {
276276
/** @group getParam */
277277
final def getSeed: Long = getOrDefault(seed)
278278
}
279+
280+
/**
281+
* :: DeveloperApi ::
282+
* Trait for shared param elasticNetParam.
283+
*/
284+
@DeveloperApi
285+
trait HasElasticNetParam extends Params {
286+
287+
/**
288+
* Param for the ElasticNet mixing parameter.
289+
* @group param
290+
*/
291+
final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter")
292+
293+
/** @group getParam */
294+
final def getElasticNetParam: Double = getOrDefault(elasticNetParam)
295+
}
296+
297+
/**
298+
* :: DeveloperApi ::
299+
* Trait for shared param tol.
300+
*/
301+
@DeveloperApi
302+
trait HasTol extends Params {
303+
304+
/**
305+
* Param for the convergence tolerance for iterative algorithms.
306+
* @group param
307+
*/
308+
final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
309+
310+
/** @group getParam */
311+
final def getTol: Double = getOrDefault(tol)
312+
}
279313
// scalastyle:on

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 281 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,29 @@
1717

1818
package org.apache.spark.ml.regression
1919

20+
import scala.collection.mutable
21+
22+
import breeze.linalg.{norm => brzNorm, DenseVector => BDV}
23+
import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
24+
import breeze.optimize.{CachedDiffFunction, DiffFunction}
25+
2026
import org.apache.spark.annotation.AlphaComponent
2127
import org.apache.spark.ml.param.{Params, ParamMap}
22-
import org.apache.spark.ml.param.shared._
23-
import org.apache.spark.mllib.linalg.{BLAS, Vector}
24-
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
28+
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
29+
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
30+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
31+
import org.apache.spark.mllib.linalg.BLAS._
32+
import org.apache.spark.mllib.regression.LabeledPoint
33+
import org.apache.spark.rdd.RDD
2534
import org.apache.spark.sql.DataFrame
2635
import org.apache.spark.storage.StorageLevel
27-
36+
import org.apache.spark.util.StatCounter
2837

2938
/**
3039
* Params for linear regression.
3140
*/
3241
private[regression] trait LinearRegressionParams extends RegressorParams
33-
with HasRegParam with HasMaxIter
34-
42+
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
3543

3644
/**
3745
* :: AlphaComponent ::
@@ -42,34 +50,119 @@ private[regression] trait LinearRegressionParams extends RegressorParams
4250
class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
4351
with LinearRegressionParams {
4452

45-
setDefault(regParam -> 0.1, maxIter -> 100)
46-
47-
/** @group setParam */
53+
/**
54+
* Set the regularization parameter.
55+
* Default is 0.0.
56+
* @group setParam
57+
*/
4858
def setRegParam(value: Double): this.type = set(regParam, value)
59+
setDefault(regParam -> 0.0)
60+
61+
/**
62+
* Set the ElasticNet mixing parameter.
63+
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
64+
* For 0 < alpha < 1, the penalty is a combination of L1 and L2.
65+
* Default is 0.0 which is an L2 penalty.
66+
* @group setParam
67+
*/
68+
def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value)
69+
setDefault(elasticNetParam -> 0.0)
4970

50-
/** @group setParam */
71+
/**
72+
* Set the maximal number of iterations.
73+
* Default is 100.
74+
* @group setParam
75+
*/
5176
def setMaxIter(value: Int): this.type = set(maxIter, value)
77+
setDefault(maxIter -> 100)
78+
79+
/**
80+
* Set the convergence tolerance of iterations.
81+
* Smaller value will lead to higher accuracy with the cost of more iterations.
82+
* Default is 1E-6.
83+
* @group setParam
84+
*/
85+
def setTol(value: Double): this.type = set(tol, value)
86+
setDefault(tol -> 1E-6)
5287

5388
override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
54-
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
55-
val oldDataset = extractLabeledPoints(dataset, paramMap)
89+
// Extract columns from data. If dataset is persisted, do not persist instances.
90+
val instances = extractLabeledPoints(dataset, paramMap).map {
91+
case LabeledPoint(label: Double, features: Vector) => (label, features)
92+
}
5693
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
5794
if (handlePersistence) {
58-
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
95+
instances.persist(StorageLevel.MEMORY_AND_DISK)
96+
}
97+
98+
val (summarizer, statCounter) = instances.treeAggregate(
99+
(new MultivariateOnlineSummarizer, new StatCounter))( {
100+
case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter),
101+
(label: Double, features: Vector)) =>
102+
(summarizer.add(features), statCounter.merge(label))
103+
}, {
104+
case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter),
105+
(summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) =>
106+
(summarizer1.merge(summarizer2), statCounter1.merge(statCounter2))
107+
})
108+
109+
val numFeatures = summarizer.mean.size
110+
val yMean = statCounter.mean
111+
val yStd = math.sqrt(statCounter.variance)
112+
113+
val featuresMean = summarizer.mean.toArray
114+
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
115+
116+
// Since we implicitly do the feature scaling when we compute the cost function
117+
// to improve the convergence, the effective regParam will be changed.
118+
val effectiveRegParam = paramMap(regParam) / yStd
119+
val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam
120+
val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam
121+
122+
val costFun = new LeastSquaresCostFun(instances, yStd, yMean,
123+
featuresStd, featuresMean, effectiveL2RegParam)
124+
125+
val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
126+
new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol))
127+
} else {
128+
new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol))
129+
}
130+
131+
val initialWeights = Vectors.zeros(numFeatures)
132+
val states =
133+
optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
134+
135+
var state = states.next()
136+
val lossHistory = mutable.ArrayBuilder.make[Double]
137+
138+
while (states.hasNext) {
139+
lossHistory += state.value
140+
state = states.next()
141+
}
142+
lossHistory += state.value
143+
144+
// TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector.
145+
// The weights are trained in the scaled space; we're converting them back to
146+
// the original space.
147+
val weights = {
148+
val rawWeights = state.x.toArray.clone()
149+
var i = 0
150+
while (i < rawWeights.length) {
151+
rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
152+
i += 1
153+
}
154+
Vectors.dense(rawWeights)
59155
}
60156

61-
// Train model
62-
val lr = new LinearRegressionWithSGD()
63-
lr.optimizer
64-
.setRegParam(paramMap(regParam))
65-
.setNumIterations(paramMap(maxIter))
66-
val model = lr.run(oldDataset)
67-
val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept)
157+
// The intercept in R's GLMNET is computed using closed form after the coefficients are
158+
// converged. See the following discussion for detail.
159+
// http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
160+
val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
68161

69162
if (handlePersistence) {
70-
oldDataset.unpersist()
163+
instances.unpersist()
71164
}
72-
lrm
165+
new LinearRegressionModel(this, paramMap, weights, intercept)
73166
}
74167
}
75168

@@ -88,7 +181,7 @@ class LinearRegressionModel private[ml] (
88181
with LinearRegressionParams {
89182

90183
override protected def predict(features: Vector): Double = {
91-
BLAS.dot(features, weights) + intercept
184+
dot(features, weights) + intercept
92185
}
93186

94187
override protected def copy(): LinearRegressionModel = {
@@ -97,3 +190,168 @@ class LinearRegressionModel private[ml] (
97190
m
98191
}
99192
}
193+
194+
/**
195+
* LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function,
196+
* as used in linear regression for samples in sparse or dense vector in a online fashion.
197+
*
198+
* Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of
199+
* the corresponding joint dataset.
200+
*
201+
202+
* * Compute gradient and loss for a Least-squared loss function, as used in linear regression.
203+
* This is correct for the averaged least squares loss function (mean squared error)
204+
* L = 1/2n ||A weights-y||^2
205+
* See also the documentation for the precise formulation.
206+
*
207+
* @param weights weights/coefficients corresponding to features
208+
*
209+
* @param updater Updater to be used to update weights after every iteration.
210+
*/
211+
private class LeastSquaresAggregator(
212+
weights: Vector,
213+
labelStd: Double,
214+
labelMean: Double,
215+
featuresStd: Array[Double],
216+
featuresMean: Array[Double]) extends Serializable {
217+
218+
private var totalCnt: Long = 0
219+
private var lossSum = 0.0
220+
private var diffSum = 0.0
221+
222+
private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = {
223+
val weightsArray = weights.toArray.clone()
224+
var sum = 0.0
225+
var i = 0
226+
while (i < weightsArray.length) {
227+
if (featuresStd(i) != 0.0) {
228+
weightsArray(i) /= featuresStd(i)
229+
sum += weightsArray(i) * featuresMean(i)
230+
} else {
231+
weightsArray(i) = 0.0
232+
}
233+
i += 1
234+
}
235+
(weightsArray, -sum + labelMean / labelStd, weightsArray.length)
236+
}
237+
private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
238+
239+
private val gradientSumArray: Array[Double] = Array.ofDim[Double](dim)
240+
241+
/**
242+
* Add a new training data to this LeastSquaresAggregator, and update the loss and gradient
243+
* of the objective function.
244+
*
245+
* @param label The label for this data point.
246+
* @param data The features for one data point in dense/sparse vector format to be added
247+
* into this aggregator.
248+
* @return This LeastSquaresAggregator object.
249+
*/
250+
def add(label: Double, data: Vector): this.type = {
251+
require(dim == data.size, s"Dimensions mismatch when adding new sample." +
252+
s" Expecting $dim but got ${data.size}.")
253+
254+
val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset
255+
256+
if (diff != 0) {
257+
val localGradientSumArray = gradientSumArray
258+
data.foreachActive { (index, value) =>
259+
if (featuresStd(index) != 0.0 && value != 0.0) {
260+
localGradientSumArray(index) += diff * value / featuresStd(index)
261+
}
262+
}
263+
lossSum += diff * diff / 2.0
264+
diffSum += diff
265+
}
266+
267+
totalCnt += 1
268+
this
269+
}
270+
271+
/**
272+
* Merge another LeastSquaresAggregator, and update the loss and gradient
273+
* of the objective function.
274+
* (Note that it's in place merging; as a result, `this` object will be modified.)
275+
*
276+
* @param other The other LeastSquaresAggregator to be merged.
277+
* @return This LeastSquaresAggregator object.
278+
*/
279+
def merge(other: LeastSquaresAggregator): this.type = {
280+
require(dim == other.dim, s"Dimensions mismatch when merging with another " +
281+
s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.")
282+
283+
if (other.totalCnt != 0) {
284+
totalCnt += other.totalCnt
285+
lossSum += other.lossSum
286+
diffSum += other.diffSum
287+
288+
var i = 0
289+
val localThisGradientSumArray = this.gradientSumArray
290+
val localOtherGradientSumArray = other.gradientSumArray
291+
while (i < dim) {
292+
localThisGradientSumArray(i) += localOtherGradientSumArray(i)
293+
i += 1
294+
}
295+
}
296+
this
297+
}
298+
299+
def count: Long = totalCnt
300+
301+
def loss: Double = lossSum / totalCnt
302+
303+
def gradient: Vector = {
304+
val result = Vectors.dense(gradientSumArray.clone())
305+
306+
val correction = {
307+
val temp = effectiveWeightsArray.clone()
308+
var i = 0
309+
while (i < temp.length) {
310+
temp(i) *= featuresMean(i)
311+
i += 1
312+
}
313+
Vectors.dense(temp)
314+
}
315+
316+
axpy(-diffSum, correction, result)
317+
scal(1.0 / totalCnt, result)
318+
result
319+
}
320+
}
321+
322+
/**
323+
* LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost.
324+
* It returns the loss and gradient with L2 regularization at a particular point (weights).
325+
* It's used in Breeze's convex optimization routines.
326+
*/
327+
private class LeastSquaresCostFun(
328+
data: RDD[(Double, Vector)],
329+
labelStd: Double,
330+
labelMean: Double,
331+
featuresStd: Array[Double],
332+
featuresMean: Array[Double],
333+
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
334+
335+
override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
336+
val w = Vectors.fromBreeze(weights)
337+
338+
val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
339+
labelMean, featuresStd, featuresMean))(
340+
seqOp = (c, v) => (c, v) match {
341+
case (aggregator, (label, features)) => aggregator.add(label, features)
342+
},
343+
combOp = (c1, c2) => (c1, c2) match {
344+
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
345+
})
346+
347+
// regVal is the sum of weight squares for L2 regularization
348+
val norm = brzNorm(weights, 2.0)
349+
val regVal = 0.5 * effectiveL2regParam * norm * norm
350+
351+
val loss = leastSquaresAggregator.loss + regVal
352+
val gradient = leastSquaresAggregator.gradient
353+
axpy(effectiveL2regParam, w, gradient)
354+
355+
(loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
356+
}
357+
}

0 commit comments

Comments
 (0)