@@ -19,18 +19,19 @@ package org.apache.spark.ml.evaluation
1919
2020import scala .reflect .ClassTag
2121
22- import org .apache .spark .SparkContext
2322import org .apache .spark .annotation .{Experimental , Since }
2423import org .apache .spark .internal .Logging
2524import org .apache .spark .ml .param .{IntParam , Param , ParamMap , ParamValidators }
2625import org .apache .spark .ml .param .shared .{HasLabelCol , HasPredictionCol }
2726import org .apache .spark .ml .util .{DefaultParamsReadable , DefaultParamsWritable , Identifiable }
27+ import org .apache .spark .mllib .evaluation .RankingMetrics
2828import org .apache .spark .sql .{DataFrame , Dataset , Row , SQLContext }
2929import org .apache .spark .sql .functions ._
3030
3131/**
3232 * :: Experimental ::
3333 * Evaluator for ranking, which expects two input columns: prediction and label.
34+ * Both prediction and label columns need to be instances of Array[T] where T is the ClassTag.
3435 */
3536@ Since (" 2.0.0" )
3637@ Experimental
@@ -65,7 +66,7 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr
6566 @ Since (" 2.0.0" )
6667 val metricName : Param [String ] = {
6768 val allowedParams = ParamValidators .inArray(Array (" map" , " mapk" , " ndcg" , " mrr" ))
68- new Param (this , " metricName" , " metric name in evaluation (map|mapk|ndcg|| mrr)" , allowedParams)
69+ new Param (this , " metricName" , " metric name in evaluation (map|mapk|ndcg|mrr)" , allowedParams)
6970 }
7071
7172 /** @group getParam */
@@ -98,175 +99,21 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr
9899 s " must be of the same type, but Prediction column $predictionColName is $predictionType " +
99100 s " and Label column $labelColName is $labelType" )
100101
102+ val predictionAndLabels = dataset
103+ .select(col($(predictionCol)).cast(predictionType), col($(labelCol)).cast(labelType))
104+ .rdd.
105+ map { case Row (prediction : Seq [T ], label : Seq [T ]) => (prediction.toArray, label.toArray) }
106+
107+ val metrics = new RankingMetrics [T ](predictionAndLabels)
101108 val metric = $(metricName) match {
102- case " map" => meanAveragePrecision(dataset)
103- case " ndcg" => normalizedDiscountedCumulativeGain(dataset )
104- case " mapk" => meanAveragePrecisionAtK(dataset )
105- case " mrr" => meanReciprocalRank(dataset)
109+ case " map" => metrics. meanAveragePrecision
110+ case " ndcg" => metrics.ndcgAt($(k) )
111+ case " mapk" => metrics.precisionAt($(k) )
112+ case " mrr" => metrics. meanReciprocalRank
106113 }
107114 metric
108115 }
109116
110- /**
111- * Returns the mean average precision (MAP) of all the queries.
112- * If a query has an empty ground truth set, the average precision will be zero and a log
113- * warning is generated.
114- */
115- private def meanAveragePrecision (dataset : Dataset [_]): Double = {
116- val sc = SparkContext .getOrCreate()
117- val sqlContext = SQLContext .getOrCreate(sc)
118- import sqlContext .implicits ._
119-
120- dataset.map{ case (prediction : Array [T ], label : Array [T ]) =>
121- val labSet = label.toSet
122-
123- if (labSet.nonEmpty) {
124- var i = 0
125- var cnt = 0
126- var precSum = 0.0
127- val n = prediction.length
128- while (i < n) {
129- if (labSet.contains(prediction(i))) {
130- cnt += 1
131- precSum += cnt.toDouble / (i + 1 )
132- }
133- i += 1
134- }
135- precSum / labSet.size
136- } else {
137- logWarning(" Empty ground truth set, check input data" )
138- 0.0
139- }
140- }.reduce{ (a, b) => a + b } / dataset.count
141- }
142-
143- /**
144- * Compute the average NDCG value of all the queries, truncated at ranking position k.
145- * The discounted cumulative gain at position k is computed as:
146- * sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
147- * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current
148- * implementation, the relevance value is binary.
149-
150- * If a query has an empty ground truth set, zero will be used as ndcg together with
151- * a log warning.
152- *
153- * See the following paper for detail:
154- *
155- * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
156- */
157- private def normalizedDiscountedCumulativeGain (dataset : Dataset [_]): Double = {
158- val sc = SparkContext .getOrCreate()
159- val sqlContext = SQLContext .getOrCreate(sc)
160- import sqlContext .implicits ._
161-
162- dataset.map{ case (prediction : Array [T ], label : Array [T ]) =>
163- val labSet = label.toSet
164-
165- if (labSet.nonEmpty) {
166- val labSetSize = labSet.size
167- val n = math.min(math.max(prediction.length, labSetSize), $(k))
168- var maxDcg = 0.0
169- var dcg = 0.0
170- var i = 0
171- while (i < n) {
172- val gain = 1.0 / math.log(i + 2 )
173- if (labSet.contains(prediction(i))) {
174- dcg += gain
175- }
176- if (i < labSetSize) {
177- maxDcg += gain
178- }
179- i += 1
180- }
181- dcg / maxDcg
182- } else {
183- logWarning(" Empty ground truth set, check input data" )
184- 0.0
185- }
186- }.reduce{ (a, b) => a + b } / dataset.count
187- }
188-
189- /**
190- * Compute the average precision of all the queries, truncated at ranking position k.
191- *
192- * If for a query, the ranking algorithm returns n (n < k) results, the precision value will be
193- * computed as #(relevant items retrieved) / k. This formula also applies when the size of the
194- * ground truth set is less than k.
195- *
196- * If a query has an empty ground truth set, zero will be used as precision together with
197- * a log warning.
198- *
199- * See the following paper for detail:
200- *
201- * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
202- */
203- private def meanAveragePrecisionAtK (dataset : Dataset [_]): Double = {
204- val sc = SparkContext .getOrCreate()
205- val sqlContext = SQLContext .getOrCreate(sc)
206- import sqlContext .implicits ._
207-
208- dataset.map{ case (prediction : Array [T ], label : Array [T ]) =>
209- val labSet = label.toSet
210-
211- if (labSet.nonEmpty) {
212- val n = math.min(prediction.length, $(k))
213- var i = 0
214- var cnt = 0
215- while (i < n) {
216- if (labSet.contains(prediction(i))) {
217- cnt += 1
218- }
219- i += 1
220- }
221- cnt.toDouble / $(k)
222- } else {
223- logWarning(" Empty ground truth set, check input data" )
224- 0.0
225- }
226- }.reduce{ (a, b) => a + b } / dataset.count
227- }
228-
229- /**
230- * Compute the mean reciprocal rank (MRR) of all the queries.
231- *
232- * MRR is the inverse position of the first relevant document, and is therefore well-suited
233- * to applications in which only the first result matters.The reciprocal rank is the
234- * multiplicative inverse of the rank of the first correct answer for a query response and
235- * the mean reciprocal rank is the average of the reciprocal ranks of results for a sample
236- * of queries. MRR is well-suited to applications in which only the first result matters.
237- *
238- * If a query has an empty ground truth set, zero will be used as precision together with
239- * a log warning.
240- *
241- * See the following paper for detail:
242- *
243- * Brian McFee, Gert R. G. Lanckriet Metric Learning to Rank. ICML 2010: 775-782
244- */
245- private def meanReciprocalRank (dataset : Dataset [_]): Double = {
246- val sc = SparkContext .getOrCreate()
247- val sqlContext = SQLContext .getOrCreate(sc)
248- import sqlContext .implicits ._
249-
250- dataset.map{ case (prediction : Array [T ], label : Array [T ]) =>
251- val labSet = label.toSet
252-
253- if (labSet.nonEmpty) {
254- var i = 0
255- var reciprocalRank = 0.0
256- while (i < prediction.length && reciprocalRank == 0.0 ) {
257- if (labSet.contains(prediction(i))) {
258- reciprocalRank = 1.0 / (i + 1 )
259- }
260- i += 1
261- }
262- reciprocalRank
263- } else {
264- logWarning(" Empty ground truth set, check input data" )
265- 0.0
266- }
267- }.reduce{ (a, b) => a + b } / dataset.count
268- }
269-
270117 @ Since (" 2.0.0" )
271118 override def isLargerBetter : Boolean = $(metricName) match {
272119 case " map" => false
0 commit comments