Skip to content

Commit 4512e2a

Browse files
yanboliangdbtsai
authored andcommitted
[SPARK-20047][ML] Constrained Logistic Regression
## What changes were proposed in this pull request? MLlib ```LogisticRegression``` should support bound constrained optimization (only for L2 regularization). Users can add bound constraints to coefficients to make the solver produce solution in the specified range. Under the hood, we call Breeze [```L-BFGS-B```](https://github.com/scalanlp/breeze/blob/master/math/src/main/scala/breeze/optimize/LBFGSB.scala) as the solver for bound constrained optimization. But in the current breeze implementation, there are some bugs in L-BFGS-B, and scalanlp/breeze#633 fixed them. We need to upgrade dependent breeze later, and currently we use the workaround L-BFGS-B in this PR temporary for reviewing. ## How was this patch tested? Unit tests. Author: Yanbo Liang <[email protected]> Closes #17715 from yanboliang/spark-20047. (cherry picked from commit 606432a) Signed-off-by: DB Tsai <[email protected]>
1 parent c29c6de commit 4512e2a

File tree

2 files changed

+682
-7
lines changed

2 files changed

+682
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 218 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util.Locale
2222
import scala.collection.mutable
2323

2424
import breeze.linalg.{DenseVector => BDV}
25-
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
25+
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
2626
import org.apache.hadoop.fs.Path
2727

2828
import org.apache.spark.SparkException
@@ -178,11 +178,86 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
178178
}
179179
}
180180

