Skip to content

Commit bc04fa2

Browse files
committed
[SPARK-6642][MLLIB] use 1.2 lambda scaling and remove addImplicit from NormalEquation
This PR changes lambda scaling from number of users/items to number of explicit ratings. The latter is the behavior in 1.2. Slight refactor of NormalEquation to make it independent of ALS models. srowen codexiang Author: Xiangrui Meng <[email protected]> Closes #5314 from mengxr/SPARK-6642 and squashes the following commits: dc655a1 [Xiangrui Meng] relax python tests f410df2 [Xiangrui Meng] use 1.2 scaling and remove addImplicit from NormalEquation (cherry picked from commit ccafd75) Signed-off-by: Xiangrui Meng <[email protected]> Conflicts: mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
1 parent 1c31ebd commit bc04fa2

File tree

3 files changed

+60
-84
lines changed

3 files changed

+60
-84
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ object ALS extends Logging {
321321

322322
/** Trait for least squares solvers applied to the normal equation. */
323323
private[recommendation] trait LeastSquaresNESolver extends Serializable {
324-
/** Solves a least squares problem (possibly with other constraints). */
324+
/** Solves a least squares problem with regularization (possibly with other constraints). */
325325
def solve(ne: NormalEquation, lambda: Double): Array[Float]
326326
}
327327

@@ -333,20 +333,19 @@ object ALS extends Logging {
333333
/**
334334
* Solves a least squares problem with L2 regularization:
335335
*
336-
* min norm(A x - b)^2^ + lambda * n * norm(x)^2^
336+
* min norm(A x - b)^2^ + lambda * norm(x)^2^
337337
*
338338
* @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances)
339-
* @param lambda regularization constant, which will be scaled by n
339+
* @param lambda regularization constant
340340
* @return the solution x
341341
*/
342342
override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
343343
val k = ne.k
344344
// Add scaled lambda to the diagonals of AtA.
345-
val scaledlambda = lambda * ne.n
346345
var i = 0
347346
var j = 2
348347
while (i < ne.triK) {
349-
ne.ata(i) += scaledlambda
348+
ne.ata(i) += lambda
350349
i += j
351350
j += 1
352351
}
@@ -392,7 +391,7 @@ object ALS extends Logging {
392391
override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
393392
val rank = ne.k
394393
initialize(rank)
395-
fillAtA(ne.ata, lambda * ne.n)
394+
fillAtA(ne.ata, lambda)
396395
val x = NNLS.solve(ata, new DoubleMatrix(rank, 1, ne.atb: _*), workspace)
397396
ne.reset()
398397
x.map(x => x.toFloat)
@@ -422,7 +421,15 @@ object ALS extends Logging {
422421
}
423422
}
424423

425-
/** Representing a normal equation (ALS' subproblem). */
424+
/**
425+
* Representing a normal equation to solve the following weighted least squares problem:
426+
*
427+
* minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x.
428+
*
429+
* Its normal equation is given by
430+
*
431+
* \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0.
432+
*/
426433
private[recommendation] class NormalEquation(val k: Int) extends Serializable {
427434

428435
/** Number of entries in the upper triangular part of a k-by-k matrix. */
@@ -431,8 +438,6 @@ object ALS extends Logging {
431438
val ata = new Array[Double](triK)
432439
/** A^T^ * b */
433440
val atb = new Array[Double](k)
434-
/** Number of observations. */
435-
var n = 0
436441

437442
private val da = new Array[Double](k)
438443
private val upper = "U"
@@ -446,28 +451,13 @@ object ALS extends Logging {
446451
}
447452

448453
/** Adds an observation. */
449-
def add(a: Array[Float], b: Float): this.type = {
450-
require(a.length == k)
451-
copyToDouble(a)
452-
blas.dspr(upper, k, 1.0, da, 1, ata)
453-
blas.daxpy(k, b.toDouble, da, 1, atb, 1)
454-
n += 1
455-
this
456-
}
457-
458-
/**
459-
* Adds an observation with implicit feedback. Note that this does not increment the counter.
460-
*/
461-
def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
454+
def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
455+
require(c >= 0.0)
462456
require(a.length == k)
463-
// Extension to the original paper to handle b < 0. confidence is a function of |b| instead
464-
// so that it is never negative.
465-
val confidence = 1.0 + alpha * math.abs(b)
466457
copyToDouble(a)
467-
blas.dspr(upper, k, confidence - 1.0, da, 1, ata)
468-
// For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0.
469-
if (b > 0) {
470-
blas.daxpy(k, confidence, da, 1, atb, 1)
458+
blas.dspr(upper, k, c, da, 1, ata)
459+
if (b != 0.0) {
460+
blas.daxpy(k, c * b, da, 1, atb, 1)
471461
}
472462
this
473463
}
@@ -477,15 +467,13 @@ object ALS extends Logging {
477467
require(other.k == k)
478468
blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1)
479469
blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1)
480-
n += other.n
481470
this
482471
}
483472

484473
/** Resets everything to zero, which should be called after each solve. */
485474
def reset(): Unit = {
486475
ju.Arrays.fill(ata, 0.0)
487476
ju.Arrays.fill(atb, 0.0)
488-
n = 0
489477
}
490478
}
491479

@@ -1116,20 +1104,31 @@ object ALS extends Logging {
11161104
ls.merge(YtY.get)
11171105
}
11181106
var i = srcPtrs(j)
1107+
var numExplicits = 0
11191108
while (i < srcPtrs(j + 1)) {
11201109
val encoded = srcEncodedIndices(i)
11211110
val blockId = srcEncoder.blockId(encoded)
11221111
val localIndex = srcEncoder.localIndex(encoded)
11231112
val srcFactor = sortedSrcFactors(blockId)(localIndex)
11241113
val rating = ratings(i)
11251114
if (implicitPrefs) {
1126-
ls.addImplicit(srcFactor, rating, alpha)
1115+
// Extension to the original paper to handle b < 0. confidence is a function of |b|
1116+
// instead so that it is never negative. c1 is confidence - 1.0.
1117+
val c1 = alpha * math.abs(rating)
1118+
// For rating <= 0, the corresponding preference is 0. So the term below is only added
1119+
// for rating > 0. Because YtY is already added, we need to adjust the scaling here.
1120+
if (rating > 0) {
1121+
numExplicits += 1
1122+
ls.add(srcFactor, (c1 + 1.0) / c1, c1)
1123+
}
11271124
} else {
11281125
ls.add(srcFactor, rating)
1126+
numExplicits += 1
11291127
}
11301128
i += 1
11311129
}
1132-
dstFactors(j) = solver.solve(ls, regParam)
1130+
// Weight lambda by the number of explicit ratings based on the ALS-WR paper.
1131+
dstFactors(j) = solver.solve(ls, numExplicits * regParam)
11331132
j += 1
11341133
}
11351134
dstFactors
@@ -1143,7 +1142,7 @@ object ALS extends Logging {
11431142
private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = {
11441143
factorBlocks.values.aggregate(new NormalEquation(rank))(
11451144
seqOp = (ne, factors) => {
1146-
factors.foreach(ne.add(_, 0.0f))
1145+
factors.foreach(ne.add(_, 0.0))
11471146
ne
11481147
},
11491148
combOp = (ne1, ne2) => ne1.merge(ne2))

mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala

Lines changed: 24 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -68,81 +68,59 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
6868
}
6969
}
7070

