@@ -210,12 +210,23 @@ private object ALS extends Logging {
210210 private case class Rating (user : Int , item : Int , rating : Float )
211211
212212 /** Cholesky solver for least square problems. */
213- private [recommendation] class CholeskySolver (val k : Int ) {
214-
215- val upper = " U"
216- val info = new intW(0 )
213+ private [recommendation] class CholeskySolver {
217214
215+ private val upper = " U"
216+ private val info = new intW(0 )
217+
218+ /**
219+ * Solves a least squares problem with L2 regularization:
220+ *
221+ * min norm(A x - b)^2^ + lambda * n * norm(x)^2^
222+ *
223+ * @param ne a [[NormalEquation ]] instance that contains AtA, Atb, and n (number of instances)
224+ * @param lambda regularization constant, which will be scaled by n
225+ * @return the solution x
226+ */
218227 def solve (ne : NormalEquation , lambda : Double ): Array [Float ] = {
228+ val k = ne.k
229+ // Add scaled lambda to the diagonals of AtA.
219230 val scaledlambda = lambda * ne.n
220231 var i = 0
221232 var j = 2
@@ -241,9 +252,13 @@ private object ALS extends Logging {
241252 /** Representing a normal equation (ALS' subproblem). */
242253 private [recommendation] class NormalEquation (val k : Int ) extends Serializable {
243254
255+ /** Number of entries in the upper triangular part of a k-by-k matrix. */
244256 val triK = k * (k + 1 ) / 2
257+ /** A^T^ * A */
245258 val ata = new Array [Double ](triK)
259+ /** A^T^ * b */
246260 val atb = new Array [Double ](k)
261+ /** Number of observations. */
247262 var n = 0
248263
249264 private val da = new Array [Double ](k)
@@ -257,6 +272,7 @@ private object ALS extends Logging {
257272 }
258273 }
259274
275+ /** Adds an observation. */
260276 def add (a : Array [Float ], b : Float ): this .type = {
261277 require(a.size == k)
262278 copyToDouble(a)
@@ -266,6 +282,9 @@ private object ALS extends Logging {
266282 this
267283 }
268284
285+ /**
286+ * Adds an observation with implicit feedback. Note that this does not increment the counter.
287+ */
269288 def addImplicit (a : Array [Float ], b : Float , alpha : Double ): this .type = {
270289 require(a.size == k)
271290 val confidence = 1.0 + alpha * math.abs(b)
@@ -277,6 +296,7 @@ private object ALS extends Logging {
277296 this
278297 }
279298
299+ /** Merges another normal equation object. */
280300 def merge (other : NormalEquation ): this .type = {
281301 require(other.k == k)
282302 blas.daxpy(ata.size, 1.0 , other.ata, 1 , ata, 1 )
@@ -285,6 +305,7 @@ private object ALS extends Logging {
285305 this
286306 }
287307
308+ /** Resets everything to zero, which should be called after each solve. */
288309 def reset (): Unit = {
289310 javaUtil.Arrays .fill(ata, 0.0 )
290311 javaUtil.Arrays .fill(atb, 0.0 )
@@ -749,7 +770,7 @@ private object ALS extends Logging {
749770 val dstFactors = new Array [Array [Float ]](dstIds.size)
750771 var j = 0
751772 val ls = new NormalEquation (k)
752- val solver = new CholeskySolver (k)
773+ val solver = new CholeskySolver
753774 while (j < dstIds.size) {
754775 ls.reset()
755776 if (implicitPrefs) {
0 commit comments