Skip to content

Commit 9a5d482

Browse files
tmyklebumengxr
authored andcommitted
[SPARK-1553] Alternating nonnegative least-squares
This pull request includes a nonnegative least-squares solver (NNLS) tailored to the kinds of small-scale problems that come up when training matrix factorisation models by alternating nonnegative least-squares (ANNLS). The method used for the NNLS subproblems is based on the classical method of projected gradients. There is a modification where, if the set of active constraints has not changed since the last iteration, a conjugate gradient step is considered and possibly rejected in favour of the gradient; this improves convergence once the optimal face has been located. The NNLS solver is in `org.apache.spark.mllib.optimization.NNLSbyPCG`. Author: Tor Myklebust <[email protected]> Closes apache#460 from tmyklebu/annls and squashes the following commits: 79bc4b5 [Tor Myklebust] Merge branch 'master' of https://github.com/apache/spark into annls 199b0bc [Tor Myklebust] Make the ctor private again and use the builder pattern. 7fbabf1 [Tor Myklebust] Cleanup matrix math in NNLSSuite. 65ef7f2 [Tor Myklebust] Make ALS's ctor public and remove a couple of "convenience" wrappers. 2d4f3cb [Tor Myklebust] Cleanup. 0cb4481 [Tor Myklebust] Drop the iteration limit from 40k to max(400,20n). e2a01d1 [Tor Myklebust] Create a workspace object for NNLS to cut down on memory allocations. b285106 [Tor Myklebust] Clean up NNLS test cases. 9c820b6 [Tor Myklebust] Tweak variable names. 8a1a436 [Tor Myklebust] Describe the problem and add a reference to Polyak's paper. 5345402 [Tor Myklebust] Style fixes that got eaten. ac673bd [Tor Myklebust] More safeguards against numerical ridiculousness. c288b6a [Tor Myklebust] Finish moving the NNLS solver. 9a82fa6 [Tor Myklebust] Fix scalastyle moanings. 33bf4f2 [Tor Myklebust] Fix missing space. 89ea0a8 [Tor Myklebust] Hack ALSSuite to support NNLS testing. f5dbf4d [Tor Myklebust] Teach ALS how to use the NNLS solver. 6cb563c [Tor Myklebust] Tests for the nonnegative least squares solver. a68ac10 [Tor Myklebust] A nonnegative least-squares solver.
1 parent 9535f40 commit 9a5d482

File tree

4 files changed

+300
-14
lines changed

4 files changed

