Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.Locale
import scala.collection.mutable

import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
Expand Down Expand Up @@ -178,11 +178,86 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
}
}

/**
* The lower bounds on coefficients if fitting under bound constrained optimization.
* The bound matrix must be compatible with the shape (1, number of features) for binomial
* regression, or (number of classes, number of features) for multinomial regression.
* Otherwise, it throws exception.
*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should state the default value is none

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for the other new bound Params

* @group param
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend that bound-constrained optimization be put under expertParams. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree to put them under expertParams. Thanks.

*/
@Since("2.2.0")
val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients",
"The lower bounds on coefficients if fitting under bound constrained optimization.")

/** @group getParam */
@Since("2.2.0")
def getLowerBoundsOnCoefficients: Matrix = $(lowerBoundsOnCoefficients)

/**
* The upper bounds on coefficients if fitting under bound constrained optimization.
* The bound matrix must be compatible with the shape (1, number of features) for binomial
* regression, or (number of classes, number of features) for multinomial regression.
* Otherwise, it throws exception.
*
* @group param
*/
@Since("2.2.0")
val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients",
"The upper bounds on coefficients if fitting under bound constrained optimization.")

/** @group getParam */
@Since("2.2.0")
def getUpperBoundsOnCoefficients: Matrix = $(upperBoundsOnCoefficients)

/**
* The lower bounds on intercepts if fitting under bound constrained optimization.
* The bounds vector size must be equal with 1 for binomial regression, or the number
* of classes for multinomial regression. Otherwise, it throws exception.
*
* @group param
*/
@Since("2.2.0")
val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts",
"The lower bounds on intercepts if fitting under bound constrained optimization.")

/** @group getParam */
@Since("2.2.0")
def getLowerBoundsOnIntercepts: Vector = $(lowerBoundsOnIntercepts)

/**
* The upper bounds on intercepts if fitting under bound constrained optimization.
* The bound vector size must be equal with 1 for binomial regression, or the number
* of classes for multinomial regression. Otherwise, it throws exception.
*
* @group param
*/
@Since("2.2.0")
val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts",
"The upper bounds on intercepts if fitting under bound constrained optimization.")

/** @group getParam */
@Since("2.2.0")
def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts)

protected def usingBoundConstrainedOptimization: Boolean = {
isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) ||
isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts)
}

override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
checkThresholdConsistency()
if (usingBoundConstrainedOptimization) {
require($(elasticNetParam) == 0.0, "Fitting under bound constrained optimization only " +
s"supports L2 regularization, but got elasticNetParam = $getElasticNetParam.")
}
if (!$(fitIntercept)) {
require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts),
"Pls don't set bounds on intercepts if fitting without intercept.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Pls don't" --> "Please do not"

}
super.validateAndTransformSchema(schema, fitting, featuresDataType)
}
}
Expand Down Expand Up @@ -217,6 +292,9 @@ class LogisticRegression @Since("1.2.0") (
* For alpha in (0,1), the penalty is a combination of L1 and L2.
* Default is 0.0 which is an L2 penalty.
*
* Note: Fitting under bound constrained optimization only supports L2 regularization,
* so throws exception if this param is non-zero value.
*
* @group setParam
*/
@Since("1.4.0")
Expand Down Expand Up @@ -312,6 +390,71 @@ class LogisticRegression @Since("1.2.0") (
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
setDefault(aggregationDepth -> 2)

/**
* Set the lower bounds on coefficients if fitting under bound constrained optimization.
*
* @group setParam
*/
@Since("2.2.0")
def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value)

/**
* Set the upper bounds on coefficients if fitting under bound constrained optimization.
*
* @group setParam
*/
@Since("2.2.0")
def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value)

/**
* Set the lower bounds on intercepts if fitting under bound constrained optimization.
*
* @group setParam
*/
@Since("2.2.0")
def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value)

/**
* Set the upper bounds on intercepts if fitting under bound constrained optimization.
*
* @group setParam
*/
@Since("2.2.0")
def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value)

private def assertBoundConstrainedOptimizationParamsValid(
numCoefficientSets: Int,
numFeatures: Int): Unit = {
if (isSet(lowerBoundsOnCoefficients)) {
require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These require() statements should have error messages so users know what went wrong.

$(lowerBoundsOnCoefficients).numCols == numFeatures)
}
if (isSet(upperBoundsOnCoefficients)) {
require($(upperBoundsOnCoefficients).numRows == numCoefficientSets &&
$(upperBoundsOnCoefficients).numCols == numFeatures)
}
if (isSet(lowerBoundsOnIntercepts)) {
require($(lowerBoundsOnIntercepts).size == numCoefficientSets)
}
if (isSet(upperBoundsOnIntercepts)) {
require($(upperBoundsOnIntercepts).size == numCoefficientSets)
}
if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) {
require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray)
.forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always => always be

"less than or equal to upperBoundsOnCoefficients, but found: " +
s"lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " +
s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients.")
}
if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) {
require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray)
.forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

