Skip to content

Commit 354f8f1

Browse files
wangmiao1981srowen
authored andcommitted
[SPARK-15096][ML] LogisticRegression MultiClassSummarizer numClasses can fail if no valid labels are found
## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) Throw better exception when numClasses is empty and empty.max is thrown. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Add a new unit test, which calls histogram with empty numClasses. Author: [email protected] <[email protected]> Closes #12969 from wangmiao1981/logisticR.
1 parent 0f1f31d commit 354f8f1

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ private[classification] class MultiClassSummarizer extends Serializable {
745745
def countInvalid: Long = totalInvalidCnt
746746

747747
/** @return The number of distinct labels in the input dataset. */
748-
def numClasses: Int = distinctMap.keySet.max + 1
748+
def numClasses: Int = if (distinctMap.isEmpty) 0 else distinctMap.keySet.max + 1
749749

750750
/** @return The weightSum of each label in the input dataset. */
751751
def histogram: Array[Double] = {

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ class LogisticRegressionSuite
256256
assert(summarizer4.countInvalid === 2)
257257
assert(summarizer4.numClasses === 4)
258258

259+
val summarizer5 = new MultiClassSummarizer
260+
assert(summarizer5.histogram.isEmpty)
261+
assert(summarizer5.numClasses === 0)
262+
259263
// small map merges large one
260264
val summarizerA = summarizer1.merge(summarizer2)
261265
assert(summarizerA.hashCode() === summarizer2.hashCode())

0 commit comments

Comments
 (0)