Skip to content

Commit 248916f

Browse files
committed
[SPARK-17057][ML] ProbabilisticClassifierModels' thresholds should have at most one 0
## What changes were proposed in this pull request? Match ProbabilisticClassifer.thresholds requirements to R randomForest cutoff, requiring all > 0 ## How was this patch tested? Jenkins tests plus new test cases Author: Sean Owen <[email protected]> Closes #15149 from srowen/SPARK-17057.
1 parent f3fe554 commit 248916f

File tree

7 files changed

+52
-29
lines changed

7 files changed

+52
-29
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,10 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
123123

124124
/**
125125
* Set thresholds in multiclass (or binary) classification to adjust the probability of
126-
* predicting each class. Array must have length equal to the number of classes, with values >= 0.
126+
* predicting each class. Array must have length equal to the number of classes, with values > 0,
127+
* excepting that at most one value may be 0.
127128
* The class with largest value p/t is predicted, where p is the original probability of that
128-
* class and t is the class' threshold.
129+
* class and t is the class's threshold.
129130
*
130131
* Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared.
131132
* If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be

mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.annotation.DeveloperApi
21-
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors, VectorUDT}
21+
import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT}
2222
import org.apache.spark.ml.param.shared._
2323
import org.apache.spark.ml.util.SchemaUtils
2424
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -200,22 +200,20 @@ abstract class ProbabilisticClassificationModel[
200200
if (!isDefined(thresholds)) {
201201
probability.argmax
202202
} else {
203-
val thresholds: Array[Double] = getThresholds
204-
val probabilities = probability.toArray
203+
val thresholds = getThresholds
205204
var argMax = 0
206205
var max = Double.NegativeInfinity
207206
var i = 0
208207
val probabilitySize = probability.size
209208
while (i < probabilitySize) {
210-
if (thresholds(i) == 0.0) {
211-
max = Double.PositiveInfinity
209+
// Thresholds are all > 0, excepting that at most one may be 0.
210+
// The single class whose threshold is 0, if any, will always be predicted
211+
// ('scaled' = +Infinity). However in the case that this class also has
212+
// 0 probability, the class will not be selected ('scaled' is NaN).
213+
val scaled = probability(i) / thresholds(i)
214+
if (scaled > max) {
215+
max = scaled
212216
argMax = i
213-
} else {
214-
val scaled = probabilities(i) / thresholds(i)
215-
if (scaled > max) {
216-
max = scaled
217-
argMax = i
218-
}
219217
}
220218
i += 1
221219
}

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ private[shared] object SharedParamsCodeGen {
5050
isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
5151
ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" +
5252
" to adjust the probability of predicting each class." +
53-
" Array must have length equal to the number of classes, with values >= 0." +
53+
" Array must have length equal to the number of classes, with values > 0" +
54+
" excepting that at most one value may be 0." +
5455
" The class with largest value p/t is predicted, where p is the original probability" +
55-
" of that class and t is the class' threshold",
56-
isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false),
56+
" of that class and t is the class's threshold",
57+
isValid = "(t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1",
58+
finalMethods = false),
5759
ParamDesc[String]("inputCol", "input column name"),
5860
ParamDesc[Array[String]]("inputCols", "input column names"),
5961
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,10 @@ private[ml] trait HasThreshold extends Params {
176176
private[ml] trait HasThresholds extends Params {
177177

178178
/**
179-
* Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
179+
* Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
180180
* @group param
181181
*/
182-
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
182+
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", (t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1)
183183

184184
/** @group getParam */
185185
def getThresholds: Array[Double] = $(thresholds)

mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,46 @@ final class TestProbabilisticClassificationModel(
3636
rawPrediction
3737
}
3838

39-
def friendlyPredict(input: Vector): Double = {
40-
predict(input)
39+
def friendlyPredict(values: Double*): Double = {
40+
predict(Vectors.dense(values.toArray))
4141
}
4242
}
4343

4444

4545
class ProbabilisticClassifierSuite extends SparkFunSuite {
4646

4747
test("test thresholding") {
48-
val thresholds = Array(0.5, 0.2)
4948
val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
50-
.setThresholds(thresholds)
51-
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
52-
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
49+
.setThresholds(Array(0.5, 0.2))
50+
assert(testModel.friendlyPredict(1.0, 1.0) === 1.0)
51+
assert(testModel.friendlyPredict(1.0, 0.2) === 0.0)
5352
}
5453

5554
test("test thresholding not required") {
5655
val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
57-
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
56+
assert(testModel.friendlyPredict(1.0, 2.0) === 1.0)
57+
}
58+
59+
test("test tiebreak") {
60+
val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
61+
.setThresholds(Array(0.4, 0.4))
62+
assert(testModel.friendlyPredict(0.6, 0.6) === 0.0)
63+
}
64+
65+
test("test one zero threshold") {
66+
val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
67+
.setThresholds(Array(0.0, 0.1))
68+
assert(testModel.friendlyPredict(1.0, 10.0) === 0.0)
69+
assert(testModel.friendlyPredict(0.0, 10.0) === 1.0)
70+
}
71+
72+
test("bad thresholds") {
73+
intercept[IllegalArgumentException] {
74+
new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(0.0, 0.0))
75+
}
76+
intercept[IllegalArgumentException] {
77+
new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1))
78+
}
5879
}
5980
}
6081

python/pyspark/ml/param/_shared_params_code_gen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ def get$Name(self):
139139
"model.", "True", "TypeConverters.toBoolean"),
140140
("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
141141
"predicting each class. Array must have length equal to the number of classes, with " +
142-
"values >= 0. The class with largest value p/t is predicted, where p is the original " +
143-
"probability of that class and t is the class' threshold.", None,
142+
"values > 0, excepting that at most one value may be 0. " +
143+
"The class with largest value p/t is predicted, where p is the original " +
144+
"probability of that class and t is the class's threshold.", None,
144145
"TypeConverters.toListFloat"),
145146
("weightCol", "weight column name. If this is not set or empty, we treat " +
146147
"all instance weights as 1.0.", None, "TypeConverters.toString"),

python/pyspark/ml/param/shared.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,10 @@ def getStandardization(self):
469469

470470
class HasThresholds(Params):
471471
"""
472-
Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
472+
Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
473473
"""
474474

475-
thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", typeConverter=TypeConverters.toListFloat)
475+
thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.", typeConverter=TypeConverters.toListFloat)
476476

477477
def __init__(self):
478478
super(HasThresholds, self).__init__()

0 commit comments

Comments
 (0)