Skip to content

Commit fcbaefe

Browse files
DB TsaiDB Tsai
authored andcommitted
Refactoring
1 parent 4eb078d commit fcbaefe

File tree

3 files changed

+82
-85
lines changed

3 files changed

+82
-85
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ private[shared] object SharedParamsCodeGen {
4747
ParamDesc[Int]("checkpointInterval", "checkpoint interval"),
4848
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
4949
ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")))
50+
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"),
51+
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"))
5052

5153
val code = genSharedParams(params)
5254
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
@@ -155,7 +157,6 @@ private[shared] object SharedParamsCodeGen {
155157
|
156158
|import org.apache.spark.annotation.DeveloperApi
157159
|import org.apache.spark.ml.param._
158-
|import org.apache.spark.util.Utils
159160
|
160161
|// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
161162
|

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,29 @@
1717

1818
package org.apache.spark.ml.regression
1919

20-
import org.apache.spark.mllib.linalg.BLAS.dot
21-
import org.apache.spark.rdd.RDD
22-
2320
import scala.collection.mutable.ArrayBuffer
2421

25-
import breeze.linalg.{norm => brzNorm, DenseVector => BDV, SparseVector => BSV}
22+
import breeze.linalg.{norm => brzNorm, DenseVector => BDV}
2623
import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
2724
import breeze.optimize.{CachedDiffFunction, DiffFunction}
2825

2926
import org.apache.spark.annotation.AlphaComponent
30-
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasTol}
3127
import org.apache.spark.ml.param.{Params, ParamMap}
28+
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
3229
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
30+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
3331
import org.apache.spark.mllib.linalg.BLAS._
34-
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
3532
import org.apache.spark.mllib.regression.LabeledPoint
33+
import org.apache.spark.rdd.RDD
3634
import org.apache.spark.sql.DataFrame
3735
import org.apache.spark.storage.StorageLevel
36+
import org.apache.spark.util.StatCounter
3837

3938
/**
4039
* Params for linear regression.
4140
*/
4241
private[regression] trait LinearRegressionParams extends RegressorParams
43-
with HasElasticNetParam with HasMaxIter with HasTol
42+
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
4443

4544
/**
4645
* :: AlphaComponent ::
@@ -57,7 +56,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
5756
* @group setParam
5857
*/
5958
def setRegParam(value: Double): this.type = set(regParam, value)
60-
setRegParam(0.0)
59+
setDefault(regParam -> 0.0)
6160

6261
/**
6362
* Set the ElasticNet mixing parameter.
@@ -67,15 +66,15 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
6766
* @group setParam
6867
*/
6968
def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value)
70-
setElasticNetParam(0.0)
69+
setDefault(elasticNetParam -> 0.0)
7170

7271
/**
7372
* Set the maximal number of iterations.
7473
* Default is 100.
7574
* @group setParam
7675
*/
7776
def setMaxIter(value: Int): this.type = set(maxIter, value)
78-
setMaxIter(100)
77+
setDefault(maxIter -> 100)
7978

8079
/**
8180
* Set the convergence tolerance of iterations.
@@ -84,7 +83,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
8483
* @group setParam
8584
*/
8685
def setTol(value: Double): this.type = set(tol, value)
87-
setTol(1E-9)
86+
setDefault(tol -> 1E-9)
8887

8988
override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
9089
// Extract columns from data. If dataset is persisted, do not persist instances.
@@ -96,38 +95,41 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
9695
instances.persist(StorageLevel.MEMORY_AND_DISK)
9796
}
9897

99-
// TODO: Benchmark if using two MultivariateOnlineSummarizer will be faster
100-
// than appending the label into the vector.
101-
val summary = instances.map { case (label: Double, features: Vector) =>
102-
Vectors.fromBreeze(features.toBreeze match {
103-
case dv: BDV[Double] => BDV.vertcat(dv, new BDV[Double](Array(label)))
104-
case sv: BSV[Double] => BSV.vertcat(sv, new BSV[Double](Array(0), Array(label), 1))
105-
case v: Any =>
106-
throw new IllegalArgumentException("Do not support vector type " + v.getClass)
107-
})}.treeAggregate(new MultivariateOnlineSummarizer)(
108-
(aggregator, data) => aggregator.add(data),
109-
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
110-
111-
val numFeatures = summary.mean.size - 1
112-
val yMean = summary.mean(numFeatures)
113-
val yStd = math.sqrt(summary.variance(numFeatures))
114-
98+
val (summarizer, statCounter) = instances.treeAggregate(
99+
(new MultivariateOnlineSummarizer, new StatCounter))( {
100+
case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter),
101+
(label: Double, features: Vector)) =>
102+
(summarizer.add(features), statCounter.merge(label))
103+
}, {
104+
case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter),
105+
(summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) =>
106+
(summarizer1.merge(summarizer2), statCounter1.merge(statCounter2))
107+
})
108+
109+
val numFeatures = summarizer.mean.size
110+
val yMean = statCounter.mean
111+
val yStd = math.sqrt(statCounter.variance)
112+
113+
val featuresMean = summarizer.mean.toArray
114+
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
115+
116+
// Since we implicitly do the feature scaling when we compute the cost function
117+
// to improve the convergence, the effective regParam will be changed.
115118
val effectiveRegParam = paramMap(regParam) / yStd
116119
val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam
117120
val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam
118121

