Skip to content

Commit 00bfec1

Browse files
committed
updated based on new comments
1 parent a571adc commit 00bfec1

File tree

6 files changed

+59
-52
lines changed

6 files changed

+59
-52
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ import org.apache.spark.annotation.Since
2121
import org.apache.spark.internal.Logging
2222
import org.apache.spark.mllib.evaluation.binary._
2323
import org.apache.spark.rdd.{RDD, UnionRDD}
24-
import org.apache.spark.sql.DataFrame
24+
import org.apache.spark.sql.{DataFrame, Row}
2525

2626
/**
2727
* Evaluator for binary classification.
2828
*
29-
* @param scoreAndLabelsWithOptWeight an RDD of (score, label) or (score, label, weight) tuples.
29+
* @param scoreAndLabels an RDD of (score, label) or (score, label, weight) tuples.
3030
* @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally
3131
* will be down-sampled to this many "bins". If 0, no down-sampling will occur.
3232
* This is useful because the curve contains a point for each distinct score
@@ -42,10 +42,10 @@ import org.apache.spark.sql.DataFrame
4242
*/
4343
@Since("1.0.0")
4444
class BinaryClassificationMetrics @Since("3.0.0") (
45-
@Since("1.3.0") val scoreAndLabelsWithOptWeight: RDD[_ <: Product],
45+
@Since("1.3.0") val scoreAndLabels: RDD[_ <: Product],
4646
@Since("1.3.0") val numBins: Int = 1000)
4747
extends Logging {
48-
val scoreLabelsWeight: RDD[(Double, (Double, Double))] = scoreAndLabelsWithOptWeight.map {
48+
val scoreLabelsWeight: RDD[(Double, (Double, Double))] = scoreAndLabels.map {
4949
case (prediction: Double, label: Double, weight: Double) =>
5050
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
5151
(prediction, (label, weight))
@@ -63,21 +63,19 @@ class BinaryClassificationMetrics @Since("3.0.0") (
6363
@Since("1.0.0")
6464
def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0)
6565

66-
/**
67-
* Retrieves the score and labels (for binary compatibility).
68-
* @return The score and labels.
69-
*/
70-
@Since("1.3.0")
71-
def scoreAndLabels: RDD[(Double, Double)] = {
72-
scoreLabelsWeight.map { case (prediction, (label, _)) => (prediction, label) }
73-
}
74-
7566
/**
7667
* An auxiliary constructor taking a DataFrame.
7768
* @param scoreAndLabels a DataFrame with two double columns: score and label
7869
*/
7970
private[mllib] def this(scoreAndLabels: DataFrame) =
80-
this(scoreAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
71+
this(scoreAndLabels.rdd.map {
72+
case Row(prediction: Double, label: Double, weight: Double) =>
73+
(prediction, label, weight)
74+
case Row(prediction: Double, label: Double) =>
75+
(prediction, label, 1.0)
76+
case other =>
77+
throw new IllegalArgumentException(s"Expected Row of tuples, got $other")
78+
})
8179

8280
/**
8381
* Unpersist intermediate RDDs used in the computation.

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@ import scala.collection.Map
2222
import org.apache.spark.annotation.Since
2323
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
2424
import org.apache.spark.rdd.RDD
25-
import org.apache.spark.sql.DataFrame
25+
import org.apache.spark.sql.{DataFrame, Row}
2626

2727
/**
2828
* Evaluator for multiclass classification.
2929
*
30-
* @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or
31-
* (prediction, label) pairs.
30+
* @param predictionAndLabels an RDD of (prediction, label, weight) or
31+
* (prediction, label) tuples.
3232
*/
3333
@Since("1.1.0")
34-
class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) {
35-
val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map {
34+
class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[_ <: Product]) {
35+
val predLabelsWeight: RDD[(Double, Double, Double)] = predictionAndLabels.map {
3636
case (prediction: Double, label: Double, weight: Double) =>
3737
(prediction, label, weight)
3838
case (prediction: Double, label: Double) =>
@@ -46,7 +46,14 @@ class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Pr
4646
* @param predictionAndLabels a DataFrame with two double columns: prediction and label
4747
*/
4848
private[mllib] def this(predictionAndLabels: DataFrame) =
49-
this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
49+
this(predictionAndLabels.rdd.map {
50+
case Row(prediction: Double, label: Double, weight: Double) =>
51+
(prediction, label, weight)
52+
case Row(prediction: Double, label: Double) =>
53+
(prediction, label, 1.0)
54+
case other =>
55+
throw new IllegalArgumentException(s"Expected Row of tuples, got $other")
56+
})
5057

5158
private lazy val labelCountByClass: Map[Double, Double] =
5259
predLabelsWeight.map {

mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,23 @@ class BinaryClassificationEvaluatorSuite
4545
.setMetricName("areaUnderPR")
4646

4747
val vectorDF = Seq(
48-
(0d, Vectors.dense(12, 2.5)),
49-
(1d, Vectors.dense(1, 3)),
50-
(0d, Vectors.dense(10, 2))
48+
(0.0, Vectors.dense(12, 2.5)),
49+
(1.0, Vectors.dense(1, 3)),
50+
(0.0, Vectors.dense(10, 2))
5151
).toDF("label", "rawPrediction")
5252
assert(evaluator.evaluate(vectorDF) === 1.0)
5353

5454
val doubleDF = Seq(
55-
(0d, 0d),
56-
(1d, 1d),
57-
(0d, 0d)
55+
(0.0, 0.0),
56+
(1.0, 1.0),
57+
(0.0, 0.0)
5858
).toDF("label", "rawPrediction")
5959
assert(evaluator.evaluate(doubleDF) === 1.0)
6060

6161
val stringDF = Seq(
62-
(0d, "0d"),
63-
(1d, "1d"),
64-
(0d, "0d")
62+
(0.0, "0.0"),
63+
(1.0, "1.0"),
64+
(0.0, "0.0")
6565
).toDF("label", "rawPrediction")
6666
val thrown = intercept[IllegalArgumentException] {
6767
evaluator.evaluate(stringDF)
@@ -77,9 +77,9 @@ class BinaryClassificationEvaluatorSuite
7777
val evaluator = new BinaryClassificationEvaluator()
7878
.setMetricName("areaUnderROC").setWeightCol(weightCol)
7979
val vectorDF = Seq(
80-
(0d, Vectors.dense(2.5, 12), 1.0),
81-
(1d, Vectors.dense(1, 3), 1.0),
82-
(0d, Vectors.dense(10, 2), 1.0)
80+
(0.0, Vectors.dense(2.5, 12), 1.0),
81+
(1.0, Vectors.dense(1, 3), 1.0),
82+
(0.0, Vectors.dense(10, 2), 1.0)
8383
).toDF("label", "rawPrediction", weightCol)
8484
val result = evaluator.evaluate(vectorDF)
8585
// without weight column
@@ -89,9 +89,9 @@ class BinaryClassificationEvaluatorSuite
8989
assert(result === result2)
9090
// use different weights, validate metrics change
9191
val vectorDF2 = Seq(
92-
(0d, Vectors.dense(2.5, 12), 2.5),
93-
(1d, Vectors.dense(1, 3), 0.1),
94-
(0d, Vectors.dense(10, 2), 2.0)
92+
(0.0, Vectors.dense(2.5, 12), 2.5),
93+
(1.0, Vectors.dense(1, 3), 0.1),
94+
(0.0, Vectors.dense(10, 2), 2.0)
9595
).toDF("label", "rawPrediction", weightCol)
9696
val result3 = evaluator.evaluate(vectorDF2)
9797
// Since wrong result weighted more heavily, expect the score to be lower

mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,15 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
9292
val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights, 0)
9393
val thresholds = Seq(0.8, 0.6, 0.4, 0.1)
9494
val numTruePositives =
95-
Seq(1.0 * w1, 1.0 * w1 + 2.0 * w2, 1.0 * w1 + 2.0 * w2, 3.0 * w2 + 1.0 * w1)
95+
Seq(1 * w1, 1 * w1 + 2 * w2, 1 * w1 + 2 * w2, 3 * w2 + 1 * w1)
9696
val numFalsePositives = Seq(0.0, 1.0 * w3, 1.0 * w1 + 1.0 * w3, 1.0 * w3 + 2.0 * w1)
9797
val numPositives = 3 * w2 + 1 * w1
9898
val numNegatives = 2 * w1 + w3
9999
val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
100100
t.toDouble / (t + f)
101101
}
102-
val recalls = numTruePositives.map(t => t / numPositives)
103-
val fpr = numFalsePositives.map(f => f / numNegatives)
102+
val recalls = numTruePositives.map(_ / numPositives)
103+
val fpr = numFalsePositives.map(_ / numNegatives)
104104
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
105105
val pr = recalls.zip(precisions)
106106
val prCurve = Seq((0.0, 1.0)) ++ pr

python/pyspark/ml/evaluation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
138138
>>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw", weightCol="weight")
139139
>>> evaluator.evaluate(dataset)
140140
0.70...
141+
>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})
142+
0.82...
141143
142144
.. versionadded:: 1.4.0
143145
"""

python/pyspark/mllib/evaluation.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class BinaryClassificationMetrics(JavaModelWrapper):
3030
"""
3131
Evaluator for binary classification.
3232
33-
:param scoreAndLabelsWithOptWeight: an RDD of score, label and optional weight.
33+
:param scoreAndLabels: an RDD of score, label and optional weight.
3434
3535
>>> scoreAndLabels = sc.parallelize([
3636
... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
@@ -45,23 +45,23 @@ class BinaryClassificationMetrics(JavaModelWrapper):
4545
... (0.6, 1.0, 0.5), (0.8, 1.0, 0.7)], 2)
4646
>>> metrics = BinaryClassificationMetrics(scoreAndLabelsWithOptWeight)
4747
>>> metrics.areaUnderROC
48-
0.70...
48+
0.79...
4949
>>> metrics.areaUnderPR
50-
0.83...
50+
0.88...
5151
5252
.. versionadded:: 1.4.0
5353
"""
5454

55-
def __init__(self, scoreAndLabelsWithOptWeight):
56-
sc = scoreAndLabelsWithOptWeight.ctx
55+
def __init__(self, scoreAndLabels):
56+
sc = scoreAndLabels.ctx
5757
sql_ctx = SQLContext.getOrCreate(sc)
58-
numCol = len(scoreAndLabelsWithOptWeight.first())
58+
numCol = len(scoreAndLabels.first())
5959
schema = StructType([
6060
StructField("score", DoubleType(), nullable=False),
6161
StructField("label", DoubleType(), nullable=False)])
62-
if (numCol == 3):
62+
if numCol == 3:
6363
schema.add("weight", DoubleType(), False)
64-
df = sql_ctx.createDataFrame(scoreAndLabelsWithOptWeight, schema=schema)
64+
df = sql_ctx.createDataFrame(scoreAndLabels, schema=schema)
6565
java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
6666
java_model = java_class(df._jdf)
6767
super(BinaryClassificationMetrics, self).__init__(java_model)
@@ -174,7 +174,7 @@ class MulticlassMetrics(JavaModelWrapper):
174174
"""
175175
Evaluator for multiclass classification.
176176
177-
:param predAndLabelsWithOptWeight: an RDD of prediction, label and optional weight.
177+
:param predictionAndLabels: an RDD of prediction, label and optional weight.
178178
179179
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
180180
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
@@ -235,16 +235,16 @@ class MulticlassMetrics(JavaModelWrapper):
235235
.. versionadded:: 1.4.0
236236
"""
237237

238-
def __init__(self, predAndLabelsWithOptWeight):
239-
sc = predAndLabelsWithOptWeight.ctx
238+
def __init__(self, predictionAndLabels):
239+
sc = predictionAndLabels.ctx
240240
sql_ctx = SQLContext.getOrCreate(sc)
241-
numCol = len(predAndLabelsWithOptWeight.first())
241+
numCol = len(predictionAndLabels.first())
242242
schema = StructType([
243243
StructField("prediction", DoubleType(), nullable=False),
244244
StructField("label", DoubleType(), nullable=False)])
245-
if (numCol == 3):
245+
if numCol == 3:
246246
schema.add("weight", DoubleType(), False)
247-
df = sql_ctx.createDataFrame(predAndLabelsWithOptWeight, schema)
247+
df = sql_ctx.createDataFrame(predictionAndLabels, schema)
248248
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
249249
java_model = java_class(df._jdf)
250250
super(MulticlassMetrics, self).__init__(java_model)

0 commit comments

Comments
 (0)