From bd324cdbddc68e764740ba788b26eee6dba86861 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 17 Apr 2016 19:55:09 -0700 Subject: [PATCH 1/4] [SPARK-14409][ML][WIP] Adding a RankingEvaluator to ML This patch tries to add the implementation of Mean Rreciprocal Rank (MRR) in mllib.evaluation, as a first step toward adding a RankingEvaluator to ML. Additional test cast has been added to cover Mean Rreciprocal Rank (MRR). This patch is a work in progress. --- .../mllib/evaluation/RankingMetrics.scala | 26 +++++++++++++++++++ .../evaluation/RankingMetricsSuite.scala | 5 +++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index b98aa0534152..8de18e7d6632 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -155,6 +155,32 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] }.mean() } + /** + * Returns the mean reciprocal rank (MRR) of all the queries. + * If a query has an empty ground truth set, the reciprocal rank will be zero and a log + * warning is generated. + */ + lazy val meanReciprocalRank: Double = { + predictionAndLabels.map { case (pred, lab) => + val labSet = lab.toSet + + if (labSet.nonEmpty) { + var i = 0 + var reciprocalRank = 0.0 + while (i < pred.length && reciprocalRank == 0.0) { + if (labSet.contains(pred(i))) { + reciprocalRank = 1.0 / (i + 1) + } + i += 1 + } + reciprocalRank + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }.mean() + } + } object RankingMetrics { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index f334be2c2ba8..29de9111660b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.util.TestingUtils._ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { - test("Ranking metrics: MAP, NDCG") { + test("Ranking metrics: MAP, NDCG, MRR, PRECK") { val predictionAndLabels = sc.parallelize( Seq( (Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)), @@ -34,6 +34,7 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val metrics = new RankingMetrics(predictionAndLabels) val map = metrics.meanAveragePrecision + val mrr = metrics.meanReciprocalRank assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps) assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps) @@ -49,6 +50,8 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps) assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps) assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps) + + assert(mrr ~== 0.5 absTol eps) } test("MAP, NDCG with few predictions (SPARK-14886)") { From 05580e0631d1ca72cd92148ca25b8d28c5099fa0 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 19 Apr 2016 18:21:29 -0700 Subject: [PATCH 2/4] [SPARK-14409][ML][WIP] Adding a RankingEvaluator to ML This patch tries to add the implementation of RankingEvaluator in ml.evaluation. Additional test cast has been added to cover the implementation. --- .../ml/evaluation/RankingEvaluator.scala | 272 ++++++++++++++++++ .../ml/evaluation/RankingEvaluatorSuite.scala | 60 ++++ 2 files changed, 332 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala new file mode 100644 index 000000000000..1b9575324620 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import scala.reflect.ClassTag + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} +import org.apache.spark.sql.functions._ + +/** + * :: Experimental :: + * Evaluator for ranking, which expects two input columns: prediction and label. + */ +@Since("2.0.0") +@Experimental +final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("rankingEval")) + + @Since("2.0.0") + final val k = new IntParam(this, "k", "Top-K cutoff", (x: Int) => x > 0) + + /** @group getParam */ + @Since("2.0.0") + def getK: Int = $(k) + + /** @group setParam */ + @Since("2.0.0") + def setK(value: Int): this.type = set(k, value) + + setDefault(k -> 1) + + /** + * Param for metric name in evaluation. Supports: + * - `"map"` (default): Mean Average Precision + * - `"mapk"`: Mean Average Precision@K + * - `"ndcg"`: Normalized Discounted Cumulative Gain + * - `"mrr"`: Mean Reciprocal Rank + * + * @group param + */ + @Since("2.0.0") + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("map", "mapk", "ndcg", "mrr")) + new Param(this, "metricName", "metric name in evaluation (map|mapk|ndcg||mrr)", allowedParams) + } + + /** @group getParam */ + @Since("2.0.0") + def getMetricName: String = $(metricName) + + /** @group setParam */ + @Since("2.0.0") + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + @Since("2.0.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "map") + + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { + val schema = dataset.schema + val predictionColName = $(predictionCol) + val predictionType = schema($(predictionCol)).dataType + val labelColName = $(labelCol) + val labelType = schema($(labelCol)).dataType + require(predictionType == labelType, + s"Prediction column $predictionColName and Label column $labelColName " + + s"must be of the same type, but Prediction column $predictionColName is $predictionType " + + s"and Label column $labelColName is $labelType") + + val metric = $(metricName) match { + case "map" => meanAveragePrecision(dataset) + case "ndcg" => normalizedDiscountedCumulativeGain(dataset) + case "mapk" => meanAveragePrecisionAtK(dataset) + case "mrr" => meanReciprocalRank(dataset) + } + metric + } + + /** + * Returns the mean average precision (MAP) of all the queries. + * If a query has an empty ground truth set, the average precision will be zero and a log + * warning is generated. + */ + private def meanAveragePrecision(dataset: Dataset[_]): Double = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + dataset.map{ case (prediction: Array[T], label: Array[T]) => + val labSet = label.toSet + + if (labSet.nonEmpty) { + var i = 0 + var cnt = 0 + var precSum = 0.0 + val n = prediction.length + while (i < n) { + if (labSet.contains(prediction(i))) { + cnt += 1 + precSum += cnt.toDouble / (i + 1) + } + i += 1 + } + precSum / labSet.size + } else { + 0.0 + } + }.reduce{ (a, b) => a + b } / dataset.count + } + + /** + * Compute the average NDCG value of all the queries, truncated at ranking position k. + * The discounted cumulative gain at position k is computed as: + * sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), + * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current + * implementation, the relevance value is binary. + + * If a query has an empty ground truth set, zero will be used as ndcg together with + * a log warning. + * + * See the following paper for detail: + * + * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen + * + * @param k the position to compute the truncated ndcg, must be positive + * @return the average ndcg at the first k ranking positions + */ + private def normalizedDiscountedCumulativeGain(dataset: Dataset[_]): Double = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + dataset.map{ case (prediction: Array[T], label: Array[T]) => + val labSet = label.toSet + + if (labSet.nonEmpty) { + val labSetSize = labSet.size + val n = math.min(math.max(prediction.length, labSetSize), $(k)) + var maxDcg = 0.0 + var dcg = 0.0 + var i = 0 + while (i < n) { + val gain = 1.0 / math.log(i + 2) + if (labSet.contains(prediction(i))) { + dcg += gain + } + if (i < labSetSize) { + maxDcg += gain + } + i += 1 + } + dcg / maxDcg + } else { + 0.0 + } + }.reduce{ (a, b) => a + b } / dataset.count + } + + /** + * Compute the average precision of all the queries, truncated at ranking position k. + * + * If for a query, the ranking algorithm returns n (n < k) results, the precision value will be + * computed as #(relevant items retrieved) / k. This formula also applies when the size of the + * ground truth set is less than k. + * + * If a query has an empty ground truth set, zero will be used as precision together with + * a log warning. + * + * See the following paper for detail: + * + * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen + * + * @param k the position to compute the truncated precision, must be positive + * @return the average precision at the first k ranking positions + */ + private def meanAveragePrecisionAtK(dataset: Dataset[_]): Double = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + dataset.map{ case (prediction: Array[T], label: Array[T]) => + val labSet = label.toSet + + if (labSet.nonEmpty) { + val n = math.min(prediction.length, $(k)) + var i = 0 + var cnt = 0 + while (i < n) { + if (labSet.contains(prediction(i))) { + cnt += 1 + } + i += 1 + } + cnt.toDouble / $(k) + } else { + 0.0 + } + }.reduce{ (a, b) => a + b } / dataset.count + } + + private def meanReciprocalRank(dataset: Dataset[_]): Double = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + dataset.map{ case (prediction: Array[T], label: Array[T]) => + val labSet = label.toSet + + if (labSet.nonEmpty) { + var i = 0 + var reciprocalRank = 0.0 + while (i < prediction.length && reciprocalRank == 0.0) { + if (labSet.contains(prediction(i))) { + reciprocalRank = 1.0 / (i + 1) + } + i += 1 + } + reciprocalRank + } else { + 0.0 + } + }.reduce{ (a, b) => a + b } / dataset.count + } + + @Since("2.0.0") + override def isLargerBetter: Boolean = $(metricName) match { + case "map" => false + case "ndcg" => false + case "mapk" => false + case "mrr" => false + } + + @Since("2.0.0") + override def copy(extra: ParamMap): RankingEvaluator[T] = defaultCopy(extra) +} + +@Since("2.0.0") +object RankingEvaluator extends DefaultParamsReadable[RankingEvaluator[_]] { + + @Since("2.0.0") + override def load(path: String): RankingEvaluator[_] = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala new file mode 100644 index 000000000000..ed4551a86f25 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class RankingEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("params") { + ParamsSuite.checkParams(new RankingEvaluator) + } + + test("Ranking Evaluator: default params") { + val sqlContext = new org.apache.spark.sql.SQLContext(sc) + import sqlContext.implicits._ + + val predictionAndLabels = sqlContext.createDataFrame(sc.parallelize( + Seq( + (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)), + (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)), + (Array[Int](1, 2, 3, 4, 5), Array[Int]()) + ), 2)).toDF(Seq("prediction", "label"): _*).as[(Array[Int], Array[Int])] + + // default = map, k = 1 + val evaluator = new RankingEvaluator() + assert(evaluator.evaluate(predictionAndLabels) ~== 0.355026 absTol 0.01) + + // mapk, k = 5 + evaluator.setMetricName("mapk").setK(5) + assert(evaluator.evaluate(predictionAndLabels) ~== 0.8/3 absTol 0.01) + + // ndcg, k = 5 + evaluator.setMetricName("ndcg") + assert(evaluator.evaluate(predictionAndLabels) ~== 0.328788 absTol 0.01) + + // mrr + evaluator.setMetricName("mrr") + assert(evaluator.evaluate(predictionAndLabels) ~== 0.5 absTol 0.01) + } +} From 19ea63b61cea0023572e1356b4aa8fa608b19082 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 19 Apr 2016 19:15:45 -0700 Subject: [PATCH 3/4] [SPARK-14409][ML][WIP] Adding a RankingEvaluator to ML This patch tries to update the comment for mean reciprocal rank. --- .../ml/evaluation/RankingEvaluator.scala | 29 ++++++++++++++----- .../mllib/evaluation/RankingMetrics.scala | 19 ++++++++++-- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala index 1b9575324620..92f4c83f5afb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala @@ -21,6 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} @@ -34,7 +35,7 @@ import org.apache.spark.sql.functions._ @Since("2.0.0") @Experimental final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable with Logging { @Since("2.0.0") def this() = this(Identifiable.randomUID("rankingEval")) @@ -133,6 +134,7 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr } precSum / labSet.size } else { + logWarning("Empty ground truth set, check input data") 0.0 } }.reduce{ (a, b) => a + b } / dataset.count @@ -151,9 +153,6 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr * See the following paper for detail: * * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen - * - * @param k the position to compute the truncated ndcg, must be positive - * @return the average ndcg at the first k ranking positions */ private def normalizedDiscountedCumulativeGain(dataset: Dataset[_]): Double = { val sc = SparkContext.getOrCreate() @@ -181,6 +180,7 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr } dcg / maxDcg } else { + logWarning("Empty ground truth set, check input data") 0.0 } }.reduce{ (a, b) => a + b } / dataset.count @@ -199,9 +199,6 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr * See the following paper for detail: * * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen - * - * @param k the position to compute the truncated precision, must be positive - * @return the average precision at the first k ranking positions */ private def meanAveragePrecisionAtK(dataset: Dataset[_]): Double = { val sc = SparkContext.getOrCreate() @@ -223,11 +220,28 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr } cnt.toDouble / $(k) } else { + logWarning("Empty ground truth set, check input data") 0.0 } }.reduce{ (a, b) => a + b } / dataset.count } + /** + * Compute the mean reciprocal rank (MRR) of all the queries. + * + * MRR is the inverse position of the first relevant document, and is therefore well-suited + * to applications in which only the first result matters.The reciprocal rank is the + * multiplicative inverse of the rank of the first correct answer for a query response and + * the mean reciprocal rank is the average of the reciprocal ranks of results for a sample + * of queries. MRR is well-suited to applications in which only the first result matters. + * + * If a query has an empty ground truth set, zero will be used as precision together with + * a log warning. + * + * See the following paper for detail: + * + * Brian McFee, Gert R. G. Lanckriet Metric Learning to Rank. ICML 2010: 775-782 + */ private def meanReciprocalRank(dataset: Dataset[_]): Double = { val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) @@ -247,6 +261,7 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr } reciprocalRank } else { + logWarning("Empty ground truth set, check input data") 0.0 } }.reduce{ (a, b) => a + b } / dataset.count diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 8de18e7d6632..22fdfb6f7bbf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -156,9 +156,22 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] } /** - * Returns the mean reciprocal rank (MRR) of all the queries. - * If a query has an empty ground truth set, the reciprocal rank will be zero and a log - * warning is generated. + * Compute the mean reciprocal rank (MRR) of all the queries. + * + * MRR is the inverse position of the first relevant document, and is therefore well-suited + * to applications in which only the first result matters.The reciprocal rank is the + * multiplicative inverse of the rank of the first correct answer for a query response and + * the mean reciprocal rank is the average of the reciprocal ranks of results for a sample + * of queries. MRR is well-suited to applications in which only the first result matters. + * + * If a query has an empty ground truth set, zero will be used as precision together with + * a log warning. + * + * See the following paper for detail: + * + * Brian McFee, Gert R. G. Lanckriet Metric Learning to Rank. ICML 2010: 775-782 + * + * @return the mean reciprocal rank of all the queries. */ lazy val meanReciprocalRank: Double = { predictionAndLabels.map { case (pred, lab) => From 78f13adf5a8b750ceea827e904f6b1761304f887 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 25 Apr 2016 19:45:44 -0700 Subject: [PATCH 4/4] [SPARK-14409][ML] Adding a RankingEvaluator to ML This patch tries to consolidate ml.evaluation and mllib.evaluation so that RankingEvaluator wraps RankingMetrics. --- .../ml/evaluation/RankingEvaluator.scala | 179 ++---------------- .../ml/evaluation/RankingEvaluatorSuite.scala | 2 +- 2 files changed, 14 insertions(+), 167 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala index 92f4c83f5afb..37a078c5f95e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala @@ -19,18 +19,19 @@ package org.apache.spark.ml.evaluation import scala.reflect.ClassTag -import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.mllib.evaluation.RankingMetrics import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} import org.apache.spark.sql.functions._ /** * :: Experimental :: * Evaluator for ranking, which expects two input columns: prediction and label. + * Both prediction and label columns need to be instances of Array[T] where T is the ClassTag. */ @Since("2.0.0") @Experimental @@ -65,7 +66,7 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr @Since("2.0.0") val metricName: Param[String] = { val allowedParams = ParamValidators.inArray(Array("map", "mapk", "ndcg", "mrr")) - new Param(this, "metricName", "metric name in evaluation (map|mapk|ndcg||mrr)", allowedParams) + new Param(this, "metricName", "metric name in evaluation (map|mapk|ndcg|mrr)", allowedParams) } /** @group getParam */ @@ -98,175 +99,21 @@ final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") overr s"must be of the same type, but Prediction column $predictionColName is $predictionType " + s"and Label column $labelColName is $labelType") + val predictionAndLabels = dataset + .select(col($(predictionCol)).cast(predictionType), col($(labelCol)).cast(labelType)) + .rdd. + map { case Row(prediction: Seq[T], label: Seq[T]) => (prediction.toArray, label.toArray) } + + val metrics = new RankingMetrics[T](predictionAndLabels) val metric = $(metricName) match { - case "map" => meanAveragePrecision(dataset) - case "ndcg" => normalizedDiscountedCumulativeGain(dataset) - case "mapk" => meanAveragePrecisionAtK(dataset) - case "mrr" => meanReciprocalRank(dataset) + case "map" => metrics.meanAveragePrecision + case "ndcg" => metrics.ndcgAt($(k)) + case "mapk" => metrics.precisionAt($(k)) + case "mrr" => metrics.meanReciprocalRank } metric } - /** - * Returns the mean average precision (MAP) of all the queries. - * If a query has an empty ground truth set, the average precision will be zero and a log - * warning is generated. - */ - private def meanAveragePrecision(dataset: Dataset[_]): Double = { - val sc = SparkContext.getOrCreate() - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ - - dataset.map{ case (prediction: Array[T], label: Array[T]) => - val labSet = label.toSet - - if (labSet.nonEmpty) { - var i = 0 - var cnt = 0 - var precSum = 0.0 - val n = prediction.length - while (i < n) { - if (labSet.contains(prediction(i))) { - cnt += 1 - precSum += cnt.toDouble / (i + 1) - } - i += 1 - } - precSum / labSet.size - } else { - logWarning("Empty ground truth set, check input data") - 0.0 - } - }.reduce{ (a, b) => a + b } / dataset.count - } - - /** - * Compute the average NDCG value of all the queries, truncated at ranking position k. - * The discounted cumulative gain at position k is computed as: - * sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), - * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current - * implementation, the relevance value is binary. - - * If a query has an empty ground truth set, zero will be used as ndcg together with - * a log warning. - * - * See the following paper for detail: - * - * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen - */ - private def normalizedDiscountedCumulativeGain(dataset: Dataset[_]): Double = { - val sc = SparkContext.getOrCreate() - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ - - dataset.map{ case (prediction: Array[T], label: Array[T]) => - val labSet = label.toSet - - if (labSet.nonEmpty) { - val labSetSize = labSet.size - val n = math.min(math.max(prediction.length, labSetSize), $(k)) - var maxDcg = 0.0 - var dcg = 0.0 - var i = 0 - while (i < n) { - val gain = 1.0 / math.log(i + 2) - if (labSet.contains(prediction(i))) { - dcg += gain - } - if (i < labSetSize) { - maxDcg += gain - } - i += 1 - } - dcg / maxDcg - } else { - logWarning("Empty ground truth set, check input data") - 0.0 - } - }.reduce{ (a, b) => a + b } / dataset.count - } - - /** - * Compute the average precision of all the queries, truncated at ranking position k. - * - * If for a query, the ranking algorithm returns n (n < k) results, the precision value will be - * computed as #(relevant items retrieved) / k. This formula also applies when the size of the - * ground truth set is less than k. - * - * If a query has an empty ground truth set, zero will be used as precision together with - * a log warning. - * - * See the following paper for detail: - * - * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen - */ - private def meanAveragePrecisionAtK(dataset: Dataset[_]): Double = { - val sc = SparkContext.getOrCreate() - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ - - dataset.map{ case (prediction: Array[T], label: Array[T]) => - val labSet = label.toSet - - if (labSet.nonEmpty) { - val n = math.min(prediction.length, $(k)) - var i = 0 - var cnt = 0 - while (i < n) { - if (labSet.contains(prediction(i))) { - cnt += 1 - } - i += 1 - } - cnt.toDouble / $(k) - } else { - logWarning("Empty ground truth set, check input data") - 0.0 - } - }.reduce{ (a, b) => a + b } / dataset.count - } - - /** - * Compute the mean reciprocal rank (MRR) of all the queries. - * - * MRR is the inverse position of the first relevant document, and is therefore well-suited - * to applications in which only the first result matters.The reciprocal rank is the - * multiplicative inverse of the rank of the first correct answer for a query response and - * the mean reciprocal rank is the average of the reciprocal ranks of results for a sample - * of queries. MRR is well-suited to applications in which only the first result matters. - * - * If a query has an empty ground truth set, zero will be used as precision together with - * a log warning. - * - * See the following paper for detail: - * - * Brian McFee, Gert R. G. Lanckriet Metric Learning to Rank. ICML 2010: 775-782 - */ - private def meanReciprocalRank(dataset: Dataset[_]): Double = { - val sc = SparkContext.getOrCreate() - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ - - dataset.map{ case (prediction: Array[T], label: Array[T]) => - val labSet = label.toSet - - if (labSet.nonEmpty) { - var i = 0 - var reciprocalRank = 0.0 - while (i < prediction.length && reciprocalRank == 0.0) { - if (labSet.contains(prediction(i))) { - reciprocalRank = 1.0 / (i + 1) - } - i += 1 - } - reciprocalRank - } else { - logWarning("Empty ground truth set, check input data") - 0.0 - } - }.reduce{ (a, b) => a + b } / dataset.count - } - @Since("2.0.0") override def isLargerBetter: Boolean = $(metricName) match { case "map" => false diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala index ed4551a86f25..76833e0dd9fa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala @@ -39,7 +39,7 @@ class RankingEvaluatorSuite (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)), (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)), (Array[Int](1, 2, 3, 4, 5), Array[Int]()) - ), 2)).toDF(Seq("prediction", "label"): _*).as[(Array[Int], Array[Int])] + ), 2)).toDF(Seq("prediction", "label"): _*) // default = map, k = 1 val evaluator = new RankingEvaluator()