Skip to content

Commit c5dceb8

Browse files
committed
[SPARK-20047][FOLLOWUP][ML] Constrained Logistic Regression follow up
## What changes were proposed in this pull request? Address some minor comments for #17715: * Put bound-constrained optimization params under expertParams. * Update some docs. ## How was this patch tested? Existing tests. Author: Yanbo Liang <[email protected]> Closes #17829 from yanboliang/spark-20047-followup.
1 parent 57b6470 commit c5dceb8

File tree

1 file changed

+35
-19
lines changed

1 file changed

+35
-19
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,15 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
183183
* The bound matrix must be compatible with the shape (1, number of features) for binomial
184184
* regression, or (number of classes, number of features) for multinomial regression.
185185
* Otherwise, it throws exception.
186+
* Default is none.
186187
*
187-
* @group param
188+
* @group expertParam
188189
*/
189190
@Since("2.2.0")
190191
val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients",
191192
"The lower bounds on coefficients if fitting under bound constrained optimization.")
192193

193-
/** @group getParam */
194+
/** @group expertGetParam */
194195
@Since("2.2.0")
195196
def getLowerBoundsOnCoefficients: Matrix = $(lowerBoundsOnCoefficients)
196197

@@ -199,44 +200,47 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
199200
* The bound matrix must be compatible with the shape (1, number of features) for binomial
200201
* regression, or (number of classes, number of features) for multinomial regression.
201202
* Otherwise, it throws exception.
203+
* Default is none.
202204
*
203-
* @group param
205+
* @group expertParam
204206
*/
205207
@Since("2.2.0")
206208
val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients",
207209
"The upper bounds on coefficients if fitting under bound constrained optimization.")
208210

209-
/** @group getParam */
211+
/** @group expertGetParam */
210212
@Since("2.2.0")
211213
def getUpperBoundsOnCoefficients: Matrix = $(upperBoundsOnCoefficients)
212214

213215
/**
214216
* The lower bounds on intercepts if fitting under bound constrained optimization.
215217
* The bounds vector size must be equal with 1 for binomial regression, or the number
216218
* of classes for multinomial regression. Otherwise, it throws exception.
219+
* Default is none.
217220
*
218-
* @group param
221+
* @group expertParam
219222
*/
220223
@Since("2.2.0")
221224
val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts",
222225
"The lower bounds on intercepts if fitting under bound constrained optimization.")
223226

224-
/** @group getParam */
227+
/** @group expertGetParam */
225228
@Since("2.2.0")
226229
def getLowerBoundsOnIntercepts: Vector = $(lowerBoundsOnIntercepts)
227230

228231
/**
229232
* The upper bounds on intercepts if fitting under bound constrained optimization.
230233
* The bound vector size must be equal with 1 for binomial regression, or the number
231234
* of classes for multinomial regression. Otherwise, it throws exception.
235+
* Default is none.
232236
*
233-
* @group param
237+
* @group expertParam
234238
*/
235239
@Since("2.2.0")
236240
val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts",
237241
"The upper bounds on intercepts if fitting under bound constrained optimization.")
238242

239-
/** @group getParam */
243+
/** @group expertGetParam */
240244
@Since("2.2.0")
241245
def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts)
242246

@@ -256,7 +260,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
256260
}
257261
if (!$(fitIntercept)) {
258262
require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts),
259-
"Pls don't set bounds on intercepts if fitting without intercept.")
263+
"Please don't set bounds on intercepts if fitting without intercept.")
260264
}
261265
super.validateAndTransformSchema(schema, fitting, featuresDataType)
262266
}
@@ -393,31 +397,31 @@ class LogisticRegression @Since("1.2.0") (
393397
/**
394398
* Set the lower bounds on coefficients if fitting under bound constrained optimization.
395399
*
396-
* @group setParam
400+
* @group expertSetParam
397401
*/
398402
@Since("2.2.0")
399403
def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value)
400404

401405
/**
402406
* Set the upper bounds on coefficients if fitting under bound constrained optimization.
403407
*
404-
* @group setParam
408+
* @group expertSetParam
405409
*/
406410
@Since("2.2.0")
407411
def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value)
408412

409413
/**
410414
* Set the lower bounds on intercepts if fitting under bound constrained optimization.
411415
*
412-
* @group setParam
416+
* @group expertSetParam
413417
*/
414418
@Since("2.2.0")
415419
def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value)
416420

417421
/**
418422
* Set the upper bounds on intercepts if fitting under bound constrained optimization.
419423
*
420-
* @group setParam
424+
* @group expertSetParam
421425
*/
422426
@Since("2.2.0")
423427
def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value)
@@ -427,28 +431,40 @@ class LogisticRegression @Since("1.2.0") (
427431
numFeatures: Int): Unit = {
428432
if (isSet(lowerBoundsOnCoefficients)) {
429433
require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets &&
430-
$(lowerBoundsOnCoefficients).numCols == numFeatures)
434+
$(lowerBoundsOnCoefficients).numCols == numFeatures,
435+
"The shape of LowerBoundsOnCoefficients must be compatible with (1, number of features) " +
436+
"for binomial regression, or (number of classes, number of features) for multinomial " +
437+
"regression, but found: " +
438+
s"(${getLowerBoundsOnCoefficients.numRows}, ${getLowerBoundsOnCoefficients.numCols}).")
431439
}
432440
if (isSet(upperBoundsOnCoefficients)) {
433441
require($(upperBoundsOnCoefficients).numRows == numCoefficientSets &&
434-
$(upperBoundsOnCoefficients).numCols == numFeatures)
442+
$(upperBoundsOnCoefficients).numCols == numFeatures,
443+
"The shape of upperBoundsOnCoefficients must be compatible with (1, number of features) " +
444+
"for binomial regression, or (number of classes, number of features) for multinomial " +
445+
"regression, but found: " +
446+
s"(${getUpperBoundsOnCoefficients.numRows}, ${getUpperBoundsOnCoefficients.numCols}).")
435447
}
436448
if (isSet(lowerBoundsOnIntercepts)) {
437-
require($(lowerBoundsOnIntercepts).size == numCoefficientSets)
449+
require($(lowerBoundsOnIntercepts).size == numCoefficientSets, "The size of " +
450+
"lowerBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " +
451+
s"classes for multinomial regression, but found: ${getLowerBoundsOnIntercepts.size}.")
438452
}
439453
if (isSet(upperBoundsOnIntercepts)) {
440-
require($(upperBoundsOnIntercepts).size == numCoefficientSets)
454+
require($(upperBoundsOnIntercepts).size == numCoefficientSets, "The size of " +
455+
"upperBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " +
456+
s"classes for multinomial regression, but found: ${getUpperBoundsOnIntercepts.size}.")
441457
}
442458
if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) {
443459
require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray)
444-
.forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always " +
460+
.forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always be " +
445461
"less than or equal to upperBoundsOnCoefficients, but found: " +
446462
s"lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " +
447463
s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients.")
448464
}
449465
if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) {
450466
require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray)
451-
.forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always " +
467+
.forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always be " +
452468
"less than or equal to upperBoundsOnIntercepts, but found: " +
453469
s"lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " +
454470
s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts.")

0 commit comments

Comments
 (0)