@@ -22,7 +22,7 @@ import java.util.Locale
2222import scala .collection .mutable
2323
2424import breeze .linalg .{DenseVector => BDV }
25- import breeze .optimize .{CachedDiffFunction , DiffFunction , LBFGS => BreezeLBFGS , OWLQN => BreezeOWLQN }
25+ import breeze .optimize .{CachedDiffFunction , DiffFunction , LBFGS => BreezeLBFGS , LBFGSB => BreezeLBFGSB , OWLQN => BreezeOWLQN }
2626import org .apache .hadoop .fs .Path
2727
2828import org .apache .spark .SparkException
@@ -178,11 +178,86 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
178178 }
179179 }
180180
181+ /**
182+ * The lower bounds on coefficients if fitting under bound constrained optimization.
183+ * The bound matrix must be compatible with the shape (1, number of features) for binomial
184+ * regression, or (number of classes, number of features) for multinomial regression.
185+ * Otherwise, it throws exception.
186+ *
187+ * @group param
188+ */
189+ @ Since (" 2.2.0" )
190+ val lowerBoundsOnCoefficients : Param [Matrix ] = new Param (this , " lowerBoundsOnCoefficients" ,
191+ " The lower bounds on coefficients if fitting under bound constrained optimization." )
192+
193+ /** @group getParam */
194+ @ Since (" 2.2.0" )
195+ def getLowerBoundsOnCoefficients : Matrix = $(lowerBoundsOnCoefficients)
196+
197+ /**
198+ * The upper bounds on coefficients if fitting under bound constrained optimization.
199+ * The bound matrix must be compatible with the shape (1, number of features) for binomial
200+ * regression, or (number of classes, number of features) for multinomial regression.
201+ * Otherwise, it throws exception.
202+ *
203+ * @group param
204+ */
205+ @ Since (" 2.2.0" )
206+ val upperBoundsOnCoefficients : Param [Matrix ] = new Param (this , " upperBoundsOnCoefficients" ,
207+ " The upper bounds on coefficients if fitting under bound constrained optimization." )
208+
209+ /** @group getParam */
210+ @ Since (" 2.2.0" )
211+ def getUpperBoundsOnCoefficients : Matrix = $(upperBoundsOnCoefficients)
212+
213+ /**
214+ * The lower bounds on intercepts if fitting under bound constrained optimization.
215+ * The bounds vector size must be equal with 1 for binomial regression, or the number
216+ * of classes for multinomial regression. Otherwise, it throws exception.
217+ *
218+ * @group param
219+ */
220+ @ Since (" 2.2.0" )
221+ val lowerBoundsOnIntercepts : Param [Vector ] = new Param (this , " lowerBoundsOnIntercepts" ,
222+ " The lower bounds on intercepts if fitting under bound constrained optimization." )
223+
224+ /** @group getParam */
225+ @ Since (" 2.2.0" )
226+ def getLowerBoundsOnIntercepts : Vector = $(lowerBoundsOnIntercepts)
227+
228+ /**
229+ * The upper bounds on intercepts if fitting under bound constrained optimization.
230+ * The bound vector size must be equal with 1 for binomial regression, or the number
231+ * of classes for multinomial regression. Otherwise, it throws exception.
232+ *
233+ * @group param
234+ */
235+ @ Since (" 2.2.0" )
236+ val upperBoundsOnIntercepts : Param [Vector ] = new Param (this , " upperBoundsOnIntercepts" ,
237+ " The upper bounds on intercepts if fitting under bound constrained optimization." )
238+
239+ /** @group getParam */
240+ @ Since (" 2.2.0" )
241+ def getUpperBoundsOnIntercepts : Vector = $(upperBoundsOnIntercepts)
242+
243+ protected def usingBoundConstrainedOptimization : Boolean = {
244+ isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) ||
245+ isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts)
246+ }
247+
181248 override protected def validateAndTransformSchema (
182249 schema : StructType ,
183250 fitting : Boolean ,
184251 featuresDataType : DataType ): StructType = {
185252 checkThresholdConsistency()
253+ if (usingBoundConstrainedOptimization) {
254+ require($(elasticNetParam) == 0.0 , " Fitting under bound constrained optimization only " +
255+ s " supports L2 regularization, but got elasticNetParam = $getElasticNetParam. " )
256+ }
257+ if (! $(fitIntercept)) {
258+ require(! isSet(lowerBoundsOnIntercepts) && ! isSet(upperBoundsOnIntercepts),
259+ " Pls don't set bounds on intercepts if fitting without intercept." )
260+ }
186261 super .validateAndTransformSchema(schema, fitting, featuresDataType)
187262 }
188263}
@@ -217,6 +292,9 @@ class LogisticRegression @Since("1.2.0") (
217292 * For alpha in (0,1), the penalty is a combination of L1 and L2.
218293 * Default is 0.0 which is an L2 penalty.
219294 *
295+ * Note: Fitting under bound constrained optimization only supports L2 regularization,
296+ * so throws exception if this param is non-zero value.
297+ *
220298 * @group setParam
221299 */
222300 @ Since (" 1.4.0" )
@@ -312,6 +390,71 @@ class LogisticRegression @Since("1.2.0") (
312390 def setAggregationDepth (value : Int ): this .type = set(aggregationDepth, value)
313391 setDefault(aggregationDepth -> 2 )
314392
393+ /**
394+ * Set the lower bounds on coefficients if fitting under bound constrained optimization.
395+ *
396+ * @group setParam
397+ */
398+ @ Since (" 2.2.0" )
399+ def setLowerBoundsOnCoefficients (value : Matrix ): this .type = set(lowerBoundsOnCoefficients, value)
400+
401+ /**
402+ * Set the upper bounds on coefficients if fitting under bound constrained optimization.
403+ *
404+ * @group setParam
405+ */
406+ @ Since (" 2.2.0" )
407+ def setUpperBoundsOnCoefficients (value : Matrix ): this .type = set(upperBoundsOnCoefficients, value)
408+
409+ /**
410+ * Set the lower bounds on intercepts if fitting under bound constrained optimization.
411+ *
412+ * @group setParam
413+ */
414+ @ Since (" 2.2.0" )
415+ def setLowerBoundsOnIntercepts (value : Vector ): this .type = set(lowerBoundsOnIntercepts, value)
416+
417+ /**
418+ * Set the upper bounds on intercepts if fitting under bound constrained optimization.
419+ *
420+ * @group setParam
421+ */
422+ @ Since (" 2.2.0" )
423+ def setUpperBoundsOnIntercepts (value : Vector ): this .type = set(upperBoundsOnIntercepts, value)
424+
425+ private def assertBoundConstrainedOptimizationParamsValid (
426+ numCoefficientSets : Int ,
427+ numFeatures : Int ): Unit = {
428+ if (isSet(lowerBoundsOnCoefficients)) {
429+ require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets &&
430+ $(lowerBoundsOnCoefficients).numCols == numFeatures)
431+ }
432+ if (isSet(upperBoundsOnCoefficients)) {
433+ require($(upperBoundsOnCoefficients).numRows == numCoefficientSets &&
434+ $(upperBoundsOnCoefficients).numCols == numFeatures)
435+ }
436+ if (isSet(lowerBoundsOnIntercepts)) {
437+ require($(lowerBoundsOnIntercepts).size == numCoefficientSets)
438+ }
439+ if (isSet(upperBoundsOnIntercepts)) {
440+ require($(upperBoundsOnIntercepts).size == numCoefficientSets)
441+ }
442+ if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) {
443+ require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray)
444+ .forall(x => x._1 <= x._2), " LowerBoundsOnCoefficients should always " +
445+ " less than or equal to upperBoundsOnCoefficients, but found: " +
446+ s " lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " +
447+ s " upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients. " )
448+ }
449+ if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) {
450+ require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray)
451+ .forall(x => x._1 <= x._2), " LowerBoundsOnIntercepts should always " +
452+ " less than or equal to upperBoundsOnIntercepts, but found: " +
453+ s " lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " +
454+ s " upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts. " )
455+ }
456+ }
457+
315458 private var optInitialModel : Option [LogisticRegressionModel ] = None
316459
317460 private [spark] def setInitialModel (model : LogisticRegressionModel ): this .type = {
@@ -378,6 +521,11 @@ class LogisticRegression @Since("1.2.0") (
378521 }
379522 val numCoefficientSets = if (isMultinomial) numClasses else 1
380523
524+ // Check params interaction is valid if fitting under bound constrained optimization.
525+ if (usingBoundConstrainedOptimization) {
526+ assertBoundConstrainedOptimizationParamsValid(numCoefficientSets, numFeatures)
527+ }
528+
381529 if (isDefined(thresholds)) {
382530 require($(thresholds).length == numClasses, this .getClass.getSimpleName +
383531 " .train() called with non-matching numClasses and thresholds.length." +
@@ -397,7 +545,7 @@ class LogisticRegression @Since("1.2.0") (
397545
398546 val isConstantLabel = histogram.count(_ != 0.0 ) == 1
399547
400- if ($(fitIntercept) && isConstantLabel) {
548+ if ($(fitIntercept) && isConstantLabel && ! usingBoundConstrainedOptimization ) {
401549 logWarning(s " All labels are the same value and fitIntercept=true, so the coefficients " +
402550 s " will be zeros. Training is not needed. " )
403551 val constantLabelIndex = Vectors .dense(histogram).argmax
@@ -434,8 +582,53 @@ class LogisticRegression @Since("1.2.0") (
434582 $(standardization), bcFeaturesStd, regParamL2, multinomial = isMultinomial,
435583 $(aggregationDepth))
436584
585+ val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets
586+
587+ val (lowerBounds, upperBounds): (Array [Double ], Array [Double ]) = {
588+ if (usingBoundConstrainedOptimization) {
589+ val lowerBounds = Array .fill[Double ](numCoeffsPlusIntercepts)(Double .NegativeInfinity )
590+ val upperBounds = Array .fill[Double ](numCoeffsPlusIntercepts)(Double .PositiveInfinity )
591+ val isSetLowerBoundsOnCoefficients = isSet(lowerBoundsOnCoefficients)
592+ val isSetUpperBoundsOnCoefficients = isSet(upperBoundsOnCoefficients)
593+ val isSetLowerBoundsOnIntercepts = isSet(lowerBoundsOnIntercepts)
594+ val isSetUpperBoundsOnIntercepts = isSet(upperBoundsOnIntercepts)
595+
596+ var i = 0
597+ while (i < numCoeffsPlusIntercepts) {
598+ val coefficientSetIndex = i % numCoefficientSets
599+ val featureIndex = i / numCoefficientSets
600+ if (featureIndex < numFeatures) {
601+ if (isSetLowerBoundsOnCoefficients) {
602+ lowerBounds(i) = $(lowerBoundsOnCoefficients)(
603+ coefficientSetIndex, featureIndex) * featuresStd(featureIndex)
604+ }
605+ if (isSetUpperBoundsOnCoefficients) {
606+ upperBounds(i) = $(upperBoundsOnCoefficients)(
607+ coefficientSetIndex, featureIndex) * featuresStd(featureIndex)
608+ }
609+ } else {
610+ if (isSetLowerBoundsOnIntercepts) {
611+ lowerBounds(i) = $(lowerBoundsOnIntercepts)(coefficientSetIndex)
612+ }
613+ if (isSetUpperBoundsOnIntercepts) {
614+ upperBounds(i) = $(upperBoundsOnIntercepts)(coefficientSetIndex)
615+ }
616+ }
617+ i += 1
618+ }
619+ (lowerBounds, upperBounds)
620+ } else {
621+ (null , null )
622+ }
623+ }
624+
437625 val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0 ) {
438- new BreezeLBFGS [BDV [Double ]]($(maxIter), 10 , $(tol))
626+ if (lowerBounds != null && upperBounds != null ) {
627+ new BreezeLBFGSB (
628+ BDV [Double ](lowerBounds), BDV [Double ](upperBounds), $(maxIter), 10 , $(tol))
629+ } else {
630+ new BreezeLBFGS [BDV [Double ]]($(maxIter), 10 , $(tol))
631+ }
439632 } else {
440633 val standardizationParam = $(standardization)
441634 def regParamL1Fun = (index : Int ) => {
@@ -546,6 +739,26 @@ class LogisticRegression @Since("1.2.0") (
546739 math.log(histogram(1 ) / histogram(0 )))
547740 }
548741
742+ if (usingBoundConstrainedOptimization) {
743+ // Make sure all initial values locate in the corresponding bound.
744+ var i = 0
745+ while (i < numCoeffsPlusIntercepts) {
746+ val coefficientSetIndex = i % numCoefficientSets
747+ val featureIndex = i / numCoefficientSets
748+ if (initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) < lowerBounds(i))
749+ {
750+ initialCoefWithInterceptMatrix.update(
751+ coefficientSetIndex, featureIndex, lowerBounds(i))
752+ } else if (
753+ initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) > upperBounds(i))
754+ {
755+ initialCoefWithInterceptMatrix.update(
756+ coefficientSetIndex, featureIndex, upperBounds(i))
757+ }
758+ i += 1
759+ }
760+ }
761+
549762 val states = optimizer.iterations(new CachedDiffFunction (costFun),
550763 new BDV [Double ](initialCoefWithInterceptMatrix.toArray))
551764
@@ -599,7 +812,7 @@ class LogisticRegression @Since("1.2.0") (
599812 if (isIntercept) interceptVec.toArray(classIndex) = value
600813 }
601814
602- if ($(regParam) == 0.0 && isMultinomial) {
815+ if ($(regParam) == 0.0 && isMultinomial && ! usingBoundConstrainedOptimization ) {
603816 /*
604817 When no regularization is applied, the multinomial coefficients lack identifiability
605818 because we do not use a pivot class. We can add any constant value to the coefficients
@@ -620,7 +833,7 @@ class LogisticRegression @Since("1.2.0") (
620833 }
621834
622835 // center the intercepts when using multinomial algorithm
623- if ($(fitIntercept) && isMultinomial) {
836+ if ($(fitIntercept) && isMultinomial && ! usingBoundConstrainedOptimization ) {
624837 val interceptArray = interceptVec.toArray
625838 val interceptMean = interceptArray.sum / interceptArray.length
626839 (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean }
0 commit comments