71-
test("normal equation construction with explict feedback") {
71+
test("normal equation construction") {
7272
val k = 2
7373
val ne0 = new NormalEquation(k)
74-
.add(Array(1.0f, 2.0f), 3.0f)
75-
.add(Array(4.0f, 5.0f), 6.0f)
74+
.add(Array(1.0f, 2.0f), 3.0)
75+
.add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted
7676
assert(ne0.k === k)
7777
assert(ne0.triK === k * (k + 1) / 2)
78-
assert(ne0.n === 2)
7978
// NumPy code that computes the expected values:
8079
// A = np.matrix("1 2; 4 5")
8180
// b = np.matrix("3; 6")
82-
// ata = A.transpose() * A
83-
// atb = A.transpose() * b
84-
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8)
85-
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8)
81+
// C = np.matrix(np.diag([1, 2]))
82+
// ata = A.transpose() * C * A
83+
// atb = A.transpose() * C * b
84+
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(33.0, 42.0, 54.0) relTol 1e-8)
85+
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(51.0, 66.0) relTol 1e-8)
8686

8787
val ne1 = new NormalEquation(2)
88-
.add(Array(7.0f, 8.0f), 9.0f)
88+
.add(Array(7.0f, 8.0f), 9.0)
8989
ne0.merge(ne1)
90-
assert(ne0.n === 3)
9190
// NumPy code that computes the expected values:
9291
// A = np.matrix("1 2; 4 5; 7 8")
9392
// b = np.matrix("3; 6; 9")
94-
// ata = A.transpose() * A
95-
// atb = A.transpose() * b
96-
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8)
97-
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8)
93+
// C = np.matrix(np.diag([1, 2, 1]))
94+
// ata = A.transpose() * C * A
95+
// atb = A.transpose() * C * b
96+
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(82.0, 98.0, 118.0) relTol 1e-8)
97+
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(114.0, 138.0) relTol 1e-8)
9898

