Skip to content

Commit 14bb639

Browse files
committed
[SPARK-30938][ML][MLLIB] BinaryClassificationMetrics optimization
### What changes were proposed in this pull request? 1, avoid `Iterator.grouped(size: Int)`, which need to maintain an arraybuffer of `size` 2, keep the number of partitions in curve computation ### Why are the changes needed? 1, `BinaryClassificationMetrics` tend to fail (OOM) when `grouping=count/numBins` is too large, due to `Iterator.grouped(size: Int)` need to maintain an arraybuffer with `size` entries, however, in `BinaryClassificationMetrics` we do not need to maintain such a big array; 2, make sizes of partitions more even; This PR computes metrics more stable and a littler faster; ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing testsuites Closes apache#27682 from zhengruifeng/grouped_opt. Authored-by: zhengruifeng <[email protected]> Signed-off-by: zhengruifeng <[email protected]>
1 parent 1383bd4 commit 14bb639

File tree

1 file changed

+49
-22
lines changed

1 file changed

+49
-22
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.evaluation
2020
import org.apache.spark.annotation.Since
2121
import org.apache.spark.internal.Logging
2222
import org.apache.spark.mllib.evaluation.binary._
23-
import org.apache.spark.rdd.{RDD, UnionRDD}
23+
import org.apache.spark.rdd.RDD
2424
import org.apache.spark.sql.{DataFrame, Row}
2525

2626
/**
@@ -101,10 +101,19 @@ class BinaryClassificationMetrics @Since("3.0.0") (
101101
@Since("1.0.0")
102102
def roc(): RDD[(Double, Double)] = {
103103
val rocCurve = createCurve(FalsePositiveRate, Recall)
104-
val sc = confusions.context
105-
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
106-
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
107-
new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
104+
val numParts = rocCurve.getNumPartitions
105+
rocCurve.mapPartitionsWithIndex { case (pid, iter) =>
106+
if (numParts == 1) {
107+
require(pid == 0)
108+
Iterator.single((0.0, 0.0)) ++ iter ++ Iterator.single((1.0, 1.0))
109+
} else if (pid == 0) {
110+
Iterator.single((0.0, 0.0)) ++ iter
111+
} else if (pid == numParts - 1) {
112+
iter ++ Iterator.single((1.0, 1.0))
113+
} else {
114+
iter
115+
}
116+
}
108117
}
109118

110119
/**
@@ -124,7 +133,13 @@ class BinaryClassificationMetrics @Since("3.0.0") (
124133
def pr(): RDD[(Double, Double)] = {
125134
val prCurve = createCurve(Recall, Precision)
126135
val (_, firstPrecision) = prCurve.first()
127-
confusions.context.parallelize(Seq((0.0, firstPrecision)), 1).union(prCurve)
136+
prCurve.mapPartitionsWithIndex { case (pid, iter) =>
137+
if (pid == 0) {
138+
Iterator.single((0.0, firstPrecision)) ++ iter
139+
} else {
140+
iter
141+
}
142+
}
128143
}
129144

130145
/**
@@ -182,28 +197,40 @@ class BinaryClassificationMetrics @Since("3.0.0") (
182197
val countsSize = counts.count()
183198
// Group the iterator into chunks of about countsSize / numBins points,
184199
// so that the resulting number of bins is about numBins
185-
var grouping = countsSize / numBins
200+
val grouping = countsSize / numBins
186201
if (grouping < 2) {
187202
// numBins was more than half of the size; no real point in down-sampling to bins
188203
logInfo(s"Curve is too small ($countsSize) for $numBins bins to be useful")
189204
counts
190205
} else {
191-
if (grouping >= Int.MaxValue) {
192-
logWarning(
193-
s"Curve too large ($countsSize) for $numBins bins; capping at ${Int.MaxValue}")
194-
grouping = Int.MaxValue
206+
counts.mapPartitions { iter =>
207+
if (iter.hasNext) {
208+
var score = Double.NaN
209+
var agg = new BinaryLabelCounter()
210+
var cnt = 0L
211+
iter.flatMap { pair =>
212+
score = pair._1
213+
agg += pair._2
214+
cnt += 1
215+
if (cnt == grouping) {
216+
// The score of the combined point will be just the last one's score,
217+
// which is also the minimal in each chunk since all scores are already
218+
// sorted in descending.
219+
// The combined point will contain all counts in this chunk. Thus, calculated
220+
// metrics (like precision, recall, etc.) on its score (or so-called threshold)
221+
// are the same as those without sampling.
222+
val ret = (score, agg)
223+
agg = new BinaryLabelCounter()
224+
cnt = 0
225+
Some(ret)
226+
} else None
227+
} ++ {
228+
if (cnt > 0) {
229+
Iterator.single((score, agg))
230+
} else Iterator.empty
231+
}
232+
} else Iterator.empty
195233
}
196-
counts.mapPartitions(_.grouped(grouping.toInt).map { pairs =>
197-
// The score of the combined point will be just the last one's score, which is also
198-
// the minimal in each chunk since all scores are already sorted in descending.
199-
val lastScore = pairs.last._1
200-
// The combined point will contain all counts in this chunk. Thus, calculated
201-
// metrics (like precision, recall, etc.) on its score (or so-called threshold) are
202-
// the same as those without sampling.
203-
val agg = new BinaryLabelCounter()
204-
pairs.foreach(pair => agg += pair._2)
205-
(lastScore, agg)
206-
})
207234
}
208235
}
209236

0 commit comments

Comments
 (0)