Skip to content

Commit 21bbf94

Browse files
committed
[SPARK-17748][ML] Minor cleanups to one-pass linear regression with elastic net
## What changes were proposed in this pull request? * Made SingularMatrixException private ml * WeightedLeastSquares: Changed to allow tol >= 0 instead of only tol > 0 ## How was this patch tested? existing tests Author: Joseph K. Bradley <[email protected]> Closes #15779 from jkbradley/wls-cleanups. (cherry picked from commit 26e1c53) Signed-off-by: Joseph K. Bradley <[email protected]>
1 parent 876eee2 commit 21bbf94

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
*/
1717
package org.apache.spark.ml.optim
1818

19+
import scala.collection.mutable
20+
1921
import breeze.linalg.{DenseVector => BDV}
2022
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
21-
import scala.collection.mutable
2223

2324
import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vectors}
2425
import org.apache.spark.mllib.linalg.CholeskyDecomposition
@@ -57,7 +58,7 @@ private[ml] sealed trait NormalEquationSolver {
5758
*/
5859
private[ml] class CholeskySolver extends NormalEquationSolver {
5960

60-
def solve(
61+
override def solve(
6162
bBar: Double,
6263
bbBar: Double,
6364
abBar: DenseVector,
@@ -80,7 +81,7 @@ private[ml] class QuasiNewtonSolver(
8081
tol: Double,
8182
l1RegFunc: Option[(Int) => Double]) extends NormalEquationSolver {
8283

83-
def solve(
84+
override def solve(
8485
bBar: Double,
8586
bbBar: Double,
8687
abBar: DenseVector,
@@ -156,7 +157,7 @@ private[ml] class QuasiNewtonSolver(
156157
* Exception thrown when solving a linear system Ax = b for which the matrix A is non-invertible
157158
* (singular).
158159
*/
159-
class SingularMatrixException(message: String, cause: Throwable)
160+
private[spark] class SingularMatrixException(message: String, cause: Throwable)
160161
extends IllegalArgumentException(message, cause) {
161162

162163
def this(message: String) = this(message, null)

mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ private[ml] class WeightedLeastSquaresModel(
4747
* formulation:
4848
*
4949
* min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w,,i,,
50-
* + lambda / delta (1/2 (1 - alpha) sumj,, (sigma,,j,, x,,j,,)^2^
50+
* + lambda / delta (1/2 (1 - alpha) sum,,j,, (sigma,,j,, x,,j,,)^2^
5151
* + alpha sum,,j,, abs(sigma,,j,, x,,j,,)),
5252
*
5353
* where lambda is the regularization parameter, alpha is the ElasticNet mixing parameter,
@@ -91,7 +91,7 @@ private[ml] class WeightedLeastSquares(
9191
require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0,
9292
s"elasticNetParam must be in [0, 1]: $elasticNetParam")
9393
require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter")
94-
require(tol > 0, s"tol must be greater than zero: $tol")
94+
require(tol >= 0.0, s"tol must be >= 0, but was set to $tol")
9595

9696
/**
9797
* Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s.

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging
3131
import org.apache.spark.ml.feature.Instance
3232
import org.apache.spark.ml.linalg.{Vector, Vectors}
3333
import org.apache.spark.ml.linalg.BLAS._
34-
import org.apache.spark.ml.optim.{NormalEquationSolver, WeightedLeastSquares}
34+
import org.apache.spark.ml.optim.WeightedLeastSquares
3535
import org.apache.spark.ml.PredictorParams
3636
import org.apache.spark.ml.param.ParamMap
3737
import org.apache.spark.ml.param.shared._
@@ -160,11 +160,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
160160
/**
161161
* Set the solver algorithm used for optimization.
162162
* In case of linear regression, this can be "l-bfgs", "normal" and "auto".
163-
* "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton
164-
* optimization method. "normal" denotes using Normal Equation as an analytical
165-
* solution to the linear regression problem.
166-
* The default value is "auto" which means that the solver algorithm is
167-
* selected automatically.
163+
* - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton
164+
* optimization method.
165+
* - "normal" denotes using Normal Equation as an analytical solution to the linear regression
166+
* problem. This solver is limited to [[LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER]].
167+
* - "auto" (default) means that the solver algorithm is selected automatically.
168+
* The Normal Equations solver will be used when possible, but this will automatically fall
169+
* back to iterative optimization methods when needed.
168170
*
169171
* @group setParam
170172
*/
@@ -404,6 +406,14 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] {
404406

405407
@Since("1.6.0")
406408
override def load(path: String): LinearRegression = super.load(path)
409+
410+
/**
411+
* When using [[LinearRegression.solver]] == "normal", the solver must limit the number of
412+
* features to at most this number. The entire covariance matrix X^T^X will be collected
413+
* to the driver. This limit helps prevent memory overflow errors.
414+
*/
415+
@Since("2.1.0")
416+
val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES
407417
}
408418

409419
/**

0 commit comments

Comments
 (0)