"less than or equal to upperBoundsOnIntercepts, but found: " +
s"lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " +
s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts.")
}
}

private var optInitialModel: Option[LogisticRegressionModel] = None

private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = {
Expand Down Expand Up @@ -378,6 +521,11 @@ class LogisticRegression @Since("1.2.0") (
}
val numCoefficientSets = if (isMultinomial) numClasses else 1

// Check params interaction is valid if fitting under bound constrained optimization.
if (usingBoundConstrainedOptimization) {
assertBoundConstrainedOptimizationParamsValid(numCoefficientSets, numFeatures)
}

if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
Expand All @@ -397,7 +545,7 @@ class LogisticRegression @Since("1.2.0") (

val isConstantLabel = histogram.count(_ != 0.0) == 1

if ($(fitIntercept) && isConstantLabel) {
if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) {
logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " +
s"will be zeros. Training is not needed.")
val constantLabelIndex = Vectors.dense(histogram).argmax
Expand Down Expand Up @@ -434,8 +582,53 @@ class LogisticRegression @Since("1.2.0") (
$(standardization), bcFeaturesStd, regParamL2, multinomial = isMultinomial,
$(aggregationDepth))

val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets

val (lowerBounds, upperBounds): (Array[Double], Array[Double]) = {
if (usingBoundConstrainedOptimization) {
val lowerBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.NegativeInfinity)
val upperBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.PositiveInfinity)
val isSetLowerBoundsOnCoefficients = isSet(lowerBoundsOnCoefficients)
val isSetUpperBoundsOnCoefficients = isSet(upperBoundsOnCoefficients)
val isSetLowerBoundsOnIntercepts = isSet(lowerBoundsOnIntercepts)
val isSetUpperBoundsOnIntercepts = isSet(upperBoundsOnIntercepts)

var i = 0
while (i < numCoeffsPlusIntercepts) {
val coefficientSetIndex = i % numCoefficientSets
val featureIndex = i / numCoefficientSets
if (featureIndex < numFeatures) {
if (isSetLowerBoundsOnCoefficients) {
lowerBounds(i) = $(lowerBoundsOnCoefficients)(
coefficientSetIndex, featureIndex) * featuresStd(featureIndex)
}
if (isSetUpperBoundsOnCoefficients) {
upperBounds(i) = $(upperBoundsOnCoefficients)(
coefficientSetIndex, featureIndex) * featuresStd(featureIndex)
}
} else {
if (isSetLowerBoundsOnIntercepts) {
lowerBounds(i) = $(lowerBoundsOnIntercepts)(coefficientSetIndex)
}
if (isSetUpperBoundsOnIntercepts) {
upperBounds(i) = $(upperBoundsOnIntercepts)(coefficientSetIndex)
}
}
i += 1
}
(lowerBounds, upperBounds)
} else {
(null, null)
}
}

val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
if (lowerBounds != null && upperBounds != null) {
new BreezeLBFGSB(
BDV[Double](lowerBounds), BDV[Double](upperBounds), $(maxIter), 10, $(tol))
} else {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
}
} else {
val standardizationParam = $(standardization)
def regParamL1Fun = (index: Int) => {
Expand Down Expand Up @@ -546,6 +739,26 @@ class LogisticRegression @Since("1.2.0") (
math.log(histogram(1) / histogram(0)))
}

if (usingBoundConstrainedOptimization) {
// Make sure all initial values locate in the corresponding bound.
var i = 0
while (i < numCoeffsPlusIntercepts) {
val coefficientSetIndex = i % numCoefficientSets
val featureIndex = i / numCoefficientSets
if (initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) < lowerBounds(i))
{
initialCoefWithInterceptMatrix.update(
coefficientSetIndex, featureIndex, lowerBounds(i))
} else if (
initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) > upperBounds(i))
{
initialCoefWithInterceptMatrix.update(
coefficientSetIndex, featureIndex, upperBounds(i))
}
i += 1
}
}

val states = optimizer.iterations(new CachedDiffFunction(costFun),
new BDV[Double](initialCoefWithInterceptMatrix.toArray))

Expand Down Expand Up @@ -599,7 +812,7 @@ class LogisticRegression @Since("1.2.0") (
if (isIntercept) interceptVec.toArray(classIndex) = value
}

if ($(regParam) == 0.0 && isMultinomial) {
if ($(regParam) == 0.0 && isMultinomial && !usingBoundConstrainedOptimization) {
/*
When no regularization is applied, the multinomial coefficients lack identifiability
because we do not use a pivot class. We can add any constant value to the coefficients
Expand All @@ -620,7 +833,7 @@ class LogisticRegression @Since("1.2.0") (
}

// center the intercepts when using multinomial algorithm
if ($(fitIntercept) && isMultinomial) {
if ($(fitIntercept) && isMultinomial && !usingBoundConstrainedOptimization) {
val interceptArray = interceptVec.toArray
val interceptMean = interceptArray.sum / interceptArray.length
(0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean }
Expand Down
Loading