119122
val costFun = new LeastSquaresCostFun(
120123
instances,
121124
yStd, yMean,
122-
summary.variance.toArray.slice(0, numFeatures).map(Math.sqrt(_)).toArray,
123-
summary.mean.toArray.slice(0, numFeatures).toArray,
125+
featuresStd,
126+
featuresMean,
124127
effectiveL2RegParam)
125128

126129
val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
127130
new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol))
128131
} else {
129-
new BreezeOWLQN[Int, BDV[Double]](
130-
paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol))
132+
new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol))
131133
}
132134

133135
val initialWeights = Vectors.zeros(numFeatures)
@@ -142,20 +144,23 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
142144
}
143145
lossHistory.append(state.value)
144146

147+
// TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector.
148+
// The weights are trained in the scaled space; we're converting them back to
149+
// the original space.
145150
val weights = {
146-
val rawWeights = state.x.toArray
147-
val std = summary.variance.toArray.slice(0, numFeatures).map(Math.sqrt(_)).toArray
148-
require(rawWeights.size == std.size)
149-
151+
val rawWeights = state.x.toArray.clone()
150152
var i = 0
151-
while (i < rawWeights.size) {
152-
rawWeights(i) = if (std(i) != 0.0) rawWeights(i) * yStd / std(i) else 0.0
153+
while (i < rawWeights.length) {
154+
rawWeights(i) = if (featuresStd(i) != 0.0) rawWeights(i) * yStd / featuresStd(i) else 0.0
153155
i += 1
154156
}
155157
Vectors.dense(rawWeights)
156158
}
157159

158-
val intercept = yMean - dot(weights, Vectors.dense(summary.mean.toArray.slice(0, numFeatures)))
160+
// The intercept in R's GLMNET is computed using closed form after the coefficients are
161+
// converged. See the following discussion for detail.
162+
// http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
163+
val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
159164

160165
if (handlePersistence) {
161166
instances.unpersist()
@@ -179,7 +184,7 @@ class LinearRegressionModel private[ml] (
179184
with LinearRegressionParams {
180185

181186
override protected def predict(features: Vector): Double = {
182-
BLAS.dot(features, weights) + intercept
187+
dot(features, weights) + intercept
183188
}
184189

185190
override protected def copy(): LinearRegressionModel = {
@@ -204,7 +209,7 @@ private class LeastSquaresAggregator(
204209
val weightsArray = weights.toArray.clone()
205210
var sum = 0.0
206211
var i = 0
207-
while (i < weights.size) {
212+
while (i < weightsArray.length) {
208213
if (featuresStd(i) != 0.0) {
209214
weightsArray(i) /= featuresStd(i)
210215
sum += weightsArray(i) * featuresMean(i)
@@ -215,9 +220,9 @@ private class LeastSquaresAggregator(
215220
}
216221
(weightsArray, -sum, weightsArray.length)
217222
}
218-
219223
private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
220-
private var gradientSumArray: Array[Double] = Array.ofDim[Double](dim)
224+
225+
private val gradientSumArray: Array[Double] = Array.ofDim[Double](dim)
221226

222227
/**
223228
* Add a new training data to this LeastSquaresAggregator, and update the loss and gradient
@@ -258,15 +263,16 @@ private class LeastSquaresAggregator(
258263
* @return This LeastSquaresAggregator object.
259264
*/
260265
def merge(other: LeastSquaresAggregator): this.type = {
266+
require(dim == other.dim, s"Dimensions mismatch when merging with another " +
267+
s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.")
268+
261269
if (this.totalCnt != 0 && other.totalCnt != 0) {
262-
require(dim == other.dim, s"Dimensions mismatch when merging with another summarizer. " +
263-
s"Expecting $dim but got ${other.dim}.")
264270
totalCnt += other.totalCnt
265271
lossSum += other.lossSum
266272
diffSum += other.diffSum
267273

268274
var i = 0
269-
val localThisGradientSumArray = gradientSumArray
275+
val localThisGradientSumArray = this.gradientSumArray
270276
val localOtherGradientSumArray = other.gradientSumArray
271277
while (i < dim) {
272278
localThisGradientSumArray(i) += localOtherGradientSumArray(i)
@@ -276,7 +282,7 @@ private class LeastSquaresAggregator(
276282
this.totalCnt = other.totalCnt
277283
this.lossSum = other.lossSum
278284
this.diffSum = other.diffSum
279-
this.gradientSumArray = other.gradientSumArray.clone
285+
System.arraycopy(other.gradientSumArray, 0, this.gradientSumArray, 0, dim)
280286
}
281287
this
282288
}
@@ -286,12 +292,12 @@ private class LeastSquaresAggregator(
286292
def loss: Double = lossSum / totalCnt
287293

288294
def gradient: Vector = {
289-
val result = Vectors.dense(gradientSumArray.clone)
295+
val result = Vectors.dense(gradientSumArray.clone())
290296

291297
val correction = {
292-
val temp = effectiveWeightsArray.clone
298+
val temp = effectiveWeightsArray.clone()
293299
var i = 0
294-
while (i < temp.size) {
300+
while (i < temp.length) {
295301
temp(i) *= featuresMean(i)
296302
i += 1
297303
}
@@ -324,22 +330,16 @@ private class LeastSquaresCostFun(
324330
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
325331
})
326332

327-
/**
328-
* regVal is sum of weight squares if it's L2 updater;
329-
* for other updater, the same logic is followed.
330-
*/
333+
// regVal is sum of weight squares for L2 regularization
331334
val norm = brzNorm(weights, 2.0)
332335
val regVal = 0.5 * effectiveL2regParam * norm * norm
333336

334337
val loss = leastSquaresAggregator.loss + regVal
335-
// The following gradientTotal is actually the regularization part of gradient.
336-
// Will add the gradientSum computed from the data with weights in the next step.
338+
337339
val gradientTotal = w.copy
338340
scal(effectiveL2regParam, gradientTotal)
339-
340-
// gradientTotal = gradient + gradientTotal
341341
axpy(1.0, leastSquaresAggregator.gradient, gradientTotal)
342342

343343
(loss, gradientTotal.toBreeze.asInstanceOf[BDV[Double]])
344344
}
345-
}
345+
}

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ package org.apache.spark.ml.regression
2020
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.mllib.linalg.DenseVector
23-
import org.apache.spark.mllib.util.TestingUtils._
2423
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
25-
import org.apache.spark.sql.{SQLContext, DataFrame}
24+
import org.apache.spark.mllib.util.TestingUtils._
25+
import org.apache.spark.sql.{Row, SQLContext, DataFrame}
2626

2727
class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
2828

@@ -73,12 +73,11 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
7373
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
7474
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
7575

76-
model.transform(dataset).select("features", "prediction").collect().map {instance =>
77-
val features = instance(0).asInstanceOf[DenseVector].toArray
78-
val prediction1 = instance(1).asInstanceOf[Double]
79-
val prediction2 =
80-
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
81-
assert(prediction1 ~== prediction2 relTol 1E-5)
76+
model.transform(dataset).select("features", "prediction").collect().foreach {
77+
case Row(features: DenseVector, prediction1: Double) =>
78+
val prediction2 =
79+
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
80+
assert(prediction1 ~== prediction2 relTol 1E-5)
8281
}
8382
}
8483

@@ -102,12 +101,11 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
102101
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
103102
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
104103

105-
model.transform(dataset).select("features", "prediction").collect().map {instance =>
106-
val features = instance(0).asInstanceOf[DenseVector].toArray
107-
val prediction1 = instance(1).asInstanceOf[Double]
108-
val prediction2 =
109-
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
110-
assert(prediction1 ~== prediction2 relTol 1E-5)
104+
model.transform(dataset).select("features", "prediction").collect().foreach {
105+
case Row(features: DenseVector, prediction1: Double) =>
106+
val prediction2 =
107+
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
108+
assert(prediction1 ~== prediction2 relTol 1E-5)
111109
}
112110
}
113111

@@ -131,12 +129,11 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
131129
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
132130
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
133131

134-
model.transform(dataset).select("features", "prediction").collect().map {instance =>
135-
val features = instance(0).asInstanceOf[DenseVector].toArray
136-
val prediction1 = instance(1).asInstanceOf[Double]
137-
val prediction2 =
138-
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
139-
assert(prediction1 ~== prediction2 relTol 1E-5)
132+
model.transform(dataset).select("features", "prediction").collect().foreach {
133+
case Row(features: DenseVector, prediction1: Double) =>
134+
val prediction2 =
135+
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
136+
assert(prediction1 ~== prediction2 relTol 1E-5)
140137
}
141138
}
142139

@@ -160,12 +157,11 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
160157
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
161158
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
162159

163-
model.transform(dataset).select("features", "prediction").collect().map { instance =>
164-
val features = instance(0).asInstanceOf[DenseVector].toArray
165-
val prediction1 = instance(1).asInstanceOf[Double]
166-
val prediction2 =
167-
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
168-
assert(prediction1 ~== prediction2 relTol 1E-5)
160+
model.transform(dataset).select("features", "prediction").collect().foreach {
161+
case Row(features: DenseVector, prediction1: Double) =>
162+
val prediction2 =
163+
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
164+
assert(prediction1 ~== prediction2 relTol 1E-5)
169165
}
170166
}
171167
}

0 commit comments

Comments
 (0)