Skip to content

Commit 15995c8

Browse files
DB Tsaimengxr
authored andcommitted
[SPARK-7222] [ML] Added mathematical derivation in comment and compressed the model, removed the correction terms in LinearRegression with ElasticNet
Added detailed mathematical derivation of how scaling and LeastSquaresAggregator work. Refactored the code so the model is compressed based on the storage. We may try compression based on the prediction time. Also, I found that diffSum will be always zero mathematically, so no corrections are required. Author: DB Tsai <[email protected]> Closes #5767 from dbtsai/lir-doc and squashes the following commits: 5e346c9 [DB Tsai] refactoring fc9f582 [DB Tsai] doc 58456d8 [DB Tsai] address feedback 69757b8 [DB Tsai] actually diffSum is mathematically zero! No correction is needed. 5929e49 [DB Tsai] typo 63f7d1e [DB Tsai] Added compression to the model based on storage 203a295 [DB Tsai] Add more documentation to LinearRegression in new ML framework.
1 parent 3a180c1 commit 15995c8

File tree

3 files changed

+101
-37
lines changed

3 files changed

+101
-37
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 90 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
3434
import org.apache.spark.sql.DataFrame
3535
import org.apache.spark.storage.StorageLevel
3636
import org.apache.spark.util.StatCounter
37+
import org.apache.spark.Logging
3738

