Skip to content

Commit 278a193

Browse files
committed
Properly fix and regenerate generated code
1 parent 80934ed commit 278a193

File tree

4 files changed

+5
-18
lines changed

4 files changed

+5
-18
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ abstract class ProbabilisticClassificationModel[
206206
var i = 0
207207
val probabilitySize = probability.size
208208
while (i < probabilitySize) {
209+
// thresholds are all > 0
209210
val scaled = probability(i) / thresholds(i)
210211
if (scaled > max) {
211212
max = scaled

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ private[shared] object SharedParamsCodeGen {
5353
" Array must have length equal to the number of classes, with values > 0." +
5454
" The class with largest value p/t is predicted, where p is the original probability" +
5555
" of that class and t is the class's threshold",
56-
isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false),
56+
isValid = "(t: Array[Double]) => t.forall(_ > 0)", finalMethods = false),
5757
ParamDesc[String]("inputCol", "input column name"),
5858
ParamDesc[Array[String]]("inputCols", "input column names"),
5959
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -176,24 +176,10 @@ private[ml] trait HasThreshold extends Params {
176176
private[ml] trait HasThresholds extends Params {
177177

178178
/**
179-
* Defines "thresholds" for each class. These do not, actually, define the minimum
180-
* probability per class for that class to be chosen. They act like like 'cutoff' values in
181-
* R's [[https://cran.r-project.org/web/packages/randomForest/randomForest.pdf randomForest]]
182-
* package, which ironically are also not cutoffs. That is, a class may be selected even if its
183-
* probability does not exceed the threshold.
184-
*
185-
* Array must have length equal to the number of classes, with values > 0.
186-
* The class with largest value p/t is predicted, where p is the original probability of that
187-
* class and t is the class's threshold.
188-
*
179+
* Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
189180
* @group param
190181
*/
191-
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds",
192-
"Thresholds in multi-class classification to adjust the probability of predicting each class. " +
193-
"Array must have length equal to the number of classes, with values > 0. " +
194-
"The class with largest value p/t is predicted, where p is the original probability of that " +
195-
"class and t is the class's threshold",
196-
(t: Array[Double]) => t.forall(_ > 0.0))
182+
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", (t: Array[Double]) => t.forall(_ > 0))
197183

198184
/** @group getParam */
199185
def getThresholds: Array[Double] = $(thresholds)

python/pyspark/ml/param/shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ class HasThresholds(Params):
472472
Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
473473
"""
474474

475-
thresholds = Param(Params._dummy(), "thresholds","Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.", typeConverter=TypeConverters.toListFloat)
475+
thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.", typeConverter=TypeConverters.toListFloat)
476476

477477
def __init__(self):
478478
super(HasThresholds, self).__init__()

0 commit comments

Comments
 (0)