Skip to content

Commit 78d740a

Browse files
sethahyanboliang
authored andcommitted
[SPARK-17748][ML] One pass solver for Weighted Least Squares with ElasticNet
## What changes were proposed in this pull request? 1. Make a pluggable solver interface for `WeightedLeastSquares` 2. Add a `QuasiNewton` solver to handle elastic net regularization for `WeightedLeastSquares` 3. Add method `BLAS.dspmv` used by QN solver 4. Add mechanism for WLS to handle singular covariance matrices by falling back to QN solver when Cholesky fails. ## How was this patch tested? Unit tests - see below. ## Design choices **Pluggable Normal Solver** Before, the `WeightedLeastSquares` package always used the Cholesky decomposition solver to compute the solution to the normal equations. Now, we specify the solver as a constructor argument to the `WeightedLeastSquares`. We introduce a new trait: ````scala private[ml] sealed trait NormalEquationSolver { def solve( bBar: Double, bbBar: Double, abBar: DenseVector, aaBar: DenseVector, aBar: DenseVector): NormalEquationSolution } ```` We extend this trait for different variants of normal equation solvers. In the future, we can easily add others (like QR) using this interface. **Always train in the standardized space** The normal solver did not previously standardize the data, but this patch introduces a change such that we always solve the normal equations in the standardized space. We convert back to the original space in the same way that is done for distributed L-BFGS/OWL-QN. We add test cases for zero variance features/labels. **Use L-BFGS locally to solve normal equations for singular matrix** When linear regression with the normal solver is called for a singular matrix, we initially try to solve with Cholesky. We use the output of `lapack.dppsv` to determine if the matrix is singular. If it is, we fall back to using L-BFGS locally to solve the normal equations. We add test cases for this as well. ## Test cases I found it helpful to enumerate some of the test cases and hopefully it makes review easier. **WeightedLeastSquares** 1. Constant columns - Cholesky solver fails with no regularization, Auto solver falls back to QN, and QN trains successfully. 2. Collinear features - Cholesky solver fails with no regularization, Auto solver falls back to QN, and QN trains successfully. 3. Label is constant zero - no training is performed regardless of intercept. Coefficients are zero and intercept is zero. 4. Label is constant - if fitIntercept, then no training is performed and intercept equals label mean. If not fitIntercept, then we train and return an answer that matches R's lm package. 5. Test with L1 - go through various combinations of L1/L2, standardization, fitIntercept and verify that output matches glmnet. 6. Initial intercept - verify that setting the initial intercept to label mean is correct by training model with strong L1 regularization so that all coefficients are zero and intercept converges to label mean. 7. Test diagInvAtWA - since we are standardizing features now during training, we should test that the inverse is computed to match R. **LinearRegression** 1. For all existing L1 test cases, test the "normal" solver too. 2. Check that using the normal solver now handles singular matrices. 3. Check that using the normal solver with L1 produces an objective history in the model summary, but does not produce the inverse of AtA. **BLAS** 1. Test new method `dspmv`. ## Performance Testing This patch will speed up linear regression with L1/elasticnet penalties when the feature size is < 4096. I have not conducted performance tests at scale, only observed by testing locally that there is a speed improvement. We should decide if this PR needs to be blocked before performance testing is conducted. Author: sethah <[email protected]> Closes #15394 from sethah/SPARK-17748.
1 parent 483c37c commit 78d740a

File tree

11 files changed

+1057
-308
lines changed

11 files changed

+1057
-308
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,24 @@ private[spark] object BLAS extends Serializable {
243243
spr(alpha, v, U.values)
244244
}
245245

