Skip to content

Commit 78f13ad

Browse files
committed
[SPARK-14409][ML] Adding a RankingEvaluator to ML
This patch tries to consolidate ml.evaluation and mllib.evaluation so that RankingEvaluator wraps RankingMetrics.
1 parent 19ea63b commit 78f13ad

File tree

2 files changed

+14
-167
lines changed

2 files changed

+14
-167
lines changed

mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala

Lines changed: 13 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,19 @@ package org.apache.spark.ml.evaluation
1919

2020
import scala.reflect.ClassTag
2121

22-
import org.apache.spark.SparkContext
2322
import org.apache.spark.annotation.{Experimental, Since}
2423
import org.apache.spark.internal.Logging
2524
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
2625
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
2726
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
27+
import org.apache.spark.mllib.evaluation.RankingMetrics
2828
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
2929
import org.apache.spark.sql.functions._
3030

3131
/**
3232
* :: Experimental ::
3333
* Evaluator for ranking, which expects two input columns: prediction and label.
34+
* Both prediction and label columns need to be instances of Array[T] where T is the ClassTag.
3435
*/
3536
@Since("2.0.0")
3637
@Experimental
@@ -65,7 +66,7 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr
6566
@Since("2.0.0")
6667
val metricName: Param[String] = {
6768
val allowedParams = ParamValidators.inArray(Array("map", "mapk", "ndcg", "mrr"))
68-
new Param(this, "metricName", "metric name in evaluation (map|mapk|ndcg||mrr)", allowedParams)
69+
new Param(this, "metricName", "metric name in evaluation (map|mapk|ndcg|mrr)", allowedParams)
6970
}
7071

7172
/** @group getParam */
@@ -98,175 +99,21 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr
9899
s"must be of the same type, but Prediction column $predictionColName is $predictionType " +
99100
s"and Label column $labelColName is $labelType")
100101

102+
val predictionAndLabels = dataset
103+
.select(col($(predictionCol)).cast(predictionType), col($(labelCol)).cast(labelType))
104+
.rdd.
105+
map { case Row(prediction: Seq[T], label: Seq[T]) => (prediction.toArray, label.toArray) }
106+
107+
val metrics = new RankingMetrics[T](predictionAndLabels)
101108
val metric = $(metricName) match {
102-
case "map" => meanAveragePrecision(dataset)
103-
case "ndcg" => normalizedDiscountedCumulativeGain(dataset)
104-
case "mapk" => meanAveragePrecisionAtK(dataset)
105-
case "mrr" => meanReciprocalRank(dataset)
109+
case "map" => metrics.meanAveragePrecision
110+
case "ndcg" => metrics.ndcgAt($(k))
111+
case "mapk" => metrics.precisionAt($(k))
112+
case "mrr" => metrics.meanReciprocalRank
106113
}
107114
metric
108115
}
109116