181+
/**
182+
* The lower bounds on coefficients if fitting under bound constrained optimization.
183+
* The bound matrix must be compatible with the shape (1, number of features) for binomial
184+
* regression, or (number of classes, number of features) for multinomial regression.
185+
* Otherwise, it throws exception.
186+
*
187+
* @group param
188+
*/
189+
@Since("2.2.0")
190+
val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients",
191+
"The lower bounds on coefficients if fitting under bound constrained optimization.")
192+
193+
/** @group getParam */
194+
@Since("2.2.0")
195+
def getLowerBoundsOnCoefficients: Matrix = $(lowerBoundsOnCoefficients)
196+
197+
/**
198+
* The upper bounds on coefficients if fitting under bound constrained optimization.
199+
* The bound matrix must be compatible with the shape (1, number of features) for binomial
200+
* regression, or (number of classes, number of features) for multinomial regression.
201+
* Otherwise, it throws exception.
202+
*
203+
* @group param
204+
*/
205+
@Since("2.2.0")
206+
val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients",
207+
"The upper bounds on coefficients if fitting under bound constrained optimization.")
208+
209+
/** @group getParam */
210+
@Since("2.2.0")
211+
def getUpperBoundsOnCoefficients: Matrix = $(upperBoundsOnCoefficients)
212+
213+
/**
214+
* The lower bounds on intercepts if fitting under bound constrained optimization.
215+
* The bounds vector size must be equal with 1 for binomial regression, or the number
216+
* of classes for multinomial regression. Otherwise, it throws exception.
217+
*
218+
* @group param
219+
*/
220+
@Since("2.2.0")
221+
val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts",
222+
"The lower bounds on intercepts if fitting under bound constrained optimization.")
223+
224+
/** @group getParam */
225+
@Since("2.2.0")
226+
def getLowerBoundsOnIntercepts: Vector = $(lowerBoundsOnIntercepts)
227+
228+
/**
229+
* The upper bounds on intercepts if fitting under bound constrained optimization.
230+
* The bound vector size must be equal with 1 for binomial regression, or the number
231+
* of classes for multinomial regression. Otherwise, it throws exception.
232+
*
233+
* @group param
234+
*/
235+
@Since("2.2.0")
236+
val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts",
237+
"The upper bounds on intercepts if fitting under bound constrained optimization.")
238+
239+
/** @group getParam */
240+
@Since("2.2.0")
241+
def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts)
242+
243+
protected def usingBoundConstrainedOptimization: Boolean = {
244+
isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) ||
245+
isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts)
246+
}
247+
181248
override protected def validateAndTransformSchema(
182249
schema: StructType,
183250
fitting: Boolean,
184251
featuresDataType: DataType): StructType = {
185252
checkThresholdConsistency()
253+
if (usingBoundConstrainedOptimization) {
254+
require($(elasticNetParam) == 0.0, "Fitting under bound constrained optimization only " +
255+
s"supports L2 regularization, but got elasticNetParam = $getElasticNetParam.")
256+
}
257+
if (!$(fitIntercept)) {
258+
require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts),
259+
"Pls don't set bounds on intercepts if fitting without intercept.")
260+
}
186261
super.validateAndTransformSchema(schema, fitting, featuresDataType)
187262
}
188263
}
@@ -217,6 +292,9 @@ class LogisticRegression @Since("1.2.0") (
217292
* For alpha in (0,1), the penalty is a combination of L1 and L2.
218293
* Default is 0.0 which is an L2 penalty.
219294
*
295+
* Note: Fitting under bound constrained optimization only supports L2 regularization,
296+
* so throws exception if this param is non-zero value.
297+
*
220298
* @group setParam
221299
*/
222300
@Since("1.4.0")
@@ -312,6 +390,71 @@ class LogisticRegression @Since("1.2.0") (
312390
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
313391
setDefault(aggregationDepth -> 2)
314392

393+
/**
394+
* Set the lower bounds on coefficients if fitting under bound constrained optimization.
395+
*
396+
* @group setParam
397+
*/
398+
@Since("2.2.0")
399+
def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value)
400+
401+
/**
402+
* Set the upper bounds on coefficients if fitting under bound constrained optimization.
403+
*
404+
* @group setParam
405+
*/
406+
@Since("2.2.0")
407+
def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value)
408+
409+
/**
410+
* Set the lower bounds on intercepts if fitting under bound constrained optimization.
411+
*
412+
* @group setParam
413+
*/
414+
@Since("2.2.0")
415+
def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value)
416+
417+
/**
418+
* Set the upper bounds on intercepts if fitting under bound constrained optimization.
419+
*
420+
* @group setParam
421+
*/
422+
@Since("2.2.0")
423+
def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value)
424+
425+
private def assertBoundConstrainedOptimizationParamsValid(
426+
numCoefficientSets: Int,
427+
numFeatures: Int): Unit = {
428+
if (isSet(lowerBoundsOnCoefficients)) {
429+
require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets &&
430+
$(lowerBoundsOnCoefficients).numCols == numFeatures)
431+
}
432+
if (isSet(upperBoundsOnCoefficients)) {
433+
require($(upperBoundsOnCoefficients).numRows == numCoefficientSets &&
434+
$(upperBoundsOnCoefficients).numCols == numFeatures)
435+
}
436+
if (isSet(lowerBoundsOnIntercepts)) {
437+
require($(lowerBoundsOnIntercepts).size == numCoefficientSets)
438+
}
439+
if (isSet(upperBoundsOnIntercepts)) {
440+
require($(upperBoundsOnIntercepts).size == numCoefficientSets)
441+
}
442+
if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) {
443+
require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray)
444+
.forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always " +
445+
"less than or equal to upperBoundsOnCoefficients, but found: " +
446+
s"lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " +
447+
s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients.")
448+
}
449+
if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) {
450+
require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray)
451+
.forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always " +
452+
"less than or equal to upperBoundsOnIntercepts, but found: " +
453+
s"lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " +
454+
s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts.")
455+
}
456+
}
457+
315458
private var optInitialModel: Option[LogisticRegressionModel] = None
316459

