Skip to content

Commit 2fa331e

Browse files
committed
Update MultinomialLogisticRegression test output to match new threshold meaning
1 parent 08dbe43 commit 2fa331e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -988,22 +988,22 @@ class MultinomialLogisticRegressionSuite
988988
val basePredictions = model.transform(dataset).select("prediction").collect()
989989

990990
// should predict all zeros
991-
model.setThresholds(Array(1, 1000, 1000))
991+
model.setThresholds(Array(0, 1, 1))
992992
val zeroPredictions = model.transform(dataset).select("prediction").collect()
993993
assert(zeroPredictions.forall(_.getDouble(0) === 0.0))
994994

995995
// should predict all ones
996-
model.setThresholds(Array(1000, 1, 1000))
996+
model.setThresholds(Array(1, 0, 1))
997997
val onePredictions = model.transform(dataset).select("prediction").collect()
998998
assert(onePredictions.forall(_.getDouble(0) === 1.0))
999999

10001000
// should predict all twos
1001-
model.setThresholds(Array(1000, 1000, 1))
1001+
model.setThresholds(Array(1, 1, 0))
10021002
val twoPredictions = model.transform(dataset).select("prediction").collect()
10031003
assert(twoPredictions.forall(_.getDouble(0) === 2.0))
10041004

10051005
// constant threshold scaling is the same as no thresholds
1006-
model.setThresholds(Array(1000, 1000, 1000))
1006+
model.setThresholds(Array(0.1, 0.1, 0.1))
10071007
val scaledPredictions = model.transform(dataset).select("prediction").collect()
10081008
assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
10091009
scaled.getDouble(0) === base.getDouble(0)

0 commit comments

Comments
 (0)