3839
/**
3940
* Params for linear regression.
@@ -48,7 +49,7 @@ private[regression] trait LinearRegressionParams extends RegressorParams
4849
*/
4950
@AlphaComponent
5051
class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
51-
with LinearRegressionParams {
52+
with LinearRegressionParams with Logging {
5253

5354
/**
5455
* Set the regularization parameter.
@@ -110,6 +111,15 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
110111
val yMean = statCounter.mean
111112
val yStd = math.sqrt(statCounter.variance)
112113

114+
// If the yStd is zero, then the intercept is yMean with zero weights;
115+
// as a result, training is not needed.
116+
if (yStd == 0.0) {
117+
logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
118+
s"and the intercept will be the mean of the label; as a result, training is not needed.")
119+
if (handlePersistence) instances.unpersist()
120+
return new LinearRegressionModel(this, paramMap, Vectors.sparse(numFeatures, Seq()), yMean)
121+
}
122+
113123
val featuresMean = summarizer.mean.toArray
114124
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
115125

@@ -141,7 +151,6 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
141151
}
142152
lossHistory += state.value
143153

144-
// TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector.
145154
// The weights are trained in the scaled space; we're converting them back to
146155
// the original space.
147156
val weights = {
@@ -158,11 +167,10 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
158167
// converged. See the following discussion for detail.
159168
// http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
160169
val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
170+
if (handlePersistence) instances.unpersist()
161171

162-
if (handlePersistence) {
163-
instances.unpersist()
164-
}
165-
new LinearRegressionModel(this, paramMap, weights, intercept)
172+
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
173+
new LinearRegressionModel(this, paramMap, weights.compressed, intercept)
166174
}
167175
}
168176

@@ -198,15 +206,84 @@ class LinearRegressionModel private[ml] (
198206
* Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of
199207
* the corresponding joint dataset.
200208
*
201-
202-
* * Compute gradient and loss for a Least-squared loss function, as used in linear regression.
203-
* This is correct for the averaged least squares loss function (mean squared error)
204-
* L = 1/2n ||A weights-y||^2
205-
* See also the documentation for the precise formulation.
209+
* For improving the convergence rate during the optimization process, and also preventing against
210+
* features with very large variances exerting an overly large influence during model training,
211+
* package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce
212+
* the condition number, and then trains the model in scaled space but returns the weights in
213+
* the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
214+
*
215+
* However, we don't want to apply the `StandardScaler` on the training dataset, and then cache
216+
* the standardized dataset since it will create a lot of overhead. As a result, we perform the
217+
* scaling implicitly when we compute the objective function. The following is the mathematical
218+
* derivation.
219+
*
220+
* Note that we don't deal with intercept by adding bias here, because the intercept
221+
* can be computed using closed form after the coefficients are converged.
222+
* See this discussion for detail.
223+
* http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
224+
*
225+
* The objective function in the scaled space is given by
226+
* {{{
227+
* L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2,
228+
* }}}
229+
* where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i,
230+
* \bar{y} is the mean of label, and \hat{y} is the standard deviation of label.
231+
*
232+
* This can be rewritten as
233+
* {{{
234+
* L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y}
235+
* + \bar{y} / \hat{y}||^2
236+
* = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2
237+
* }}}
238+
* where w_i^\prime is the effective weights defined by w_i/\hat{x_i}, offset is
239+
* {{{
240+
* - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}.
241+
* }}}, and diff is
242+
* {{{
243+
* \sum_i w_i^\prime x_i - y / \hat{y} + offset
244+
* }}}
206245
*
207-
* @param weights weights/coefficients corresponding to features
246+
* Note that the effective weights and offset don't depend on training dataset,
247+
* so they can be precomputed.
208248
*
209-
* @param updater Updater to be used to update weights after every iteration.
249+
* Now, the first derivative of the objective function in scaled space is
250+
* {{{
251+
* \frac{\partial L}{\partial\w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i}
252+
* }}}
253+
* However, ($x_i - \bar{x_i}$) will densify the computation, so it's not
254+
* an ideal formula when the training dataset is sparse format.
255+
*
256+
* This can be addressed by adding the dense \bar{x_i} / \har{x_i} terms
257+
* in the end by keeping the sum of diff. The first derivative of total
258+
* objective function from all the samples is
259+
* {{{
260+
* \frac{\partial L}{\partial\w_i} =
261+
* 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i}
262+
* = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i}) / \hat{x_i})
263+
* = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i)
264+
* }}},
265+
* where correction_i = - diffSum \bar{x_i}) / \hat{x_i}
266+
*
267+
* A simple math can show that diffSum is actually zero, so we don't even
268+
* need to add the correction terms in the end. From the definition of diff,
269+
* {{{
270+
* diffSum = \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) / \hat{x_i} - (y_j - \bar{y}) / \hat{y})
271+
* = N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y_j} - \bar{y}) / \hat{y})
272+
* = 0
273+
* }}}
274+
*
275+
* As a result, the first derivative of the total objective function only depends on
276+
* the training dataset, which can be easily computed in distributed fashion, and is
277+
* sparse format friendly.
278+
* {{{
279+
* \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i})
280+
* }}},
281+
*
282+
* @param weights The weights/coefficients corresponding to the features.
283+
* @param labelStd The standard deviation value of the label.
284+
* @param labelMean The mean value of the label.
285+
* @param featuresStd The standard deviation values of the features.
286+
* @param featuresMean The mean values of the features.
210287
*/
211288
private class LeastSquaresAggregator(
212289
weights: Vector,
@@ -302,18 +379,6 @@ private class LeastSquaresAggregator(
302379

303380
def gradient: Vector = {
304381
val result = Vectors.dense(gradientSumArray.clone())
305-
306-
val correction = {
307-
val temp = effectiveWeightsArray.clone()
308-
var i = 0
309-
while (i < temp.length) {
310-
temp(i) *= featuresMean(i)
311-
i += 1
312-
}
313-
Vectors.dense(temp)
314-
}
315-
316-
axpy(-diffSum, correction, result)
317382
scal(1.0 / totalCnt, result)
318383
result
319384
}

mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
225225
throw new SparkException("Input validation failed.")
226226
}
227227

228-
/*
228+
/**
229229
* Scaling columns to unit variance as a heuristic to reduce the condition number:
230230
*
231231
* During the optimization process, the convergence (rate) depends on the condition number of

mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,16 @@ object LinearDataGenerator {
103103

104104
val rnd = new Random(seed)
105105
val x = Array.fill[Array[Double]](nPoints)(
106-
Array.fill[Double](weights.length)(rnd.nextDouble))
107-
108-
x.map(vector => {
109-
// This doesn't work if `vector` is a sparse vector.
110-
val vectorArray = vector.toArray
111-
var i = 0
112-
while (i < vectorArray.size) {
113-
vectorArray(i) = (vectorArray(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
114-
i += 1
115-
}
116-
})
106+
Array.fill[Double](weights.length)(rnd.nextDouble()))
107+
108+
x.foreach {
109+
case v =>
110+
var i = 0
111+
while (i < v.length) {
112+
v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
113+
i += 1
114+
}
115+
}
117116

118117
val y = x.map { xi =>
119118
blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian()

0 commit comments

Comments
 (0)