-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-20047][ML] Constrained Logistic Regression #17715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
aea265c
405dffb
37e0ada
2cab4e5
0e866e9
92dfa15
aa7242c
e3ea117
1091fb1
e708e0d
4d51663
43192a4
96fcec4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ import java.util.Locale | |
| import scala.collection.mutable | ||
|
|
||
| import breeze.linalg.{DenseVector => BDV} | ||
| import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} | ||
| import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN} | ||
| import org.apache.hadoop.fs.Path | ||
|
|
||
| import org.apache.spark.SparkException | ||
|
|
@@ -178,11 +178,86 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * The lower bounds on coefficients if fitting under bound constrained optimization. | ||
| * The bound matrix must be compatible with the shape (1, number of features) for binomial | ||
| * regression, or (number of classes, number of features) for multinomial regression. | ||
| * Otherwise, it throws exception. | ||
| * | ||
| * @group param | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd recommend that bound-constrained optimization be put under expertParams. What do you think?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree to put them under expertParams. Thanks. |
||
| */ | ||
| @Since("2.2.0") | ||
| val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients", | ||
| "The lower bounds on coefficients if fitting under bound constrained optimization.") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.2.0") | ||
| def getLowerBoundsOnCoefficients: Matrix = $(lowerBoundsOnCoefficients) | ||
|
|
||
| /** | ||
| * The upper bounds on coefficients if fitting under bound constrained optimization. | ||
| * The bound matrix must be compatible with the shape (1, number of features) for binomial | ||
| * regression, or (number of classes, number of features) for multinomial regression. | ||
| * Otherwise, it throws exception. | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.2.0") | ||
| val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients", | ||
| "The upper bounds on coefficients if fitting under bound constrained optimization.") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.2.0") | ||
| def getUpperBoundsOnCoefficients: Matrix = $(upperBoundsOnCoefficients) | ||
|
|
||
| /** | ||
| * The lower bounds on intercepts if fitting under bound constrained optimization. | ||
| * The bounds vector size must be equal with 1 for binomial regression, or the number | ||
| * of classes for multinomial regression. Otherwise, it throws exception. | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.2.0") | ||
| val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts", | ||
| "The lower bounds on intercepts if fitting under bound constrained optimization.") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.2.0") | ||
| def getLowerBoundsOnIntercepts: Vector = $(lowerBoundsOnIntercepts) | ||
|
|
||
| /** | ||
| * The upper bounds on intercepts if fitting under bound constrained optimization. | ||
| * The bound vector size must be equal with 1 for binomial regression, or the number | ||
| * of classes for multinomial regression. Otherwise, it throws exception. | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.2.0") | ||
| val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts", | ||
| "The upper bounds on intercepts if fitting under bound constrained optimization.") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.2.0") | ||
| def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts) | ||
|
|
||
| protected def usingBoundConstrainedOptimization: Boolean = { | ||
| isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) || | ||
| isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts) | ||
| } | ||
|
|
||
| override protected def validateAndTransformSchema( | ||
| schema: StructType, | ||
| fitting: Boolean, | ||
| featuresDataType: DataType): StructType = { | ||
| checkThresholdConsistency() | ||
| if (usingBoundConstrainedOptimization) { | ||
| require($(elasticNetParam) == 0.0, "Fitting under bound constrained optimization only " + | ||
| s"supports L2 regularization, but got elasticNetParam = $getElasticNetParam.") | ||
| } | ||
| if (!$(fitIntercept)) { | ||
| require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts), | ||
| "Pls don't set bounds on intercepts if fitting without intercept.") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Pls don't" --> "Please do not" |
||
| } | ||
| super.validateAndTransformSchema(schema, fitting, featuresDataType) | ||
| } | ||
| } | ||
|
|
@@ -217,6 +292,9 @@ class LogisticRegression @Since("1.2.0") ( | |
| * For alpha in (0,1), the penalty is a combination of L1 and L2. | ||
| * Default is 0.0 which is an L2 penalty. | ||
| * | ||
| * Note: Fitting under bound constrained optimization only supports L2 regularization, | ||
| * so throws exception if this param is non-zero value. | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("1.4.0") | ||
|
|
@@ -312,6 +390,71 @@ class LogisticRegression @Since("1.2.0") ( | |
| def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) | ||
| setDefault(aggregationDepth -> 2) | ||
|
|
||
| /** | ||
| * Set the lower bounds on coefficients if fitting under bound constrained optimization. | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.2.0") | ||
| def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value) | ||
|
|
||
| /** | ||
| * Set the upper bounds on coefficients if fitting under bound constrained optimization. | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.2.0") | ||
| def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value) | ||
|
|
||
| /** | ||
| * Set the lower bounds on intercepts if fitting under bound constrained optimization. | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.2.0") | ||
| def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value) | ||
|
|
||
| /** | ||
| * Set the upper bounds on intercepts if fitting under bound constrained optimization. | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.2.0") | ||
| def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value) | ||
|
|
||
| private def assertBoundConstrainedOptimizationParamsValid( | ||
| numCoefficientSets: Int, | ||
| numFeatures: Int): Unit = { | ||
| if (isSet(lowerBoundsOnCoefficients)) { | ||
| require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets && | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These require() statements should have error messages so users know what went wrong. |
||
| $(lowerBoundsOnCoefficients).numCols == numFeatures) | ||
| } | ||
| if (isSet(upperBoundsOnCoefficients)) { | ||
| require($(upperBoundsOnCoefficients).numRows == numCoefficientSets && | ||
| $(upperBoundsOnCoefficients).numCols == numFeatures) | ||
| } | ||
| if (isSet(lowerBoundsOnIntercepts)) { | ||
| require($(lowerBoundsOnIntercepts).size == numCoefficientSets) | ||
| } | ||
| if (isSet(upperBoundsOnIntercepts)) { | ||
| require($(upperBoundsOnIntercepts).size == numCoefficientSets) | ||
| } | ||
| if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) { | ||
| require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray) | ||
| .forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always " + | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. always => always be |
||
| "less than or equal to upperBoundsOnCoefficients, but found: " + | ||
| s"lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " + | ||
| s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients.") | ||
| } | ||
| if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) { | ||
| require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray) | ||
| .forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always " + | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
| "less than or equal to upperBoundsOnIntercepts, but found: " + | ||
| s"lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " + | ||
| s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts.") | ||
| } | ||
| } | ||
|
|
||
| private var optInitialModel: Option[LogisticRegressionModel] = None | ||
|
|
||
| private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = { | ||
|
|
@@ -378,6 +521,11 @@ class LogisticRegression @Since("1.2.0") ( | |
| } | ||
| val numCoefficientSets = if (isMultinomial) numClasses else 1 | ||
|
|
||
| // Check params interaction is valid if fitting under bound constrained optimization. | ||
| if (usingBoundConstrainedOptimization) { | ||
| assertBoundConstrainedOptimizationParamsValid(numCoefficientSets, numFeatures) | ||
| } | ||
|
|
||
| if (isDefined(thresholds)) { | ||
| require($(thresholds).length == numClasses, this.getClass.getSimpleName + | ||
| ".train() called with non-matching numClasses and thresholds.length." + | ||
|
|
@@ -397,7 +545,7 @@ class LogisticRegression @Since("1.2.0") ( | |
|
|
||
| val isConstantLabel = histogram.count(_ != 0.0) == 1 | ||
|
|
||
| if ($(fitIntercept) && isConstantLabel) { | ||
| if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) { | ||
| logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + | ||
| s"will be zeros. Training is not needed.") | ||
| val constantLabelIndex = Vectors.dense(histogram).argmax | ||
|
|
@@ -434,8 +582,53 @@ class LogisticRegression @Since("1.2.0") ( | |
| $(standardization), bcFeaturesStd, regParamL2, multinomial = isMultinomial, | ||
| $(aggregationDepth)) | ||
|
|
||
| val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets | ||
|
|
||
| val (lowerBounds, upperBounds): (Array[Double], Array[Double]) = { | ||
| if (usingBoundConstrainedOptimization) { | ||
| val lowerBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.NegativeInfinity) | ||
| val upperBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.PositiveInfinity) | ||
| val isSetLowerBoundsOnCoefficients = isSet(lowerBoundsOnCoefficients) | ||
| val isSetUpperBoundsOnCoefficients = isSet(upperBoundsOnCoefficients) | ||
| val isSetLowerBoundsOnIntercepts = isSet(lowerBoundsOnIntercepts) | ||
| val isSetUpperBoundsOnIntercepts = isSet(upperBoundsOnIntercepts) | ||
|
|
||
| var i = 0 | ||
| while (i < numCoeffsPlusIntercepts) { | ||
| val coefficientSetIndex = i % numCoefficientSets | ||
| val featureIndex = i / numCoefficientSets | ||
| if (featureIndex < numFeatures) { | ||
| if (isSetLowerBoundsOnCoefficients) { | ||
| lowerBounds(i) = $(lowerBoundsOnCoefficients)( | ||
| coefficientSetIndex, featureIndex) * featuresStd(featureIndex) | ||
| } | ||
| if (isSetUpperBoundsOnCoefficients) { | ||
| upperBounds(i) = $(upperBoundsOnCoefficients)( | ||
| coefficientSetIndex, featureIndex) * featuresStd(featureIndex) | ||
| } | ||
| } else { | ||
| if (isSetLowerBoundsOnIntercepts) { | ||
| lowerBounds(i) = $(lowerBoundsOnIntercepts)(coefficientSetIndex) | ||
| } | ||
| if (isSetUpperBoundsOnIntercepts) { | ||
| upperBounds(i) = $(upperBoundsOnIntercepts)(coefficientSetIndex) | ||
| } | ||
| } | ||
| i += 1 | ||
| } | ||
| (lowerBounds, upperBounds) | ||
| } else { | ||
| (null, null) | ||
| } | ||
| } | ||
|
|
||
| val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { | ||
| new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) | ||
| if (lowerBounds != null && upperBounds != null) { | ||
| new BreezeLBFGSB( | ||
| BDV[Double](lowerBounds), BDV[Double](upperBounds), $(maxIter), 10, $(tol)) | ||
| } else { | ||
| new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) | ||
| } | ||
| } else { | ||
| val standardizationParam = $(standardization) | ||
| def regParamL1Fun = (index: Int) => { | ||
|
|
@@ -546,6 +739,26 @@ class LogisticRegression @Since("1.2.0") ( | |
| math.log(histogram(1) / histogram(0))) | ||
| } | ||
|
|
||
| if (usingBoundConstrainedOptimization) { | ||
| // Make sure all initial values locate in the corresponding bound. | ||
| var i = 0 | ||
| while (i < numCoeffsPlusIntercepts) { | ||
| val coefficientSetIndex = i % numCoefficientSets | ||
| val featureIndex = i / numCoefficientSets | ||
| if (initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) < lowerBounds(i)) | ||
| { | ||
| initialCoefWithInterceptMatrix.update( | ||
| coefficientSetIndex, featureIndex, lowerBounds(i)) | ||
| } else if ( | ||
| initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) > upperBounds(i)) | ||
| { | ||
| initialCoefWithInterceptMatrix.update( | ||
| coefficientSetIndex, featureIndex, upperBounds(i)) | ||
| } | ||
| i += 1 | ||
| } | ||
| } | ||
|
|
||
| val states = optimizer.iterations(new CachedDiffFunction(costFun), | ||
| new BDV[Double](initialCoefWithInterceptMatrix.toArray)) | ||
|
|
||
|
|
@@ -599,7 +812,7 @@ class LogisticRegression @Since("1.2.0") ( | |
| if (isIntercept) interceptVec.toArray(classIndex) = value | ||
| } | ||
|
|
||
| if ($(regParam) == 0.0 && isMultinomial) { | ||
| if ($(regParam) == 0.0 && isMultinomial && !usingBoundConstrainedOptimization) { | ||
| /* | ||
| When no regularization is applied, the multinomial coefficients lack identifiability | ||
| because we do not use a pivot class. We can add any constant value to the coefficients | ||
|
|
@@ -620,7 +833,7 @@ class LogisticRegression @Since("1.2.0") ( | |
| } | ||
|
|
||
| // center the intercepts when using multinomial algorithm | ||
| if ($(fitIntercept) && isMultinomial) { | ||
| if ($(fitIntercept) && isMultinomial && !usingBoundConstrainedOptimization) { | ||
| val interceptArray = interceptVec.toArray | ||
| val interceptMean = interceptArray.sum / interceptArray.length | ||
| (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should state the default value is none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for the other new bound Params