-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-1553] Alternating nonnegative least-squares #460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
a68ac10
A nonnegative least-squares solver.
tmyklebu 6cb563c
Tests for the nonnegative least squares solver.
tmyklebu f5dbf4d
Teach ALS how to use the NNLS solver.
tmyklebu 89ea0a8
Hack ALSSuite to support NNLS testing.
tmyklebu 33bf4f2
Fix missing space.
tmyklebu 9a82fa6
Fix scalastyle moanings.
tmyklebu c288b6a
Finish moving the NNLS solver.
tmyklebu ac673bd
More safeguards against numerical ridiculousness.
tmyklebu 5345402
Style fixes that got eaten.
tmyklebu 8a1a436
Describe the problem and add a reference to Polyak's paper.
tmyklebu 9c820b6
Tweak variable names.
tmyklebu b285106
Clean up NNLS test cases.
tmyklebu e2a01d1
Create a workspace object for NNLS to cut down on memory allocations.
tmyklebu 0cb4481
Drop the iteration limit from 40k to max(400,20n).
tmyklebu 2d4f3cb
Cleanup.
tmyklebu 65ef7f2
Make ALS's ctor public and remove a couple of "convenience" wrappers.
tmyklebu 7fbabf1
Cleanup matrix math in NNLSSuite.
tmyklebu 199b0bc
Make the ctor private again and use the builder pattern.
tmyklebu 79bc4b5
Merge branch 'master' of https://github.com/apache/spark into annls
tmyklebu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
169 changes: 169 additions & 0 deletions
169
mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.optimization | ||
|
|
||
| import org.jblas.{DoubleMatrix, SimpleBlas} | ||
|
|
||
| import org.apache.spark.annotation.DeveloperApi | ||
|
|
||
| /** | ||
| * Object used to solve nonnegative least squares problems using a modified | ||
| * projected gradient method. | ||
| */ | ||
| private[mllib] object NNLS { | ||
| class Workspace(val n: Int) { | ||
| val scratch = new DoubleMatrix(n, 1) | ||
| val grad = new DoubleMatrix(n, 1) | ||
| val x = new DoubleMatrix(n, 1) | ||
| val dir = new DoubleMatrix(n, 1) | ||
| val lastDir = new DoubleMatrix(n, 1) | ||
| val res = new DoubleMatrix(n, 1) | ||
|
|
||
| def wipe() { | ||
| scratch.fill(0.0) | ||
| grad.fill(0.0) | ||
| x.fill(0.0) | ||
| dir.fill(0.0) | ||
| lastDir.fill(0.0) | ||
| res.fill(0.0) | ||
| } | ||
| } | ||
|
|
||
| def createWorkspace(n: Int): Workspace = { | ||
| new Workspace(n) | ||
| } | ||
|
|
||
| /** | ||
| * Solve a least squares problem, possibly with nonnegativity constraints, by a modified | ||
| * projected gradient method. That is, find x minimising ||Ax - b||_2 given A^T A and A^T b. | ||
| * | ||
| * We solve the problem | ||
| * min_x 1/2 x^T ata x^T - x^T atb | ||
| * subject to x >= 0 | ||
| * | ||
| * The method used is similar to one described by Polyak (B. T. Polyak, The conjugate gradient | ||
| * method in extremal problems, Zh. Vychisl. Mat. Mat. Fiz. 9(4)(1969), pp. 94-112) for bound- | ||
| * constrained nonlinear programming. Polyak unconditionally uses a conjugate gradient | ||
| * direction, however, while this method only uses a conjugate gradient direction if the last | ||
| * iteration did not cause a previously-inactive constraint to become active. | ||
| */ | ||
| def solve(ata: DoubleMatrix, atb: DoubleMatrix, ws: Workspace): Array[Double] = { | ||
| ws.wipe() | ||
|
|
||
| val n = atb.rows | ||
| val scratch = ws.scratch | ||
|
|
||
| // find the optimal unconstrained step | ||
| def steplen(dir: DoubleMatrix, res: DoubleMatrix): Double = { | ||
| val top = SimpleBlas.dot(dir, res) | ||
| SimpleBlas.gemv(1.0, ata, dir, 0.0, scratch) | ||
| // Push the denominator upward very slightly to avoid infinities and silliness | ||
| top / (SimpleBlas.dot(scratch, dir) + 1e-20) | ||
| } | ||
|
|
||
| // stopping condition | ||
| def stop(step: Double, ndir: Double, nx: Double): Boolean = { | ||
| ((step.isNaN) // NaN | ||
| || (step < 1e-6) // too small or negative | ||
| || (step > 1e40) // too small; almost certainly numerical problems | ||
| || (ndir < 1e-12 * nx) // gradient relatively too small | ||
| || (ndir < 1e-32) // gradient absolutely too small; numerical issues may lurk | ||
| ) | ||
| } | ||
|
|
||
| val grad = ws.grad | ||
| val x = ws.x | ||
| val dir = ws.dir | ||
| val lastDir = ws.lastDir | ||
| val res = ws.res | ||
| val iterMax = Math.max(400, 20 * n) | ||
| var lastNorm = 0.0 | ||
| var iterno = 0 | ||
| var lastWall = 0 // Last iteration when we hit a bound constraint. | ||
| var i = 0 | ||
| while (iterno < iterMax) { | ||
| // find the residual | ||
| SimpleBlas.gemv(1.0, ata, x, 0.0, res) | ||
| SimpleBlas.axpy(-1.0, atb, res) | ||
| SimpleBlas.copy(res, grad) | ||
|
|
||
| // project the gradient | ||
| i = 0 | ||
| while (i < n) { | ||
| if (grad.data(i) > 0.0 && x.data(i) == 0.0) { | ||
| grad.data(i) = 0.0 | ||
| } | ||
| i = i + 1 | ||
| } | ||
| val ngrad = SimpleBlas.dot(grad, grad) | ||
|
|
||
| SimpleBlas.copy(grad, dir) | ||
|
|
||
| // use a CG direction under certain conditions | ||
| var step = steplen(grad, res) | ||
| var ndir = 0.0 | ||
| val nx = SimpleBlas.dot(x, x) | ||
| if (iterno > lastWall + 1) { | ||
| val alpha = ngrad / lastNorm | ||
| SimpleBlas.axpy(alpha, lastDir, dir) | ||
| val dstep = steplen(dir, res) | ||
| ndir = SimpleBlas.dot(dir, dir) | ||
| if (stop(dstep, ndir, nx)) { | ||
| // reject the CG step if it could lead to premature termination | ||
| SimpleBlas.copy(grad, dir) | ||
| ndir = SimpleBlas.dot(dir, dir) | ||
| } else { | ||
| step = dstep | ||
| } | ||
| } else { | ||
| ndir = SimpleBlas.dot(dir, dir) | ||
| } | ||
|
|
||
| // terminate? | ||
| if (stop(step, ndir, nx)) { | ||
| return x.data.clone | ||
| } | ||
|
|
||
| // don't run through the walls | ||
| i = 0 | ||
| while (i < n) { | ||
| if (step * dir.data(i) > x.data(i)) { | ||
| step = x.data(i) / dir.data(i) | ||
| } | ||
| i = i + 1 | ||
| } | ||
|
|
||
| // take the step | ||
| i = 0 | ||
| while (i < n) { | ||
| if (step * dir.data(i) > x.data(i) * (1 - 1e-14)) { | ||
| x.data(i) = 0 | ||
| lastWall = iterno | ||
| } else { | ||
| x.data(i) -= step * dir.data(i) | ||
| } | ||
| i = i + 1 | ||
| } | ||
|
|
||
| iterno = iterno + 1 | ||
| SimpleBlas.copy(dir, lastDir) | ||
| lastNorm = ngrad | ||
| } | ||
| x.data.clone | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.optimization | ||
|
|
||
| import scala.util.Random | ||
|
|
||
| import org.scalatest.FunSuite | ||
|
|
||
| import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas} | ||
|
|
||
| class NNLSSuite extends FunSuite { | ||
| /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ | ||
| def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = { | ||
| val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*) | ||
| val b = A.mmul(DoubleMatrix.ones(n, 1)) | ||
|
|
||
| val ata = A.transpose.mmul(A) | ||
| val atb = A.transpose.mmul(b) | ||
|
|
||
| (ata, atb) | ||
| } | ||
|
|
||
| test("NNLS: exact solution cases") { | ||
| val n = 20 | ||
| val rand = new Random(12346) | ||
| val ws = NNLS.createWorkspace(n) | ||
| var numSolved = 0 | ||
|
|
||
| // About 15% of random 20x20 [-1,1]-matrices have a singular value less than 1e-3. NNLS | ||
| // can legitimately fail to solve these anywhere close to exactly. So we grab a considerable | ||
| // sample of these matrices and make sure that we solved a substantial fraction of them. | ||
|
|
||
| for (k <- 0 until 100) { | ||
| val (ata, atb) = genOnesData(n, rand) | ||
| val x = new DoubleMatrix(NNLS.solve(ata, atb, ws)) | ||
| assert(x.length === n) | ||
| val answer = DoubleMatrix.ones(n, 1) | ||
| SimpleBlas.axpy(-1.0, answer, x) | ||
| val solved = (x.norm2 < 1e-2) && (x.normmax < 1e-3) | ||
| if (solved) numSolved = numSolved + 1 | ||
| } | ||
|
|
||
| assert(numSolved > 50) | ||
| } | ||
|
|
||
| test("NNLS: nonnegativity constraint active") { | ||
| val n = 5 | ||
| val ata = new DoubleMatrix(Array( | ||
| Array( 4.377, -3.531, -1.306, -0.139, 3.418), | ||
| Array(-3.531, 4.344, 0.934, 0.305, -2.140), | ||
| Array(-1.306, 0.934, 2.644, -0.203, -0.170), | ||
| Array(-0.139, 0.305, -0.203, 5.883, 1.428), | ||
| Array( 3.418, -2.140, -0.170, 1.428, 4.684))) | ||
| val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636)) | ||
|
|
||
| val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628) | ||
|
|
||
| val ws = NNLS.createWorkspace(n) | ||
| val x = NNLS.solve(ata, atb, ws) | ||
| for (i <- 0 until n) { | ||
| assert(Math.abs(x(i) - goodx(i)) < 1e-3) | ||
| assert(x(i) >= 0) | ||
| } | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also need to assert
x(i) >= 0.