1717
1818package org .apache .spark .ml .regression
1919
20+ import scala .collection .mutable
21+
22+ import breeze .linalg .{norm => brzNorm , DenseVector => BDV }
23+ import breeze .optimize .{LBFGS => BreezeLBFGS , OWLQN => BreezeOWLQN }
24+ import breeze .optimize .{CachedDiffFunction , DiffFunction }
25+
2026import org .apache .spark .annotation .AlphaComponent
2127import org .apache .spark .ml .param .{Params , ParamMap }
22- import org .apache .spark .ml .param .shared ._
23- import org .apache .spark .mllib .linalg .{BLAS , Vector }
24- import org .apache .spark .mllib .regression .LinearRegressionWithSGD
28+ import org .apache .spark .ml .param .shared .{HasElasticNetParam , HasMaxIter , HasRegParam , HasTol }
29+ import org .apache .spark .mllib .stat .MultivariateOnlineSummarizer
30+ import org .apache .spark .mllib .linalg .{Vector , Vectors }
31+ import org .apache .spark .mllib .linalg .BLAS ._
32+ import org .apache .spark .mllib .regression .LabeledPoint
33+ import org .apache .spark .rdd .RDD
2534import org .apache .spark .sql .DataFrame
2635import org .apache .spark .storage .StorageLevel
27-
36+ import org . apache . spark . util . StatCounter
2837
2938/**
3039 * Params for linear regression.
3140 */
3241private [regression] trait LinearRegressionParams extends RegressorParams
33- with HasRegParam with HasMaxIter
34-
42+ with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
3543
3644/**
3745 * :: AlphaComponent ::
@@ -42,34 +50,119 @@ private[regression] trait LinearRegressionParams extends RegressorParams
4250class LinearRegression extends Regressor [Vector , LinearRegression , LinearRegressionModel ]
4351 with LinearRegressionParams {
4452
45- setDefault(regParam -> 0.1 , maxIter -> 100 )
46-
47- /** @group setParam */
53+ /**
54+ * Set the regularization parameter.
55+ * Default is 0.0.
56+ * @group setParam
57+ */
4858 def setRegParam (value : Double ): this .type = set(regParam, value)
59+ setDefault(regParam -> 0.0 )
60+
61+ /**
62+ * Set the ElasticNet mixing parameter.
63+ * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
64+ * For 0 < alpha < 1, the penalty is a combination of L1 and L2.
65+ * Default is 0.0 which is an L2 penalty.
66+ * @group setParam
67+ */
68+ def setElasticNetParam (value : Double ): this .type = set(elasticNetParam, value)
69+ setDefault(elasticNetParam -> 0.0 )
4970
50- /** @group setParam */
71+ /**
72+ * Set the maximal number of iterations.
73+ * Default is 100.
74+ * @group setParam
75+ */
5176 def setMaxIter (value : Int ): this .type = set(maxIter, value)
77+ setDefault(maxIter -> 100 )
78+
79+ /**
80+ * Set the convergence tolerance of iterations.
81+ * Smaller value will lead to higher accuracy with the cost of more iterations.
82+ * Default is 1E-6.
83+ * @group setParam
84+ */
85+ def setTol (value : Double ): this .type = set(tol, value)
86+ setDefault(tol -> 1E-6 )
5287
5388 override protected def train (dataset : DataFrame , paramMap : ParamMap ): LinearRegressionModel = {
54- // Extract columns from data. If dataset is persisted, do not persist oldDataset.
55- val oldDataset = extractLabeledPoints(dataset, paramMap)
89+ // Extract columns from data. If dataset is persisted, do not persist instances.
90+ val instances = extractLabeledPoints(dataset, paramMap).map {
91+ case LabeledPoint (label : Double , features : Vector ) => (label, features)
92+ }
5693 val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel .NONE
5794 if (handlePersistence) {
58- oldDataset.persist(StorageLevel .MEMORY_AND_DISK )
95+ instances.persist(StorageLevel .MEMORY_AND_DISK )
96+ }
97+
98+ val (summarizer, statCounter) = instances.treeAggregate(
99+ (new MultivariateOnlineSummarizer , new StatCounter ))( {
100+ case ((summarizer : MultivariateOnlineSummarizer , statCounter : StatCounter ),
101+ (label : Double , features : Vector )) =>
102+ (summarizer.add(features), statCounter.merge(label))
103+ }, {
104+ case ((summarizer1 : MultivariateOnlineSummarizer , statCounter1 : StatCounter ),
105+ (summarizer2 : MultivariateOnlineSummarizer , statCounter2 : StatCounter )) =>
106+ (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2))
107+ })
108+
109+ val numFeatures = summarizer.mean.size
110+ val yMean = statCounter.mean
111+ val yStd = math.sqrt(statCounter.variance)
112+
113+ val featuresMean = summarizer.mean.toArray
114+ val featuresStd = summarizer.variance.toArray.map(math.sqrt)
115+
116+ // Since we implicitly do the feature scaling when we compute the cost function
117+ // to improve the convergence, the effective regParam will be changed.
118+ val effectiveRegParam = paramMap(regParam) / yStd
119+ val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam
120+ val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam
121+
122+ val costFun = new LeastSquaresCostFun (instances, yStd, yMean,
123+ featuresStd, featuresMean, effectiveL2RegParam)
124+
125+ val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0 ) {
126+ new BreezeLBFGS [BDV [Double ]](paramMap(maxIter), 10 , paramMap(tol))
127+ } else {
128+ new BreezeOWLQN [Int , BDV [Double ]](paramMap(maxIter), 10 , effectiveL1RegParam, paramMap(tol))
129+ }
130+
131+ val initialWeights = Vectors .zeros(numFeatures)
132+ val states =
133+ optimizer.iterations(new CachedDiffFunction (costFun), initialWeights.toBreeze.toDenseVector)
134+
135+ var state = states.next()
136+ val lossHistory = mutable.ArrayBuilder .make[Double ]
137+
138+ while (states.hasNext) {
139+ lossHistory += state.value
140+ state = states.next()
141+ }
142+ lossHistory += state.value
143+
144+ // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector.
145+ // The weights are trained in the scaled space; we're converting them back to
146+ // the original space.
147+ val weights = {
148+ val rawWeights = state.x.toArray.clone()
149+ var i = 0
150+ while (i < rawWeights.length) {
151+ rawWeights(i) *= { if (featuresStd(i) != 0.0 ) yStd / featuresStd(i) else 0.0 }
152+ i += 1
153+ }
154+ Vectors .dense(rawWeights)
59155 }
60156
61- // Train model
62- val lr = new LinearRegressionWithSGD ()
63- lr.optimizer
64- .setRegParam(paramMap(regParam))
65- .setNumIterations(paramMap(maxIter))
66- val model = lr.run(oldDataset)
67- val lrm = new LinearRegressionModel (this , paramMap, model.weights, model.intercept)
157+ // The intercept in R's GLMNET is computed using closed form after the coefficients are
158+ // converged. See the following discussion for detail.
159+ // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
160+ val intercept = yMean - dot(weights, Vectors .dense(featuresMean))
68161
69162 if (handlePersistence) {
70- oldDataset .unpersist()
163+ instances .unpersist()
71164 }
72- lrm
165+ new LinearRegressionModel ( this , paramMap, weights, intercept)
73166 }
74167}
75168
@@ -88,7 +181,7 @@ class LinearRegressionModel private[ml] (
88181 with LinearRegressionParams {
89182
90183 override protected def predict (features : Vector ): Double = {
91- BLAS . dot(features, weights) + intercept
184+ dot(features, weights) + intercept
92185 }
93186
94187 override protected def copy (): LinearRegressionModel = {
@@ -97,3 +190,168 @@ class LinearRegressionModel private[ml] (
97190 m
98191 }
99192}
193+
194+ /**
195+ * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function,
196+ * as used in linear regression for samples in sparse or dense vector in a online fashion.
197+ *
198+ * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of
199+ * the corresponding joint dataset.
200+ *
201+
202+ * * Compute gradient and loss for a Least-squared loss function, as used in linear regression.
203+ * This is correct for the averaged least squares loss function (mean squared error)
204+ * L = 1/2n ||A weights-y||^2
205+ * See also the documentation for the precise formulation.
206+ *
207+ * @param weights weights/coefficients corresponding to features
208+ *
209+ * @param updater Updater to be used to update weights after every iteration.
210+ */
211+ private class LeastSquaresAggregator (
212+ weights : Vector ,
213+ labelStd : Double ,
214+ labelMean : Double ,
215+ featuresStd : Array [Double ],
216+ featuresMean : Array [Double ]) extends Serializable {
217+
218+ private var totalCnt : Long = 0
219+ private var lossSum = 0.0
220+ private var diffSum = 0.0
221+
222+ private val (effectiveWeightsArray : Array [Double ], offset : Double , dim : Int ) = {
223+ val weightsArray = weights.toArray.clone()
224+ var sum = 0.0
225+ var i = 0
226+ while (i < weightsArray.length) {
227+ if (featuresStd(i) != 0.0 ) {
228+ weightsArray(i) /= featuresStd(i)
229+ sum += weightsArray(i) * featuresMean(i)
230+ } else {
231+ weightsArray(i) = 0.0
232+ }
233+ i += 1
234+ }
235+ (weightsArray, - sum + labelMean / labelStd, weightsArray.length)
236+ }
237+ private val effectiveWeightsVector = Vectors .dense(effectiveWeightsArray)
238+
239+ private val gradientSumArray : Array [Double ] = Array .ofDim[Double ](dim)
240+
241+ /**
242+ * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient
243+ * of the objective function.
244+ *
245+ * @param label The label for this data point.
246+ * @param data The features for one data point in dense/sparse vector format to be added
247+ * into this aggregator.
248+ * @return This LeastSquaresAggregator object.
249+ */
250+ def add (label : Double , data : Vector ): this .type = {
251+ require(dim == data.size, s " Dimensions mismatch when adding new sample. " +
252+ s " Expecting $dim but got ${data.size}. " )
253+
254+ val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset
255+
256+ if (diff != 0 ) {
257+ val localGradientSumArray = gradientSumArray
258+ data.foreachActive { (index, value) =>
259+ if (featuresStd(index) != 0.0 && value != 0.0 ) {
260+ localGradientSumArray(index) += diff * value / featuresStd(index)
261+ }
262+ }
263+ lossSum += diff * diff / 2.0
264+ diffSum += diff
265+ }
266+
267+ totalCnt += 1
268+ this
269+ }
270+
271+ /**
272+ * Merge another LeastSquaresAggregator, and update the loss and gradient
273+ * of the objective function.
274+ * (Note that it's in place merging; as a result, `this` object will be modified.)
275+ *
276+ * @param other The other LeastSquaresAggregator to be merged.
277+ * @return This LeastSquaresAggregator object.
278+ */
279+ def merge (other : LeastSquaresAggregator ): this .type = {
280+ require(dim == other.dim, s " Dimensions mismatch when merging with another " +
281+ s " LeastSquaresAggregator. Expecting $dim but got ${other.dim}. " )
282+
283+ if (other.totalCnt != 0 ) {
284+ totalCnt += other.totalCnt
285+ lossSum += other.lossSum
286+ diffSum += other.diffSum
287+
288+ var i = 0
289+ val localThisGradientSumArray = this .gradientSumArray
290+ val localOtherGradientSumArray = other.gradientSumArray
291+ while (i < dim) {
292+ localThisGradientSumArray(i) += localOtherGradientSumArray(i)
293+ i += 1
294+ }
295+ }
296+ this
297+ }
298+
299+ def count : Long = totalCnt
300+
301+ def loss : Double = lossSum / totalCnt
302+
303+ def gradient : Vector = {
304+ val result = Vectors .dense(gradientSumArray.clone())
305+
306+ val correction = {
307+ val temp = effectiveWeightsArray.clone()
308+ var i = 0
309+ while (i < temp.length) {
310+ temp(i) *= featuresMean(i)
311+ i += 1
312+ }
313+ Vectors .dense(temp)
314+ }
315+
316+ axpy(- diffSum, correction, result)
317+ scal(1.0 / totalCnt, result)
318+ result
319+ }
320+ }
321+
322+ /**
323+ * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost.
324+ * It returns the loss and gradient with L2 regularization at a particular point (weights).
325+ * It's used in Breeze's convex optimization routines.
326+ */
327+ private class LeastSquaresCostFun (
328+ data : RDD [(Double , Vector )],
329+ labelStd : Double ,
330+ labelMean : Double ,
331+ featuresStd : Array [Double ],
332+ featuresMean : Array [Double ],
333+ effectiveL2regParam : Double ) extends DiffFunction [BDV [Double ]] {
334+
335+ override def calculate (weights : BDV [Double ]): (Double , BDV [Double ]) = {
336+ val w = Vectors .fromBreeze(weights)
337+
338+ val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator (w, labelStd,
339+ labelMean, featuresStd, featuresMean))(
340+ seqOp = (c, v) => (c, v) match {
341+ case (aggregator, (label, features)) => aggregator.add(label, features)
342+ },
343+ combOp = (c1, c2) => (c1, c2) match {
344+ case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
345+ })
346+
347+ // regVal is the sum of weight squares for L2 regularization
348+ val norm = brzNorm(weights, 2.0 )
349+ val regVal = 0.5 * effectiveL2regParam * norm * norm
350+
351+ val loss = leastSquaresAggregator.loss + regVal
352+ val gradient = leastSquaresAggregator.gradient
353+ axpy(effectiveL2regParam, w, gradient)
354+
355+ (loss, gradient.toBreeze.asInstanceOf [BDV [Double ]])
356+ }
357+ }
0 commit comments