246+
/**
247+
* y := alpha*A*x + beta*y
248+
*
249+
* @param n The order of the n by n matrix A.
250+
* @param A The upper triangular part of A in a [[DenseVector]] (column major).
251+
* @param x The [[DenseVector]] transformed by A.
252+
* @param y The [[DenseVector]] to be modified in place.
253+
*/
254+
def dspmv(
255+
n: Int,
256+
alpha: Double,
257+
A: DenseVector,
258+
x: DenseVector,
259+
beta: Double,
260+
y: DenseVector): Unit = {
261+
f2jBLAS.dspmv("U", n, alpha, A.values, x.values, 1, beta, y.values, 1)
262+
}
263+
246264
/**
247265
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
248266
*

mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,4 +422,49 @@ class BLASSuite extends SparkMLFunSuite {
422422
assert(dATT.multiply(sx) ~== expected absTol 1e-15)
423423
assert(sATT.multiply(sx) ~== expected absTol 1e-15)
424424
}
425+
426+
test("spmv") {
427+
/*
428+
A = [[3.0, -2.0, 2.0, -4.0],
429+
[-2.0, -8.0, 4.0, 7.0],
430+
[2.0, 4.0, -3.0, -3.0],
431+
[-4.0, 7.0, -3.0, 0.0]]
432+
x = [5.0, 2.0, -1.0, -9.0]
433+
Ax = [ 45., -93., 48., -3.]
434+
*/
435+
val A = new DenseVector(Array(3.0, -2.0, -8.0, 2.0, 4.0, -3.0, -4.0, 7.0, -3.0, 0.0))
436+
val x = new DenseVector(Array(5.0, 2.0, -1.0, -9.0))
437+
val n = 4
438+
439+
val y1 = new DenseVector(Array(-3.0, 6.0, -8.0, -3.0))
440+
val y2 = y1.copy
441+
val y3 = y1.copy
442+
val y4 = y1.copy
443+
val y5 = y1.copy
444+
val y6 = y1.copy
445+
val y7 = y1.copy
446+
447+
val expected1 = new DenseVector(Array(42.0, -87.0, 40.0, -6.0))
448+
val expected2 = new DenseVector(Array(19.5, -40.5, 16.0, -4.5))
449+
val expected3 = new DenseVector(Array(-25.5, 52.5, -32.0, -1.5))
450+
val expected4 = new DenseVector(Array(-3.0, 6.0, -8.0, -3.0))
451+
val expected5 = new DenseVector(Array(43.5, -90.0, 44.0, -4.5))
452+
val expected6 = new DenseVector(Array(46.5, -96.0, 52.0, -1.5))
453+
val expected7 = new DenseVector(Array(45.0, -93.0, 48.0, -3.0))
454+
455+
dspmv(n, 1.0, A, x, 1.0, y1)
456+
dspmv(n, 0.5, A, x, 1.0, y2)
457+
dspmv(n, -0.5, A, x, 1.0, y3)
458+
dspmv(n, 0.0, A, x, 1.0, y4)
459+
dspmv(n, 1.0, A, x, 0.5, y5)
460+
dspmv(n, 1.0, A, x, -0.5, y6)
461+
dspmv(n, 1.0, A, x, 0.0, y7)
462+
assert(y1 ~== expected1 absTol 1e-8)
463+
assert(y2 ~== expected2 absTol 1e-8)
464+
assert(y3 ~== expected3 absTol 1e-8)
465+
assert(y4 ~== expected4 absTol 1e-8)
466+
assert(y5 ~== expected5 absTol 1e-8)
467+
assert(y6 ~== expected6 absTol 1e-8)
468+
assert(y7 ~== expected7 absTol 1e-8)
469+
}
425470
}

mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ private[ml] class IterativelyReweightedLeastSquares(
8181
}
8282

8383
// Estimate new model
84-
model = new WeightedLeastSquares(fitIntercept, regParam, standardizeFeatures = false,
85-
standardizeLabel = false).fit(newInstances)
84+
model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0,
85+
standardizeFeatures = false, standardizeLabel = false).fit(newInstances)
8686

