Skip to content

Commit 2178b63

Browse files
author
DB Tsai
committed
add comments
1 parent 9988ca8 commit 2178b63

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,23 @@ class LinearRegressionModel private[ml] (
191191
}
192192
}
193193

194+
/**
195+
* LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function,
196+
* as used in linear regression for samples in sparse or dense vector in a online fashion.
197+
*
198+
* Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of
199+
* the corresponding joint dataset.
200+
*
201+
202+
* * Compute gradient and loss for a Least-squared loss function, as used in linear regression.
203+
* This is correct for the averaged least squares loss function (mean squared error)
204+
* L = 1/2n ||A weights-y||^2
205+
* See also the documentation for the precise formulation.
206+
*
207+
* @param weights weights/coefficients corresponding to features
208+
*
209+
* @param updater Updater to be used to update weights after every iteration.
210+
*/
194211
private class LeastSquaresAggregator(
195212
weights: Vector,
196213
labelStd: Double,
@@ -302,6 +319,11 @@ private class LeastSquaresAggregator(
302319
}
303320
}
304321

322+
/**
323+
* LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost.
324+
* It returns the loss and gradient with L2 regularization at a particular point (weights).
325+
* It's used in Breeze's convex optimization routines.
326+
*/
305327
private class LeastSquaresCostFun(
306328
data: RDD[(Double, Vector)],
307329
labelStd: Double,
@@ -322,7 +344,7 @@ private class LeastSquaresCostFun(
322344
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
323345
})
324346

325-
// regVal is sum of weight squares for L2 regularization
347+
// regVal is the sum of weight squares for L2 regularization
326348
val norm = brzNorm(weights, 2.0)
327349
val regVal = 0.5 * effectiveL2regParam * norm * norm
328350

0 commit comments

Comments
 (0)