1717
1818package org .apache .spark .ml .regression
1919
20- import org .apache .spark .mllib .linalg .BLAS .dot
21- import org .apache .spark .rdd .RDD
22-
2320import scala .collection .mutable .ArrayBuffer
2421
25- import breeze .linalg .{norm => brzNorm , DenseVector => BDV , SparseVector => BSV }
22+ import breeze .linalg .{norm => brzNorm , DenseVector => BDV }
2623import breeze .optimize .{LBFGS => BreezeLBFGS , OWLQN => BreezeOWLQN }
2724import breeze .optimize .{CachedDiffFunction , DiffFunction }
2825
2926import org .apache .spark .annotation .AlphaComponent
30- import org .apache .spark .ml .param .shared .{HasElasticNetParam , HasMaxIter , HasTol }
3127import org .apache .spark .ml .param .{Params , ParamMap }
28+ import org .apache .spark .ml .param .shared .{HasElasticNetParam , HasMaxIter , HasRegParam , HasTol }
3229import org .apache .spark .mllib .stat .MultivariateOnlineSummarizer
30+ import org .apache .spark .mllib .linalg .{Vector , Vectors }
3331import org .apache .spark .mllib .linalg .BLAS ._
34- import org .apache .spark .mllib .linalg .{BLAS , Vector , Vectors }
3532import org .apache .spark .mllib .regression .LabeledPoint
33+ import org .apache .spark .rdd .RDD
3634import org .apache .spark .sql .DataFrame
3735import org .apache .spark .storage .StorageLevel
36+ import org .apache .spark .util .StatCounter
3837
3938/**
4039 * Params for linear regression.
4140 */
4241private [regression] trait LinearRegressionParams extends RegressorParams
43- with HasElasticNetParam with HasMaxIter with HasTol
42+ with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
4443
4544/**
4645 * :: AlphaComponent ::
@@ -57,7 +56,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
5756 * @group setParam
5857 */
5958 def setRegParam (value : Double ): this .type = set(regParam, value)
60- setRegParam( 0.0 )
59+ setDefault(regParam -> 0.0 )
6160
6261 /**
6362 * Set the ElasticNet mixing parameter.
@@ -67,15 +66,15 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
6766 * @group setParam
6867 */
6968 def setElasticNetParam (value : Double ): this .type = set(elasticNetParam, value)
70- setElasticNetParam( 0.0 )
69+ setDefault(elasticNetParam -> 0.0 )
7170
7271 /**
7372 * Set the maximal number of iterations.
7473 * Default is 100.
7574 * @group setParam
7675 */
7776 def setMaxIter (value : Int ): this .type = set(maxIter, value)
78- setMaxIter( 100 )
77+ setDefault(maxIter -> 100 )
7978
8079 /**
8180 * Set the convergence tolerance of iterations.
@@ -84,7 +83,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
8483 * @group setParam
8584 */
8685 def setTol (value : Double ): this .type = set(tol, value)
87- setTol( 1E-9 )
86+ setDefault(tol -> 1E-9 )
8887
8988 override protected def train (dataset : DataFrame , paramMap : ParamMap ): LinearRegressionModel = {
9089 // Extract columns from data. If dataset is persisted, do not persist instances.
@@ -96,38 +95,41 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
9695 instances.persist(StorageLevel .MEMORY_AND_DISK )
9796 }
9897
99- // TODO: Benchmark if using two MultivariateOnlineSummarizer will be faster
100- // than appending the label into the vector.
101- val summary = instances.map { case (label : Double , features : Vector ) =>
102- Vectors .fromBreeze(features.toBreeze match {
103- case dv : BDV [Double ] => BDV .vertcat(dv, new BDV [Double ](Array (label)))
104- case sv : BSV [Double ] => BSV .vertcat(sv, new BSV [Double ](Array (0 ), Array (label), 1 ))
105- case v : Any =>
106- throw new IllegalArgumentException (" Do not support vector type " + v.getClass)
107- })}.treeAggregate(new MultivariateOnlineSummarizer )(
108- (aggregator, data) => aggregator.add(data),
109- (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
110-
111- val numFeatures = summary.mean.size - 1
112- val yMean = summary.mean(numFeatures)
113- val yStd = math.sqrt(summary.variance(numFeatures))
114-
98+ val (summarizer, statCounter) = instances.treeAggregate(
99+ (new MultivariateOnlineSummarizer , new StatCounter ))( {
100+ case ((summarizer : MultivariateOnlineSummarizer , statCounter : StatCounter ),
101+ (label : Double , features : Vector )) =>
102+ (summarizer.add(features), statCounter.merge(label))
103+ }, {
104+ case ((summarizer1 : MultivariateOnlineSummarizer , statCounter1 : StatCounter ),
105+ (summarizer2 : MultivariateOnlineSummarizer , statCounter2 : StatCounter )) =>
106+ (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2))
107+ })
108+
109+ val numFeatures = summarizer.mean.size
110+ val yMean = statCounter.mean
111+ val yStd = math.sqrt(statCounter.variance)
112+
113+ val featuresMean = summarizer.mean.toArray
114+ val featuresStd = summarizer.variance.toArray.map(math.sqrt)
115+
116+ // Since we implicitly do the feature scaling when we compute the cost function
117+ // to improve the convergence, the effective regParam will be changed.
115118 val effectiveRegParam = paramMap(regParam) / yStd
116119 val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam
117120 val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam
118121
119122 val costFun = new LeastSquaresCostFun (
120123 instances,
121124 yStd, yMean,
122- summary.variance.toArray.slice( 0 , numFeatures).map( Math .sqrt(_)).toArray ,
123- summary.mean.toArray.slice( 0 , numFeatures).toArray ,
125+ featuresStd ,
126+ featuresMean ,
124127 effectiveL2RegParam)
125128
126129 val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0 ) {
127130 new BreezeLBFGS [BDV [Double ]](paramMap(maxIter), 10 , paramMap(tol))
128131 } else {
129- new BreezeOWLQN [Int , BDV [Double ]](
130- paramMap(maxIter), 10 , effectiveL1RegParam, paramMap(tol))
132+ new BreezeOWLQN [Int , BDV [Double ]](paramMap(maxIter), 10 , effectiveL1RegParam, paramMap(tol))
131133 }
132134
133135 val initialWeights = Vectors .zeros(numFeatures)
@@ -142,20 +144,23 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
142144 }
143145 lossHistory.append(state.value)
144146
147+ // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector.
148+ // The weights are trained in the scaled space; we're converting them back to
149+ // the original space.
145150 val weights = {
146- val rawWeights = state.x.toArray
147- val std = summary.variance.toArray.slice(0 , numFeatures).map(Math .sqrt(_)).toArray
148- require(rawWeights.size == std.size)
149-
151+ val rawWeights = state.x.toArray.clone()
150152 var i = 0
151- while (i < rawWeights.size ) {
152- rawWeights(i) = if (std (i) != 0.0 ) rawWeights(i) * yStd / std (i) else 0.0
153+ while (i < rawWeights.length ) {
154+ rawWeights(i) = if (featuresStd (i) != 0.0 ) rawWeights(i) * yStd / featuresStd (i) else 0.0
153155 i += 1
154156 }
155157 Vectors .dense(rawWeights)
156158 }
157159
158- val intercept = yMean - dot(weights, Vectors .dense(summary.mean.toArray.slice(0 , numFeatures)))
160+ // The intercept in R's GLMNET is computed using closed form after the coefficients are
161+ // converged. See the following discussion for detail.
162+ // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
163+ val intercept = yMean - dot(weights, Vectors .dense(featuresMean))
159164
160165 if (handlePersistence) {
161166 instances.unpersist()
@@ -179,7 +184,7 @@ class LinearRegressionModel private[ml] (
179184 with LinearRegressionParams {
180185
181186 override protected def predict (features : Vector ): Double = {
182- BLAS . dot(features, weights) + intercept
187+ dot(features, weights) + intercept
183188 }
184189
185190 override protected def copy (): LinearRegressionModel = {
@@ -204,7 +209,7 @@ private class LeastSquaresAggregator(
204209 val weightsArray = weights.toArray.clone()
205210 var sum = 0.0
206211 var i = 0
207- while (i < weights.size ) {
212+ while (i < weightsArray.length ) {
208213 if (featuresStd(i) != 0.0 ) {
209214 weightsArray(i) /= featuresStd(i)
210215 sum += weightsArray(i) * featuresMean(i)
@@ -215,9 +220,9 @@ private class LeastSquaresAggregator(
215220 }
216221 (weightsArray, - sum, weightsArray.length)
217222 }
218-
219223 private val effectiveWeightsVector = Vectors .dense(effectiveWeightsArray)
220- private var gradientSumArray : Array [Double ] = Array .ofDim[Double ](dim)
224+
225+ private val gradientSumArray : Array [Double ] = Array .ofDim[Double ](dim)
221226
222227 /**
223228 * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient
@@ -258,15 +263,16 @@ private class LeastSquaresAggregator(
258263 * @return This LeastSquaresAggregator object.
259264 */
260265 def merge (other : LeastSquaresAggregator ): this .type = {
266+ require(dim == other.dim, s " Dimensions mismatch when merging with another " +
267+ s " LeastSquaresAggregator. Expecting $dim but got ${other.dim}. " )
268+
261269 if (this .totalCnt != 0 && other.totalCnt != 0 ) {
262- require(dim == other.dim, s " Dimensions mismatch when merging with another summarizer. " +
263- s " Expecting $dim but got ${other.dim}. " )
264270 totalCnt += other.totalCnt
265271 lossSum += other.lossSum
266272 diffSum += other.diffSum
267273
268274 var i = 0
269- val localThisGradientSumArray = gradientSumArray
275+ val localThisGradientSumArray = this . gradientSumArray
270276 val localOtherGradientSumArray = other.gradientSumArray
271277 while (i < dim) {
272278 localThisGradientSumArray(i) += localOtherGradientSumArray(i)
@@ -276,7 +282,7 @@ private class LeastSquaresAggregator(
276282 this .totalCnt = other.totalCnt
277283 this .lossSum = other.lossSum
278284 this .diffSum = other.diffSum
279- this . gradientSumArray = other .gradientSumArray.clone
285+ System .arraycopy(other. gradientSumArray, 0 , this .gradientSumArray, 0 , dim)
280286 }
281287 this
282288 }
@@ -286,12 +292,12 @@ private class LeastSquaresAggregator(
286292 def loss : Double = lossSum / totalCnt
287293
288294 def gradient : Vector = {
289- val result = Vectors .dense(gradientSumArray.clone)
295+ val result = Vectors .dense(gradientSumArray.clone() )
290296
291297 val correction = {
292- val temp = effectiveWeightsArray.clone
298+ val temp = effectiveWeightsArray.clone()
293299 var i = 0
294- while (i < temp.size ) {
300+ while (i < temp.length ) {
295301 temp(i) *= featuresMean(i)
296302 i += 1
297303 }
@@ -324,22 +330,16 @@ private class LeastSquaresCostFun(
324330 case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
325331 })
326332
327- /**
328- * regVal is sum of weight squares if it's L2 updater;
329- * for other updater, the same logic is followed.
330- */
333+ // regVal is sum of weight squares for L2 regularization
331334 val norm = brzNorm(weights, 2.0 )
332335 val regVal = 0.5 * effectiveL2regParam * norm * norm
333336
334337 val loss = leastSquaresAggregator.loss + regVal
335- // The following gradientTotal is actually the regularization part of gradient.
336- // Will add the gradientSum computed from the data with weights in the next step.
338+
337339 val gradientTotal = w.copy
338340 scal(effectiveL2regParam, gradientTotal)
339-
340- // gradientTotal = gradient + gradientTotal
341341 axpy(1.0 , leastSquaresAggregator.gradient, gradientTotal)
342342
343343 (loss, gradientTotal.toBreeze.asInstanceOf [BDV [Double ]])
344344 }
345- }
345+ }
0 commit comments