317460
private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = {
@@ -378,6 +521,11 @@ class LogisticRegression @Since("1.2.0") (
378521
}
379522
val numCoefficientSets = if (isMultinomial) numClasses else 1
380523

524+
// Check params interaction is valid if fitting under bound constrained optimization.
525+
if (usingBoundConstrainedOptimization) {
526+
assertBoundConstrainedOptimizationParamsValid(numCoefficientSets, numFeatures)
527+
}
528+
381529
if (isDefined(thresholds)) {
382530
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
383531
".train() called with non-matching numClasses and thresholds.length." +
@@ -397,7 +545,7 @@ class LogisticRegression @Since("1.2.0") (
397545

398546
val isConstantLabel = histogram.count(_ != 0.0) == 1
399547

400-
if ($(fitIntercept) && isConstantLabel) {
548+
if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) {
401549
logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " +
402550
s"will be zeros. Training is not needed.")
403551
val constantLabelIndex = Vectors.dense(histogram).argmax
@@ -434,8 +582,53 @@ class LogisticRegression @Since("1.2.0") (
434582
$(standardization), bcFeaturesStd, regParamL2, multinomial = isMultinomial,
435583
$(aggregationDepth))
436584

585+
val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets
586+
587+
val (lowerBounds, upperBounds): (Array[Double], Array[Double]) = {
588+
if (usingBoundConstrainedOptimization) {
589+
val lowerBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.NegativeInfinity)
590+
val upperBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.PositiveInfinity)
591+
val isSetLowerBoundsOnCoefficients = isSet(lowerBoundsOnCoefficients)
592+
val isSetUpperBoundsOnCoefficients = isSet(upperBoundsOnCoefficients)
593+
val isSetLowerBoundsOnIntercepts = isSet(lowerBoundsOnIntercepts)
594+
val isSetUpperBoundsOnIntercepts = isSet(upperBoundsOnIntercepts)
595+
596+
var i = 0
597+
while (i < numCoeffsPlusIntercepts) {
598+
val coefficientSetIndex = i % numCoefficientSets
599+
val featureIndex = i / numCoefficientSets
600+
if (featureIndex < numFeatures) {
601+
if (isSetLowerBoundsOnCoefficients) {
602+
lowerBounds(i) = $(lowerBoundsOnCoefficients)(
603+
coefficientSetIndex, featureIndex) * featuresStd(featureIndex)
604+
}
605+
if (isSetUpperBoundsOnCoefficients) {
606+
upperBounds(i) = $(upperBoundsOnCoefficients)(
607+
coefficientSetIndex, featureIndex) * featuresStd(featureIndex)
608+
}
609+
} else {
610+
if (isSetLowerBoundsOnIntercepts) {
611+
lowerBounds(i) = $(lowerBoundsOnIntercepts)(coefficientSetIndex)
612+
}
613+
if (isSetUpperBoundsOnIntercepts) {
614+
upperBounds(i) = $(upperBoundsOnIntercepts)(coefficientSetIndex)
615+
}
616+
}
617+
i += 1
618+
}
619+
(lowerBounds, upperBounds)
620+
} else {
621+
(null, null)
622+
}
623+
}
624+
437625
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
438-
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
626+
if (lowerBounds != null && upperBounds != null) {
627+
new BreezeLBFGSB(
628+
BDV[Double](lowerBounds), BDV[Double](upperBounds), $(maxIter), 10, $(tol))
629+
} else {
630+
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
631+
}
439632
} else {
440633
val standardizationParam = $(standardization)
441634
def regParamL1Fun = (index: Int) => {
@@ -546,6 +739,26 @@ class LogisticRegression @Since("1.2.0") (
546739
math.log(histogram(1) / histogram(0)))
547740
}
548741

742+
if (usingBoundConstrainedOptimization) {
743+
// Make sure all initial values locate in the corresponding bound.
744+
var i = 0
745+
while (i < numCoeffsPlusIntercepts) {
746+
val coefficientSetIndex = i % numCoefficientSets
747+
val featureIndex = i / numCoefficientSets
748+
if (initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) < lowerBounds(i))
749+
{
750+
initialCoefWithInterceptMatrix.update(
751+
coefficientSetIndex, featureIndex, lowerBounds(i))
752+
} else if (
753+
initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) > upperBounds(i))
754+
{
755+
initialCoefWithInterceptMatrix.update(
756+
coefficientSetIndex, featureIndex, upperBounds(i))
757+
}
758+
i += 1
759+
}
760+
}
761+
549762
val states = optimizer.iterations(new CachedDiffFunction(costFun),
550763
new BDV[Double](initialCoefWithInterceptMatrix.toArray))
551764

@@ -599,7 +812,7 @@ class LogisticRegression @Since("1.2.0") (
599812
if (isIntercept) interceptVec.toArray(classIndex) = value
600813
}
601814

602-
if ($(regParam) == 0.0 && isMultinomial) {
815+
if ($(regParam) == 0.0 && isMultinomial && !usingBoundConstrainedOptimization) {
603816
/*
604817
When no regularization is applied, the multinomial coefficients lack identifiability
605818
because we do not use a pivot class. We can add any constant value to the coefficients
@@ -620,7 +833,7 @@ class LogisticRegression @Since("1.2.0") (
620833
}
621834

622835
// center the intercepts when using multinomial algorithm
623-
if ($(fitIntercept) && isMultinomial) {
836+
if ($(fitIntercept) && isMultinomial && !usingBoundConstrainedOptimization) {
624837
val interceptArray = interceptVec.toArray
625838
val interceptMean = interceptArray.sum / interceptArray.length
626839
(0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean }

0 commit comments

Comments
 (0)