110-
/**
111-
* Returns the mean average precision (MAP) of all the queries.
112-
* If a query has an empty ground truth set, the average precision will be zero and a log
113-
* warning is generated.
114-
*/
115-
private def meanAveragePrecision(dataset: Dataset[_]): Double = {
116-
val sc = SparkContext.getOrCreate()
117-
val sqlContext = SQLContext.getOrCreate(sc)
118-
import sqlContext.implicits._
119-
120-
dataset.map{ case (prediction: Array[T], label: Array[T]) =>
121-
val labSet = label.toSet
122-
123-
if (labSet.nonEmpty) {
124-
var i = 0
125-
var cnt = 0
126-
var precSum = 0.0
127-
val n = prediction.length
128-
while (i < n) {
129-
if (labSet.contains(prediction(i))) {
130-
cnt += 1
131-
precSum += cnt.toDouble / (i + 1)
132-
}
133-
i += 1
134-
}
135-
precSum / labSet.size
136-
} else {
137-
logWarning("Empty ground truth set, check input data")
138-
0.0
139-
}
140-
}.reduce{ (a, b) => a + b } / dataset.count
141-
}
142-
143-
/**
144-
* Compute the average NDCG value of all the queries, truncated at ranking position k.
145-
* The discounted cumulative gain at position k is computed as:
146-
* sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
147-
* and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current
148-
* implementation, the relevance value is binary.
149-
150-
* If a query has an empty ground truth set, zero will be used as ndcg together with
151-
* a log warning.
152-
*
153-
* See the following paper for detail:
154-
*
155-
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
156-
*/
157-
private def normalizedDiscountedCumulativeGain(dataset: Dataset[_]): Double = {
158-
val sc = SparkContext.getOrCreate()
159-
val sqlContext = SQLContext.getOrCreate(sc)
160-
import sqlContext.implicits._
161-
162-
dataset.map{ case (prediction: Array[T], label: Array[T]) =>
163-
val labSet = label.toSet
164-
165-
if (labSet.nonEmpty) {
166-
val labSetSize = labSet.size
167-
val n = math.min(math.max(prediction.length, labSetSize), $(k))
168-
var maxDcg = 0.0
169-
var dcg = 0.0
170-
var i = 0
171-
while (i < n) {
172-
val gain = 1.0 / math.log(i + 2)
173-
if (labSet.contains(prediction(i))) {
174-
dcg += gain
175-
}
176-
if (i < labSetSize) {
177-
maxDcg += gain
178-
}
179-
i += 1
180-
}
181-
dcg / maxDcg
182-
} else {
183-
logWarning("Empty ground truth set, check input data")
184-
0.0
185-
}
186-
}.reduce{ (a, b) => a + b } / dataset.count
187-
}
188-
189-
/**
190-
* Compute the average precision of all the queries, truncated at ranking position k.
191-
*
192-
* If for a query, the ranking algorithm returns n (n < k) results, the precision value will be
193-
* computed as #(relevant items retrieved) / k. This formula also applies when the size of the
194-
* ground truth set is less than k.
195-
*
196-
* If a query has an empty ground truth set, zero will be used as precision together with
197-
* a log warning.
198-
*
199-
* See the following paper for detail:
200-
*
201-
* IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
202-
*/
203-
private def meanAveragePrecisionAtK(dataset: Dataset[_]): Double = {
204-
val sc = SparkContext.getOrCreate()
205-
val sqlContext = SQLContext.getOrCreate(sc)
206-
import sqlContext.implicits._
207-
208-
dataset.map{ case (prediction: Array[T], label: Array[T]) =>
209-
val labSet = label.toSet
210-
211-
if (labSet.nonEmpty) {
212-
val n = math.min(prediction.length, $(k))
213-
var i = 0
214-
var cnt = 0
215-
while (i < n) {
216-
if (labSet.contains(prediction(i))) {
217-
cnt += 1
218-
}
219-
i += 1
220-
}
221-
cnt.toDouble / $(k)
222-
} else {
223-
logWarning("Empty ground truth set, check input data")
224-
0.0
225-
}
226-
}.reduce{ (a, b) => a + b } / dataset.count
227-
}
228-
229-
/**
230-
* Compute the mean reciprocal rank (MRR) of all the queries.
231-
*
232-
* MRR is the inverse position of the first relevant document, and is therefore well-suited
233-
* to applications in which only the first result matters.The reciprocal rank is the
234-
* multiplicative inverse of the rank of the first correct answer for a query response and
235-
* the mean reciprocal rank is the average of the reciprocal ranks of results for a sample
236-
* of queries. MRR is well-suited to applications in which only the first result matters.
237-
*
238-
* If a query has an empty ground truth set, zero will be used as precision together with
239-
* a log warning.
240-
*
241-
* See the following paper for detail:
242-
*
243-
* Brian McFee, Gert R. G. Lanckriet Metric Learning to Rank. ICML 2010: 775-782
244-
*/
245-
private def meanReciprocalRank(dataset: Dataset[_]): Double = {
246-
val sc = SparkContext.getOrCreate()
247-
val sqlContext = SQLContext.getOrCreate(sc)
248-
import sqlContext.implicits._
249-
250-
dataset.map{ case (prediction: Array[T], label: Array[T]) =>
251-
val labSet = label.toSet
252-
253-
if (labSet.nonEmpty) {
254-
var i = 0
255-
var reciprocalRank = 0.0
256-
while (i < prediction.length && reciprocalRank == 0.0) {
257-
if (labSet.contains(prediction(i))) {
258-
reciprocalRank = 1.0 / (i + 1)
259-
}
260-
i += 1
261-
}
262-
reciprocalRank
263-
} else {
264-
logWarning("Empty ground truth set, check input data")
265-
0.0
266-
}
267-
}.reduce{ (a, b) => a + b } / dataset.count
268-
}
269-
270117
@Since("2.0.0")
271118
override def isLargerBetter: Boolean = $(metricName) match {
272119
case "map" => false

mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class RankingEvaluatorSuite
3939
(Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)),
4040
(Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)),
4141
(Array[Int](1, 2, 3, 4, 5), Array[Int]())
42-
), 2)).toDF(Seq("prediction", "label"): _*).as[(Array[Int], Array[Int])]
42+
), 2)).toDF(Seq("prediction", "label"): _*)
4343

4444
// default = map, k = 1
4545
val evaluator = new RankingEvaluator()

0 commit comments

Comments
 (0)