8787
// Check convergence
8888
val oldCoefficients = oldModel.coefficients
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.ml.optim
18+
19+
import breeze.linalg.{DenseVector => BDV}
20+
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
21+
import scala.collection.mutable
22+
23+
import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vectors}
24+
import org.apache.spark.mllib.linalg.CholeskyDecomposition
25+
26+
/**
27+
* A class to hold the solution to the normal equations A^T^ W A x = A^T^ W b.
28+
*
29+
* @param coefficients The least squares coefficients. The last element in the coefficients
30+
* is the intercept when bias is added to A.
31+
* @param aaInv An option containing the upper triangular part of (A^T^ W A)^-1^, in column major
32+
* format. None when an optimization program is used to solve the normal equations.
33+
* @param objectiveHistory Option containing the objective history when an optimization program is
34+
* used to solve the normal equations. None when an analytic solver is used.
35+
*/
36+
private[ml] class NormalEquationSolution(
37+
val coefficients: Array[Double],
38+
val aaInv: Option[Array[Double]],
39+
val objectiveHistory: Option[Array[Double]])
40+
41+
/**
42+
* Interface for classes that solve the normal equations locally.
43+
*/
44+
private[ml] sealed trait NormalEquationSolver {
45+
46+
/** Solve the normal equations from summary statistics. */
47+
def solve(
48+
bBar: Double,
49+
bbBar: Double,
50+
abBar: DenseVector,
51+
aaBar: DenseVector,
52+
aBar: DenseVector): NormalEquationSolution
53+
}
54+
55+
/**
56+
* A class that solves the normal equations directly, using Cholesky decomposition.
57+
*/
58+
private[ml] class CholeskySolver extends NormalEquationSolver {
59+
60+
def solve(
61+
bBar: Double,
62+
bbBar: Double,
63+
abBar: DenseVector,
64+
aaBar: DenseVector,
65+
aBar: DenseVector): NormalEquationSolution = {
66+
val k = abBar.size
67+
val x = CholeskyDecomposition.solve(aaBar.values, abBar.values)
68+
val aaInv = CholeskyDecomposition.inverse(aaBar.values, k)
69+
70+
new NormalEquationSolution(x, Some(aaInv), None)
71+
}
72+
}
73+
74+
/**
75+
* A class for solving the normal equations using Quasi-Newton optimization methods.
76+
*/
77+
private[ml] class QuasiNewtonSolver(
78+
fitIntercept: Boolean,
79+
maxIter: Int,
80+
tol: Double,
81+
l1RegFunc: Option[(Int) => Double]) extends NormalEquationSolver {
82+
83+
def solve(
84+
bBar: Double,
85+
bbBar: Double,
86+
abBar: DenseVector,
87+
aaBar: DenseVector,
88+
aBar: DenseVector): NormalEquationSolution = {
89+
val numFeatures = aBar.size
90+
val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
91+
val initialCoefficientsWithIntercept = new Array[Double](numFeaturesPlusIntercept)
92+
if (fitIntercept) {
93+
initialCoefficientsWithIntercept(numFeaturesPlusIntercept - 1) = bBar
94+
}
95+
96+
val costFun =
97+
new NormalEquationCostFun(bBar, bbBar, abBar, aaBar, aBar, fitIntercept, numFeatures)
98+
val optimizer = l1RegFunc.map { func =>
99+
new BreezeOWLQN[Int, BDV[Double]](maxIter, 10, func, tol)
100+
}.getOrElse(new BreezeLBFGS[BDV[Double]](maxIter, 10, tol))
101+
102+
val states = optimizer.iterations(new CachedDiffFunction(costFun),
103+
new BDV[Double](initialCoefficientsWithIntercept))
104+
105+
val arrayBuilder = mutable.ArrayBuilder.make[Double]
106+
var state: optimizer.State = null
107+
while (states.hasNext) {
108+
state = states.next()
109+
arrayBuilder += state.adjustedValue
110+
}
111+
val x = state.x.toArray.clone()
112+
new NormalEquationSolution(x, None, Some(arrayBuilder.result()))
113+
}
114+
115+
/**
116+
* NormalEquationCostFun implements Breeze's DiffFunction[T] for the normal equation.
117+
* It returns the loss and gradient with L2 regularization at a particular point (coefficients).
118+
* It's used in Breeze's convex optimization routines.
119+
*/
120+
private class NormalEquationCostFun(
121+
bBar: Double,
122+
bbBar: Double,
123+
ab: DenseVector,
124+
aa: DenseVector,
125+
aBar: DenseVector,
126+
fitIntercept: Boolean,
127+
numFeatures: Int) extends DiffFunction[BDV[Double]] {
128+
129+
private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
130+
131+
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
132+
val coef = Vectors.fromBreeze(coefficients).toDense
133+
if (fitIntercept) {
134+
var j = 0
135+
var dotProd = 0.0
136+
val coefValues = coef.values
137+
val aBarValues = aBar.values
138+
while (j < numFeatures) {
139+
dotProd += coefValues(j) * aBarValues(j)
140+
j += 1
141+
}
142+
coefValues(numFeatures) = bBar - dotProd
143+
}
144+
val aax = new DenseVector(new Array[Double](numFeaturesPlusIntercept))
145+
BLAS.dspmv(numFeaturesPlusIntercept, 1.0, aa, coef, 1.0, aax)
146+
// loss = 1/2 (b^T W b - 2 x^T A^T W b + x^T A^T W A x)
147+
val loss = 0.5 * bbBar - BLAS.dot(ab, coef) + 0.5 * BLAS.dot(coef, aax)
148+
// gradient = A^T W A x - A^T W b
149+
BLAS.axpy(-1.0, ab, aax)
150+
(loss, aax.asBreeze.toDenseVector)
151+
}
152+
}
153+
}
154+
155+
/**
156+
* Exception thrown when solving a linear system Ax = b for which the matrix A is non-invertible
157+
* (singular).
158+
*/
159+
class SingularMatrixException(message: String, cause: Throwable)
160+
extends IllegalArgumentException(message, cause) {
161+
162+
def this(message: String) = this(message, null)
163+
}

0 commit comments

Comments
 (0)