9999
intercept[IllegalArgumentException] {
100-
ne0.add(Array(1.0f), 2.0f)
100+
ne0.add(Array(1.0f), 2.0)
101101
}
102102
intercept[IllegalArgumentException] {
103-
ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f)
103+
ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0)
104+
}
105+
intercept[IllegalArgumentException] {
106+
ne0.add(Array(1.0f, 2.0f), 0.0, -1.0)
104107
}
105108
intercept[IllegalArgumentException] {
106109
val ne2 = new NormalEquation(3)
107110
ne0.merge(ne2)
108111
}
109112

110113
ne0.reset()
111-
assert(ne0.n === 0)
112114
assert(ne0.ata.forall(_ == 0.0))
113115
assert(ne0.atb.forall(_ == 0.0))
114116
}
115117

116-
test("normal equation construction with implicit feedback") {
117-
val k = 2
118-
val alpha = 0.5
119-
val ne0 = new NormalEquation(k)
120-
.addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha)
121-
.addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha)
122-
.addImplicit(Array(1.0f, 2.0f), 3.0f, alpha)
123-
assert(ne0.k === k)
124-
assert(ne0.triK === k * (k + 1) / 2)
125-
assert(ne0.n === 0) // addImplicit doesn't increase the count.
126-
// NumPy code that computes the expected values:
127-
// alpha = 0.5
128-
// A = np.matrix("-5 -4; -2 -1; 1 2")
129-
// b = np.matrix("-3; 0; 3")
130-
// b1 = b > 0
131-
// c = 1.0 + alpha * np.abs(b)
132-
// C = np.diag(c.A1)
133-
// I = np.eye(3)
134-
// ata = A.transpose() * (C - I) * A
135-
// atb = A.transpose() * C * b1
136-
assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8)
137-
assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8)
138-
}
139-
140118
test("CholeskySolver") {
141119
val k = 2
142120
val ne0 = new NormalEquation(k)
143-
.add(Array(1.0f, 2.0f), 4.0f)
144-
.add(Array(1.0f, 3.0f), 9.0f)
145-
.add(Array(1.0f, 4.0f), 16.0f)
121+
.add(Array(1.0f, 2.0f), 4.0)
122+
.add(Array(1.0f, 3.0f), 9.0)
123+
.add(Array(1.0f, 4.0f), 16.0)
146124
val ne1 = new NormalEquation(k)
147125
.merge(ne0)
148126

@@ -154,13 +132,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
154132
// x0 = np.linalg.lstsq(A, b)[0]
155133
assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6)
156134

157-
assert(ne0.n === 0)
158135
assert(ne0.ata.forall(_ == 0.0))
159136
assert(ne0.atb.forall(_ == 0.0))
160137

161-
val x1 = chol.solve(ne1, 0.5).map(_.toDouble)
138+
val x1 = chol.solve(ne1, 1.5).map(_.toDouble)
162139
// NumPy code that computes the expected solution, where lambda is scaled by n:
163-
// x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b)
140+
// x0 = np.linalg.solve(A.transpose() * A + 1.5 * np.eye(2), A.transpose() * b)
164141
assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6)
165142
}
166143

python/pyspark/mllib/recommendation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
5252
>>> ratings = sc.parallelize([r1, r2, r3])
5353
>>> model = ALS.trainImplicit(ratings, 1, seed=10)
5454
>>> model.predict(2, 2)
55-
0.43...
55+
0.4...
5656
5757
>>> testset = sc.parallelize([(1, 2), (1, 1)])
5858
>>> model = ALS.train(ratings, 2, seed=0)
@@ -82,14 +82,14 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
8282
8383
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
8484
>>> model.predict(2,2)
85-
0.43...
85+
0.4...
8686
8787
>>> import os, tempfile
8888
>>> path = tempfile.mkdtemp()
8989
>>> model.save(sc, path)
9090
>>> sameModel = MatrixFactorizationModel.load(sc, path)
9191
>>> sameModel.predict(2,2)
92-
0.43...
92+
0.4...
9393
>>> sameModel.predictAll(testset).collect()
9494
[Rating(...
9595
>>> try:

0 commit comments

Comments
 (0)