+300
-14
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.optimization
19+
20+
import org.jblas.{DoubleMatrix, SimpleBlas}
21+
22+
import org.apache.spark.annotation.DeveloperApi
23+
24+
/**
25+
* Object used to solve nonnegative least squares problems using a modified
26+
* projected gradient method.
27+
*/
28+
private[mllib] object NNLS {
29+
class Workspace(val n: Int) {
30+
val scratch = new DoubleMatrix(n, 1)
31+
val grad = new DoubleMatrix(n, 1)
32+
val x = new DoubleMatrix(n, 1)
33+
val dir = new DoubleMatrix(n, 1)
34+
val lastDir = new DoubleMatrix(n, 1)
35+
val res = new DoubleMatrix(n, 1)
36+
37+
def wipe() {
38+
scratch.fill(0.0)
39+
grad.fill(0.0)
40+
x.fill(0.0)
41+
dir.fill(0.0)
42+
lastDir.fill(0.0)
43+
res.fill(0.0)
44+
}
45+
}
46+
47+
def createWorkspace(n: Int): Workspace = {
48+
new Workspace(n)
49+
}
50+
51+
/**
52+
* Solve a least squares problem, possibly with nonnegativity constraints, by a modified
53+
* projected gradient method. That is, find x minimising ||Ax - b||_2 given A^T A and A^T b.
54+
*
55+
* We solve the problem
56+
* min_x 1/2 x^T ata x^T - x^T atb
57+
* subject to x >= 0
58+
*
59+
* The method used is similar to one described by Polyak (B. T. Polyak, The conjugate gradient
60+
* method in extremal problems, Zh. Vychisl. Mat. Mat. Fiz. 9(4)(1969), pp. 94-112) for bound-
61+
* constrained nonlinear programming. Polyak unconditionally uses a conjugate gradient
62+
* direction, however, while this method only uses a conjugate gradient direction if the last
63+
* iteration did not cause a previously-inactive constraint to become active.
64+
*/
65+
def solve(ata: DoubleMatrix, atb: DoubleMatrix, ws: Workspace): Array[Double] = {
66+
ws.wipe()
67+
68+
val n = atb.rows
69+
val scratch = ws.scratch
70+
71+
// find the optimal unconstrained step
72+
def steplen(dir: DoubleMatrix, res: DoubleMatrix): Double = {
73+
val top = SimpleBlas.dot(dir, res)
74+
SimpleBlas.gemv(1.0, ata, dir, 0.0, scratch)
75+
// Push the denominator upward very slightly to avoid infinities and silliness
76+
top / (SimpleBlas.dot(scratch, dir) + 1e-20)
77+
}
78+
79+
// stopping condition
80+
def stop(step: Double, ndir: Double, nx: Double): Boolean = {
81+
((step.isNaN) // NaN
82+
|| (step < 1e-6) // too small or negative
83+
|| (step > 1e40) // too small; almost certainly numerical problems
84+
|| (ndir < 1e-12 * nx) // gradient relatively too small
85+
|| (ndir < 1e-32) // gradient absolutely too small; numerical issues may lurk
86+
)
87+
}
88+
89+
val grad = ws.grad
90+
val x = ws.x
91+
val dir = ws.dir
92+
val lastDir = ws.lastDir
93+
val res = ws.res
94+
val iterMax = Math.max(400, 20 * n)
95+
var lastNorm = 0.0
96+
var iterno = 0
97+
var lastWall = 0 // Last iteration when we hit a bound constraint.
98+
var i = 0
99+
while (iterno < iterMax) {
100+
// find the residual
101+
SimpleBlas.gemv(1.0, ata, x, 0.0, res)
102+
SimpleBlas.axpy(-1.0, atb, res)
103+
SimpleBlas.copy(res, grad)
104+
105+
// project the gradient
106+
i = 0
107+
while (i < n) {
108+
if (grad.data(i) > 0.0 && x.data(i) == 0.0) {
109+
grad.data(i) = 0.0
110+
}
111+
i = i + 1
112+
}
113+
val ngrad = SimpleBlas.dot(grad, grad)
114+
115+
SimpleBlas.copy(grad, dir)
116+
117+
// use a CG direction under certain conditions
118+
var step = steplen(grad, res)
119+
var ndir = 0.0
120+
val nx = SimpleBlas.dot(x, x)
121+
if (iterno > lastWall + 1) {
122+
val alpha = ngrad / lastNorm
123+
SimpleBlas.axpy(alpha, lastDir, dir)
124+
val dstep = steplen(dir, res)
125+
ndir = SimpleBlas.dot(dir, dir)
126+
if (stop(dstep, ndir, nx)) {
127+
// reject the CG step if it could lead to premature termination
128+
SimpleBlas.copy(grad, dir)
129+
ndir = SimpleBlas.dot(dir, dir)
130+
} else {
131+
step = dstep
132+
}
133+
} else {
134+
ndir = SimpleBlas.dot(dir, dir)
135+
}
136+
137+
// terminate?
138+
if (stop(step, ndir, nx)) {
139+
return x.data.clone
140+
}
141+
142+
// don't run through the walls
143+
i = 0
144+
while (i < n) {
145+
if (step * dir.data(i) > x.data(i)) {
146+
step = x.data(i) / dir.data(i)
147+
}
148+
i = i + 1
149+
}
150+
151+
// take the step
152+
i = 0
153+
while (i < n) {
154+
if (step * dir.data(i) > x.data(i) * (1 - 1e-14)) {
155+
x.data(i) = 0
156+
lastWall = iterno
157+
} else {
158+
x.data(i) -= step * dir.data(i)
159+
}
160+
i = i + 1
161+
}
162+
163+
iterno = iterno + 1
164+
SimpleBlas.copy(dir, lastDir)
165+
lastNorm = ngrad
166+
}
167+
x.data.clone
168+
}
169+
}

mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.storage.StorageLevel
3232
import org.apache.spark.rdd.RDD
3333
import org.apache.spark.SparkContext._
3434
import org.apache.spark.util.Utils
35+
import org.apache.spark.mllib.optimization.NNLS
3536

3637
/**
3738
* Out-link information for a user or product block. This includes the original user/product IDs
@@ -156,6 +157,18 @@ class ALS private (
156157
this
157158
}
158159

160+
/** If true, do alternating nonnegative least squares. */
161+
private var nonnegative = false
162+
163+
/**
164+
* Set whether the least-squares problems solved at each iteration should have
165+
* nonnegativity constraints.
166+
*/
167+
def setNonnegative(b: Boolean): ALS = {
168+
this.nonnegative = b
169+
this
170+
}
171+
159172
/**
160173
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
161174
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
@@ -505,6 +518,8 @@ class ALS private (
505518
}
506519
}
507520

521+
val ws = if (nonnegative) NNLS.createWorkspace(rank) else null
522+
508523
// Solve the least-squares problem for each user and return the new feature vectors
509524
Array.range(0, numUsers).map { index =>
510525
// Compute the full XtX matrix from the lower-triangular part we got above
@@ -517,13 +532,26 @@ class ALS private (
517532
}
518533
// Solve the resulting matrix, which is symmetric and positive-definite
519534
if (implicitPrefs) {
520-
Solve.solvePositive(fullXtX.addi(YtY.get.value), userXy(index)).data
535+
solveLeastSquares(fullXtX.addi(YtY.get.value), userXy(index), ws)
521536
} else {
522-
Solve.solvePositive(fullXtX, userXy(index)).data
537+
solveLeastSquares(fullXtX, userXy(index), ws)
523538
}
524539
}
525540
}
526541

542+
/**
543+
* Given A^T A and A^T b, find the x minimising ||Ax - b||_2, possibly subject
544+
* to nonnegativity constraints if `nonnegative` is true.
545+
*/
546+
def solveLeastSquares(ata: DoubleMatrix, atb: DoubleMatrix,
547+
ws: NNLS.Workspace): Array[Double] = {
548+
if (!nonnegative) {
549+
Solve.solvePositive(ata, atb).data
550+
} else {
551+
NNLS.solve(ata, atb, ws)
552+
}
553+
}
554+
527555
/**
528556
* Given a triangular matrix in the order of fillXtX above, compute the full symmetric square
529557
* matrix that it represents, storing it into destMatrix.
@@ -550,7 +578,6 @@ class ALS private (
550578
* Top-level methods for calling Alternating Least Squares (ALS) matrix factorization.
551579
*/
552580
object ALS {
553-
554581
/**
555582
* Train a matrix factorization model given an RDD of ratings given by users to some products,
556583
* in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.optimization
19+
20+
import scala.util.Random
21+
22+
import org.scalatest.FunSuite
23+
24+
import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas}
25+
26+
class NNLSSuite extends FunSuite {
27+
/** Generate an NNLS problem whose optimal solution is the all-ones vector. */
28+
def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = {
29+
val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*)
30+
val b = A.mmul(DoubleMatrix.ones(n, 1))
31+
32+
val ata = A.transpose.mmul(A)
33+
val atb = A.transpose.mmul(b)
34+
35+
(ata, atb)
36+
}
37+
38+
test("NNLS: exact solution cases") {
39+
val n = 20
40+
val rand = new Random(12346)
41+
val ws = NNLS.createWorkspace(n)
42+
var numSolved = 0
43+
44+
// About 15% of random 20x20 [-1,1]-matrices have a singular value less than 1e-3. NNLS
45+
// can legitimately fail to solve these anywhere close to exactly. So we grab a considerable
46+
// sample of these matrices and make sure that we solved a substantial fraction of them.
47+
48+
for (k <- 0 until 100) {
49+
val (ata, atb) = genOnesData(n, rand)
50+
val x = new DoubleMatrix(NNLS.solve(ata, atb, ws))
51+
assert(x.length === n)
52+
val answer = DoubleMatrix.ones(n, 1)
53+
SimpleBlas.axpy(-1.0, answer, x)
54+
val solved = (x.norm2 < 1e-2) && (x.normmax < 1e-3)
55+
if (solved) numSolved = numSolved + 1
56+
}
57+
58+
assert(numSolved > 50)
59+
}
60+
61+
test("NNLS: nonnegativity constraint active") {
62+
val n = 5
63+
val ata = new DoubleMatrix(Array(
64+
Array( 4.377, -3.531, -1.306, -0.139, 3.418),
65+
Array(-3.531, 4.344, 0.934, 0.305, -2.140),
66+
Array(-1.306, 0.934, 2.644, -0.203, -0.170),
67+
Array(-0.139, 0.305, -0.203, 5.883, 1.428),
68+
Array( 3.418, -2.140, -0.170, 1.428, 4.684)))
69+
val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636))
70+
71+
val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628)
72+
73+
val ws = NNLS.createWorkspace(n)
74+
val x = NNLS.solve(ata, atb, ws)
75+
for (i <- 0 until n) {
76+
assert(Math.abs(x(i) - goodx(i)) < 1e-3)
77+
assert(x(i) >= 0)
78+
}
79+
}
80+
}

mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,18 @@ object ALSSuite {
4848
features: Int,
4949
samplingRate: Double,
5050
implicitPrefs: Boolean = false,
51-
negativeWeights: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
51+
negativeWeights: Boolean = false,
52+
negativeFactors: Boolean = true): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
5253
val rand = new Random(42)
5354

5455
// Create a random matrix with uniform values from -1 to 1
55-
def randomMatrix(m: Int, n: Int) =
56-
new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*)
56+
def randomMatrix(m: Int, n: Int) = {
57+
if (negativeFactors) {
58+
new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*)
59+
} else {
60+
new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble()): _*)
61+
}
62+
}
5763

5864
val userMatrix = randomMatrix(users, features)
5965
val productMatrix = randomMatrix(features, products)
@@ -146,6 +152,10 @@ class ALSSuite extends FunSuite with LocalSparkContext {
146152
}
147153
}
148154

155+
test("NNALS, rank 2") {
156+
testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, false)
157+
}
158+
149159
/**
150160
* Test if we can correctly factorize R = U * P where U and P are of known rank.
151161
*
@@ -159,19 +169,19 @@ class ALSSuite extends FunSuite with LocalSparkContext {
159169
* @param bulkPredict flag to test bulk prediciton
160170
* @param negativeWeights whether the generated data can contain negative values
161171
* @param numBlocks number of blocks to partition users and products into
172+
* @param negativeFactors whether the generated user/product factors can have negative entries
162173
*/
163174
def testALS(users: Int, products: Int, features: Int, iterations: Int,
164175
samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false,
165-
bulkPredict: Boolean = false, negativeWeights: Boolean = false, numBlocks: Int = -1)
176+
bulkPredict: Boolean = false, negativeWeights: Boolean = false, numBlocks: Int = -1,
177+
negativeFactors: Boolean = true)
166178
{
167179
val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
168-
features, samplingRate, implicitPrefs, negativeWeights)
169-
val model = implicitPrefs match {
170-
case false => ALS.train(sc.parallelize(sampledRatings), features, iterations, 0.01,
171-
numBlocks, 0L)
172-
case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations, 0.01,
173-
numBlocks, 1.0, 0L)
174-
}
180+
features, samplingRate, implicitPrefs, negativeWeights, negativeFactors)
181+
182+
val model = (new ALS().setBlocks(numBlocks).setRank(features).setIterations(iterations)
183+
.setAlpha(1.0).setImplicitPrefs(implicitPrefs).setLambda(0.01).setSeed(0L)
184+
.setNonnegative(!negativeFactors).run(sc.parallelize(sampledRatings)))
175185

176186
val predictedU = new DoubleMatrix(users, features)
177187
for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) {

0 commit comments

Comments
 (0)