Skip to content

Commit 771baf3

Browse files
committed
chol doc update
1 parent ca9ad9d commit 771baf3

File tree

2 files changed

+27
-6
lines changed
  • mllib/src

2 files changed

+27
-6
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,23 @@ private object ALS extends Logging {
210210
private case class Rating(user: Int, item: Int, rating: Float)
211211

212212
/** Cholesky solver for least square problems. */
213-
private[recommendation] class CholeskySolver(val k: Int) {
214-
215-
val upper = "U"
216-
val info = new intW(0)
213+
private[recommendation] class CholeskySolver {
217214

215+
private val upper = "U"
216+
private val info = new intW(0)
217+
218+
/**
219+
* Solves a least squares problem with L2 regularization:
220+
*
221+
* min norm(A x - b)^2^ + lambda * n * norm(x)^2^
222+
*
223+
* @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances)
224+
* @param lambda regularization constant, which will be scaled by n
225+
* @return the solution x
226+
*/
218227
def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
228+
val k = ne.k
229+
// Add scaled lambda to the diagonals of AtA.
219230
val scaledlambda = lambda * ne.n
220231
var i = 0
221232
var j = 2
@@ -241,9 +252,13 @@ private object ALS extends Logging {
241252
/** Representing a normal equation (ALS' subproblem). */
242253
private[recommendation] class NormalEquation(val k: Int) extends Serializable {
243254

255+
/** Number of entries in the upper triangular part of a k-by-k matrix. */
244256
val triK = k * (k + 1) / 2
257+
/** A^T^ * A */
245258
val ata = new Array[Double](triK)
259+
/** A^T^ * b */
246260
val atb = new Array[Double](k)
261+
/** Number of observations. */
247262
var n = 0
248263

249264
private val da = new Array[Double](k)
@@ -257,6 +272,7 @@ private object ALS extends Logging {
257272
}
258273
}
259274

275+
/** Adds an observation. */
260276
def add(a: Array[Float], b: Float): this.type = {
261277
require(a.size == k)
262278
copyToDouble(a)
@@ -266,6 +282,9 @@ private object ALS extends Logging {
266282
this
267283
}
268284

285+
/**
286+
* Adds an observation with implicit feedback. Note that this does not increment the counter.
287+
*/
269288
def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
270289
require(a.size == k)
271290
val confidence = 1.0 + alpha * math.abs(b)
@@ -277,6 +296,7 @@ private object ALS extends Logging {
277296
this
278297
}
279298

299+
/** Merges another normal equation object. */
280300
def merge(other: NormalEquation): this.type = {
281301
require(other.k == k)
282302
blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1)
@@ -285,6 +305,7 @@ private object ALS extends Logging {
285305
this
286306
}
287307

308+
/** Resets everything to zero, which should be called after each solve. */
288309
def reset(): Unit = {
289310
javaUtil.Arrays.fill(ata, 0.0)
290311
javaUtil.Arrays.fill(atb, 0.0)
@@ -749,7 +770,7 @@ private object ALS extends Logging {
749770
val dstFactors = new Array[Array[Float]](dstIds.size)
750771
var j = 0
751772
val ls = new NormalEquation(k)
752-
val solver = new CholeskySolver(k)
773+
val solver = new CholeskySolver
753774
while (j < dstIds.size) {
754775
ls.reset()
755776
if (implicitPrefs) {

mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
122122
val ne1 = new NormalEquation(k)
123123
.merge(ne0)
124124

125-
val chol = new CholeskySolver(k)
125+
val chol = new CholeskySolver
126126
val x0 = chol.solve(ne0, 0.0).map(_.toDouble)
127127
// NumPy code that computes the expected solution:
128128
// A = np.matrix("1 2; 1 3; 1 4")

0 commit comments

Comments
 (0)