@@ -267,10 +267,20 @@ class LinearRegressionModel private[ml] (
267267 * }}},
268268 * where correction_i = - diffSum \bar{x_i}) / \hat{x_i}
269269 *
270- * As a result, the first term in the first derivative of the total objective function
271- * depends on training dataset, which can be computed in distributed fashion, and is sparse
272- * format friendly. We only have to loop through the whole gradientSum vector one time
273- * in the end for adding the correction terms back (in the driver, not in the executors).
270+ * A simple math can show that diffSum is actually zero, so we don't even
271+ * need to add the correction terms in the end. From the definition of diff,
272+ * {{{
273+ * diffSum = \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) / \hat{x_i} - (y_j - \bar{y}) / \hat{y})
274+ * = N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y_j} - \bar{y}) / \hat{y})
275+ * = 0
276+ * }}}
277+ *
278+ * As a result, the first derivative of the total objective function only depends on
279+ * the training dataset, which can be easily computed in distributed fashion, and is
280+ * sparse format friendly.
281+ * {{{
282+ * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i})
283+ * }}},
274284 *
275285 * @param weights The weights/coefficients corresponding to the features.
276286 * @param labelStd The standard deviation value of the label.
@@ -371,17 +381,7 @@ private class LeastSquaresAggregator(
371381 def loss : Double = lossSum / totalCnt
372382
373383 def gradient : Vector = {
374- val resultArray = gradientSumArray.clone()
375-
376- // Adding the correction terms back to gradientSum;
377- // see the mathematical derivation for detail.
378- Vectors .dense(featuresMean).foreachActive { (index, value) =>
379- if (featuresStd(index) != 0.0 && value != 0.0 ) {
380- resultArray(index) -= diffSum * value / featuresStd(index)
381- }
382- }
383-
384- val result = Vectors .dense(resultArray)
384+ val result = Vectors .dense(gradientSumArray.clone())
385385 scal(1.0 / totalCnt, result)
386386 result
387387 }
0 commit comments