@@ -988,22 +988,22 @@ class MultinomialLogisticRegressionSuite
988988 val basePredictions = model.transform(dataset).select(" prediction" ).collect()
989989
990990 // should predict all zeros
991- model.setThresholds(Array (1 , 1000 , 1000 ))
991+ model.setThresholds(Array (0 , 1 , 1 ))
992992 val zeroPredictions = model.transform(dataset).select(" prediction" ).collect()
993993 assert(zeroPredictions.forall(_.getDouble(0 ) === 0.0 ))
994994
995995 // should predict all ones
996- model.setThresholds(Array (1000 , 1 , 1000 ))
996+ model.setThresholds(Array (1 , 0 , 1 ))
997997 val onePredictions = model.transform(dataset).select(" prediction" ).collect()
998998 assert(onePredictions.forall(_.getDouble(0 ) === 1.0 ))
999999
10001000 // should predict all twos
1001- model.setThresholds(Array (1000 , 1000 , 1 ))
1001+ model.setThresholds(Array (1 , 1 , 0 ))
10021002 val twoPredictions = model.transform(dataset).select(" prediction" ).collect()
10031003 assert(twoPredictions.forall(_.getDouble(0 ) === 2.0 ))
10041004
10051005 // constant threshold scaling is the same as no thresholds
1006- model.setThresholds(Array (1000 , 1000 , 1000 ))
1006+ model.setThresholds(Array (0.1 , 0.1 , 0.1 ))
10071007 val scaledPredictions = model.transform(dataset).select(" prediction" ).collect()
10081008 assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
10091009 scaled.getDouble(0 ) === base.getDouble(0 )
0 commit comments