From c93ab86d35984e9f70a3b4f543fb88f5541333f0 Mon Sep 17 00:00:00 2001 From: Danilo Ascione Date: Thu, 5 Jan 2017 10:34:23 +0100 Subject: [PATCH 1/5] [SPARK-14409][ML] Add RankingEvaluator and MPR metric --- .../MeanPercentileRankMetrics.scala | 53 +++++++ .../ml/evaluation/RankingEvaluator.scala | 134 ++++++++++++++++++ .../ml/evaluation/RankingEvaluatorSuite.scala | 53 +++++++ 3 files changed, 240 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala new file mode 100644 index 0000000000000..9942f3e02f00f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Since +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.functions.sum +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.types.DoubleType + +@Since("2.2.0") +class MeanPercentileRankMetrics ( + predictionAndObservations: DataFrame, predictionCol: String, labelCol: String) + extends Logging { + + def meanPercentileRank: Double = { + + def rank_ui = udf((recs: Seq[Long], item: Long) => { + val l_i = recs.indexOf(item) + + if (l_i == -1) { + 1 + } else { + l_i.toDouble / recs.size + } + }, DoubleType) + + val R_prime = predictionAndObservations.count() + val predictionColumn: Column = predictionAndObservations.col(predictionCol) + val labelColumn: Column = predictionAndObservations.col(labelCol) + + val rankSum: Double = predictionAndObservations + .withColumn("rank_ui", rank_ui(predictionColumn, labelColumn)) + .agg(sum("rank_ui")).first().getDouble(0) + + rankSum / R_prime + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala new file mode 100644 index 0000000000000..9b85485fbe164 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions.{col, collect_list, row_number} +import org.apache.spark.sql.types.LongType + +/** + * Evaluator for ranking. + */ +@Since("2.2.0") +@Experimental +final class RankingEvaluator @Since("2.2.0")(@Since("2.2.0") override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + + @Since("2.2.0") + def this() = this(Identifiable.randomUID("rankingEval")) + + @Since("2.2.0") + val k = new IntParam(this, "k", "Top-K cutoff", (x: Int) => x > 0) + + /** @group getParam */ + @Since("2.2.0") + def getK: Int = $(k) + + /** @group setParam */ + @Since("2.2.0") + def setK(value: Int): this.type = set(k, value) + + setDefault(k -> 1) + + @Since("2.2.0") + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("mpr")) + new Param(this, "metricName", "metric name in evaluation (mpr)", allowedParams) + } + + /** @group getParam */ + @Since("2.2.0") + def getMetricName: String = $(metricName) + + /** @group setParam */ + @Since("2.2.0") + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + @Since("2.2.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.2.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** + * Param for query column name. + * @group param + */ + val queryCol: Param[String] = new Param[String](this, "queryCol", "query column name") + + setDefault(queryCol, "query") + + /** @group getParam */ + @Since("2.2.0") + def getQueryCol: String = $(queryCol) + + /** @group setParam */ + @Since("2.2.0") + def setQueryCol(value: String): this.type = set(queryCol, value) + + setDefault(metricName -> "mpr") + + @Since("2.2.0") + override def evaluate(dataset: Dataset[_]): Double = { + val schema = dataset.schema + SchemaUtils.checkNumericType(schema, $(predictionCol)) + SchemaUtils.checkNumericType(schema, $(labelCol)) + SchemaUtils.checkNumericType(schema, $(queryCol)) + + val w = Window.partitionBy(col($(queryCol))).orderBy(col($(predictionCol)).desc) + + val topAtk: DataFrame = dataset + .select(col($(predictionCol)), col($(labelCol)).cast(LongType), col($(queryCol))) + .withColumn("rn", row_number().over(w)).where(col("rn") <= $(k)) + .drop("rn") + .groupBy(col($(queryCol))) + .agg(collect_list($(labelCol)).as("topAtk")) + + val predictionAndLabels: DataFrame = dataset + .join(topAtk, $(queryCol)) + .select($(labelCol), "topAtk") + + val metrics = new MeanPercentileRankMetrics(predictionAndLabels, "topAtk", $(labelCol)) + val metric = $(metricName) match { + case "mpr" => metrics.meanPercentileRank + } + metric + } + + @Since("2.2.0") + override def isLargerBetter: Boolean = $(metricName) match { + case "mpr" => false + } + + @Since("2.2.0") + override def copy(extra: ParamMap): RankingEvaluator = defaultCopy(extra) +} + +@Since("2.2.0") +object RankingEvaluator extends DefaultParamsReadable[RankingEvaluator] { + + @Since("2.2.0") + override def load(path: String): RankingEvaluator = super.load(path) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala new file mode 100644 index 0000000000000..ba574318bad98 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class RankingEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + test("params") { + ParamsSuite.checkParams(new RankingEvaluator) + } + + test("Ranking Evaluator: default params") { + + val predictionAndLabels = + Seq( + (1L, 1L, 0.5f), + (1L, 2L, Float.NaN), + (2L, 1L, 0.1f), + (2L, 2L, 0.7f) + ).toDF(Seq("query", "label", "prediction"): _*) + + // default = mpr, k = 1 + val evaluator = new RankingEvaluator() + assert(evaluator.evaluate(predictionAndLabels) ~== 0.5 absTol 0.01) + + // mpr, k = 5 + evaluator.setMetricName("mpr").setK(5) + assert(evaluator.evaluate(predictionAndLabels) ~== 0.25 absTol 0.01) + } +} From bfd7dc5f3d08cbef311b7e4828c22efedf2117d8 Mon Sep 17 00:00:00 2001 From: Danilo Ascione Date: Tue, 10 Jan 2017 18:30:15 +0100 Subject: [PATCH 2/5] [SPARK-14409][ML] Handle NaN in predictions --- .../spark/ml/evaluation/RankingEvaluator.scala | 8 ++++++-- .../spark/ml/evaluation/RankingEvaluatorSuite.scala | 13 ++++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala index 9b85485fbe164..be867fa25a1c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions.{col, collect_list, row_number} +import org.apache.spark.sql.functions.{coalesce, col, collect_list, row_number, udf} import org.apache.spark.sql.types.LongType /** @@ -99,14 +99,18 @@ final class RankingEvaluator @Since("2.2.0")(@Since("2.2.0") override val uid: S val w = Window.partitionBy(col($(queryCol))).orderBy(col($(predictionCol)).desc) val topAtk: DataFrame = dataset + .na.drop("all", Seq($(predictionCol))) .select(col($(predictionCol)), col($(labelCol)).cast(LongType), col($(queryCol))) .withColumn("rn", row_number().over(w)).where(col("rn") <= $(k)) .drop("rn") .groupBy(col($(queryCol))) .agg(collect_list($(labelCol)).as("topAtk")) + val mapToEmptyArray_ = udf(() => Array.empty[Long]) + val predictionAndLabels: DataFrame = dataset - .join(topAtk, $(queryCol)) + .join(topAtk, Seq($(queryCol)), "outer") + .withColumn("topAtk", coalesce(col("topAtk"), mapToEmptyArray_())) .select($(labelCol), "topAtk") val metrics = new MeanPercentileRankMetrics(predictionAndLabels, "topAtk", $(labelCol)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala index ba574318bad98..2965c55b1beb3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala @@ -48,6 +48,17 @@ class RankingEvaluatorSuite // mpr, k = 5 evaluator.setMetricName("mpr").setK(5) - assert(evaluator.evaluate(predictionAndLabels) ~== 0.25 absTol 0.01) + assert(evaluator.evaluate(predictionAndLabels) ~== 0.375 absTol 0.01) + } + + test("Ranking Evaluator: no predictions") { + val predictionAndLabels = + Seq( + (1L, 2L, Float.NaN) + ).toDF(Seq("query", "label", "prediction"): _*) + + // default = mpr, k = 1 + val evaluator = new RankingEvaluator() + assert(evaluator.evaluate(predictionAndLabels) == 1.0) } } From 3b23bfb035514ce2a039b03d3e4ecb881f68a0f6 Mon Sep 17 00:00:00 2001 From: Danilo Ascione Date: Tue, 17 Jan 2017 16:22:57 +0100 Subject: [PATCH 3/5] [SPARK-14409][ML] Add basic test --- .../MeanPercentileRankMetricsSuite.scala | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala new file mode 100644 index 0000000000000..e76eb0621a16f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class MeanPercentileRankMetricsSuite extends SparkFunSuite with MLlibTestSparkContext{ + + import testImplicits._ + + test("Mean Percentile Rank metrics") { + val predictionAndLabels = Seq( + (111216304L, Array(111216304L)), + (108848657L, Array.empty[Long]) + ).toDF(Seq("label", "prediction"): _*) + predictionAndLabels.show() + val mpr = new MeanPercentileRankMetrics(predictionAndLabels, "prediction", "label") + print(mpr.meanPercentileRank) + } + +} From ad69499116e75c7851356e99865d7b8361cbfb20 Mon Sep 17 00:00:00 2001 From: Danilo Ascione Date: Sun, 12 Mar 2017 15:35:58 +0100 Subject: [PATCH 4/5] [SPARK-14409][ML] MeanPercentileRankMetrics : Support generic type --- .../ml/evaluation/MeanPercentileRankMetrics.scala | 10 +++++----- .../MeanPercentileRankMetricsSuite.scala | 15 ++++++++++++--- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala index 9942f3e02f00f..73827499c2678 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala @@ -30,13 +30,13 @@ class MeanPercentileRankMetrics ( def meanPercentileRank: Double = { - def rank_ui = udf((recs: Seq[Long], item: Long) => { - val l_i = recs.indexOf(item) + def rank = udf((predicted: Seq[Any], actual: Any) => { + val l_i = predicted.indexOf(actual) if (l_i == -1) { 1 } else { - l_i.toDouble / recs.size + l_i.toDouble / predicted.size } }, DoubleType) @@ -45,8 +45,8 @@ class MeanPercentileRankMetrics ( val labelColumn: Column = predictionAndObservations.col(labelCol) val rankSum: Double = predictionAndObservations - .withColumn("rank_ui", rank_ui(predictionColumn, labelColumn)) - .agg(sum("rank_ui")).first().getDouble(0) + .withColumn("rank", rank(predictionColumn, labelColumn)) + .agg(sum("rank")).first().getDouble(0) rankSum / R_prime } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala index e76eb0621a16f..1066548c2e4cb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala @@ -24,14 +24,23 @@ class MeanPercentileRankMetricsSuite extends SparkFunSuite with MLlibTestSparkCo import testImplicits._ - test("Mean Percentile Rank metrics") { + test("Mean Percentile Rank metrics : Long type") { val predictionAndLabels = Seq( (111216304L, Array(111216304L)), (108848657L, Array.empty[Long]) ).toDF(Seq("label", "prediction"): _*) - predictionAndLabels.show() val mpr = new MeanPercentileRankMetrics(predictionAndLabels, "prediction", "label") - print(mpr.meanPercentileRank) + assert(mpr.meanPercentileRank === 0.5) + } + + test("Mean Percentile Rank metrics : String type") { + val predictionAndLabels = Seq( + ("item1", Array("item2", "item1")), + ("item2", Array("item1")), + ("item3", Array("item3", "item1", "item2")) + ).toDF(Seq("label", "prediction"): _*) + val mpr = new MeanPercentileRankMetrics(predictionAndLabels, "prediction", "label") + assert(mpr.meanPercentileRank === 0.5) } } From fa2155af8947347a2fc1e565cf05a19529022266 Mon Sep 17 00:00:00 2001 From: Danilo Ascione Date: Sun, 12 Mar 2017 17:48:33 +0100 Subject: [PATCH 5/5] [SPARK-14409][ML] Write the ranking metrics computations as UDFs --- .../MeanPercentileRankMetrics.scala | 53 ----- .../ml/evaluation/RankingEvaluator.scala | 2 +- .../spark/ml/evaluation/RankingMetrics.scala | 202 ++++++++++++++++++ .../MeanPercentileRankMetricsSuite.scala | 46 ---- .../ml/evaluation/RankingMetricsSuite.scala | 75 +++++++ 5 files changed, 278 insertions(+), 100 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingMetrics.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingMetricsSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala deleted file mode 100644 index 73827499c2678..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetrics.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.ml.evaluation - -import org.apache.spark.annotation.Since -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Column, DataFrame} -import org.apache.spark.sql.functions.sum -import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.types.DoubleType - -@Since("2.2.0") -class MeanPercentileRankMetrics ( - predictionAndObservations: DataFrame, predictionCol: String, labelCol: String) - extends Logging { - - def meanPercentileRank: Double = { - - def rank = udf((predicted: Seq[Any], actual: Any) => { - val l_i = predicted.indexOf(actual) - - if (l_i == -1) { - 1 - } else { - l_i.toDouble / predicted.size - } - }, DoubleType) - - val R_prime = predictionAndObservations.count() - val predictionColumn: Column = predictionAndObservations.col(predictionCol) - val labelColumn: Column = predictionAndObservations.col(labelCol) - - val rankSum: Double = predictionAndObservations - .withColumn("rank", rank(predictionColumn, labelColumn)) - .agg(sum("rank")).first().getDouble(0) - - rankSum / R_prime - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala index be867fa25a1c0..28e1e88ea5d80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala @@ -113,7 +113,7 @@ final class RankingEvaluator @Since("2.2.0")(@Since("2.2.0") override val uid: S .withColumn("topAtk", coalesce(col("topAtk"), mapToEmptyArray_())) .select($(labelCol), "topAtk") - val metrics = new MeanPercentileRankMetrics(predictionAndLabels, "topAtk", $(labelCol)) + val metrics = new RankingMetrics(predictionAndLabels, "topAtk", $(labelCol)) val metric = $(metricName) match { case "mpr" => metrics.meanPercentileRank } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingMetrics.scala new file mode 100644 index 0000000000000..2e721e1df3b9e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingMetrics.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Since +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.functions.{mean, sum} +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.types.DoubleType + +@Since("2.2.0") +class RankingMetrics( + predictionAndObservations: DataFrame, predictionCol: String, labelCol: String) + extends Logging with Serializable { + + /** + * Compute the Mean Percentile Rank (MPR) of all the queries. + * + * See the following paper for detail ("Expected percentile rank" in the paper): + * Hu, Y., Y. Koren, and C. Volinsky. “Collaborative Filtering for Implicit Feedback Datasets.” + * In 2008 Eighth IEEE International Conference on Data Mining, 263–72, 2008. + * doi:10.1109/ICDM.2008.22. + * + * @return the mean percentile rank + */ + lazy val meanPercentileRank: Double = { + + def rank = udf((predicted: Seq[Any], actual: Any) => { + val l_i = predicted.indexOf(actual) + + if (l_i == -1) { + 1 + } else { + l_i.toDouble / predicted.size + } + }, DoubleType) + + val R_prime = predictionAndObservations.count() + val predictionColumn: Column = predictionAndObservations.col(predictionCol) + val labelColumn: Column = predictionAndObservations.col(labelCol) + + val rankSum: Double = predictionAndObservations + .withColumn("rank", rank(predictionColumn, labelColumn)) + .agg(sum("rank")).first().getDouble(0) + + rankSum / R_prime + } + + /** + * Compute the average precision of all the queries, truncated at ranking position k. + * + * If for a query, the ranking algorithm returns n (n is less than k) results, the precision + * value will be computed as #(relevant items retrieved) / k. This formula also applies when + * the size of the ground truth set is less than k. + * + * If a query has an empty ground truth set, zero will be used as precision together with + * a log warning. + * + * See the following paper for detail: + * + * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen + * + * @param k the position to compute the truncated precision, must be positive + * @return the average precision at the first k ranking positions + */ + @Since("2.2.0") + def precisionAt(k: Int): Double = { + require(k > 0, "ranking position k should be positive") + + def precisionAtK = udf((predicted: Seq[Any], actual: Seq[Any]) => { + val actualSet = actual.toSet + if (actualSet.nonEmpty) { + val n = math.min(predicted.length, k) + var i = 0 + var cnt = 0 + while (i < n) { + if (actualSet.contains(predicted(i))) { + cnt += 1 + } + i += 1 + } + cnt.toDouble / k + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }, DoubleType) + + val predictionColumn: Column = predictionAndObservations.col(predictionCol) + val labelColumn: Column = predictionAndObservations.col(labelCol) + + predictionAndObservations + .withColumn("predictionAtK", precisionAtK(predictionColumn, labelColumn)) + .agg(mean("predictionAtK")).first().getDouble(0) + } + + /** + * Returns the mean average precision (MAP) of all the queries. + * If a query has an empty ground truth set, the average precision will be zero and a log + * warning is generated. + */ + lazy val meanAveragePrecision: Double = { + + def map = udf((predicted: Seq[Any], actual: Seq[Any]) => { + val actualSet = actual.toSet + if (actualSet.nonEmpty) { + var i = 0 + var cnt = 0 + var precSum = 0.0 + val n = predicted.length + while (i < n) { + if (actualSet.contains(predicted(i))) { + cnt += 1 + precSum += cnt.toDouble / (i + 1) + } + i += 1 + } + precSum / actualSet.size + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }, DoubleType) + + val predictionColumn: Column = predictionAndObservations.col(predictionCol) + val labelColumn: Column = predictionAndObservations.col(labelCol) + + predictionAndObservations + .withColumn("MAP", map(predictionColumn, labelColumn)) + .agg(mean("MAP")).first().getDouble(0) + } + + /** + * Compute the average NDCG value of all the queries, truncated at ranking position k. + * The discounted cumulative gain at position k is computed as: + * sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), + * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current + * implementation, the relevance value is binary. + + * If a query has an empty ground truth set, zero will be used as ndcg together with + * a log warning. + * + * See the following paper for detail: + * + * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen + * + * @param k the position to compute the truncated ndcg, must be positive + * @return the average ndcg at the first k ranking positions + */ + @Since("2.2.0") + def ndcgAt(k: Int): Double = { + require(k > 0, "ranking position k should be positive") + + def ndcgAtK = udf((predicted: Seq[Any], actual: Seq[Any]) => { + val actualSet = actual.toSet + + if (actualSet.nonEmpty) { + val labSetSize = actualSet.size + val n = math.min(math.max(predicted.length, labSetSize), k) + var maxDcg = 0.0 + var dcg = 0.0 + var i = 0 + while (i < n) { + val gain = 1.0 / math.log(i + 2) + if (i < predicted.length && actualSet.contains(predicted(i))) { + dcg += gain + } + if (i < labSetSize) { + maxDcg += gain + } + i += 1 + } + dcg / maxDcg + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }, DoubleType) + + val predictionColumn: Column = predictionAndObservations.col(predictionCol) + val labelColumn: Column = predictionAndObservations.col(labelCol) + + predictionAndObservations + .withColumn("ndcgAtK", ndcgAtK(predictionColumn, labelColumn)) + .agg(mean("ndcgAtK")).first().getDouble(0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala deleted file mode 100644 index 1066548c2e4cb..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MeanPercentileRankMetricsSuite.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.evaluation - -import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext - -class MeanPercentileRankMetricsSuite extends SparkFunSuite with MLlibTestSparkContext{ - - import testImplicits._ - - test("Mean Percentile Rank metrics : Long type") { - val predictionAndLabels = Seq( - (111216304L, Array(111216304L)), - (108848657L, Array.empty[Long]) - ).toDF(Seq("label", "prediction"): _*) - val mpr = new MeanPercentileRankMetrics(predictionAndLabels, "prediction", "label") - assert(mpr.meanPercentileRank === 0.5) - } - - test("Mean Percentile Rank metrics : String type") { - val predictionAndLabels = Seq( - ("item1", Array("item2", "item1")), - ("item2", Array("item1")), - ("item3", Array("item3", "item1", "item2")) - ).toDF(Seq("label", "prediction"): _*) - val mpr = new MeanPercentileRankMetrics(predictionAndLabels, "prediction", "label") - assert(mpr.meanPercentileRank === 0.5) - } - -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingMetricsSuite.scala new file mode 100644 index 0000000000000..29b0f70000271 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingMetricsSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext{ + + import testImplicits._ + + test("Ranking metrics: precision@K, MAP, NDCG") { + val predictionAndLabels = sc.parallelize( + Seq( + (Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)), + (Array(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array(1, 2, 3)), + (Array(1, 2, 3, 4, 5), Array.empty[Int]) + ), 2).toDF(Seq("prediction", "label"): _*) + val eps = 1.0E-5 + + val metrics = new RankingMetrics(predictionAndLabels, "prediction", "label") + val map = metrics.meanAveragePrecision + + assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps) + assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps) + assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps) + assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps) + + assert(map ~== 0.355026 absTol eps) + + assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps) + assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps) + assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps) + assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps) + } + + test("Ranking metrics: Mean Percentile Rank (Long type)") { + val predictionAndLabels = Seq( + (111216304L, Array(111216304L)), + (108848657L, Array.empty[Long]) + ).toDF(Seq("label", "prediction"): _*) + val mpr = new RankingMetrics(predictionAndLabels, "prediction", "label") + assert(mpr.meanPercentileRank === 0.5) + } + + test("Ranking metrics: Mean Percentile Rank (String type)") { + val predictionAndLabels = Seq( + ("item1", Array("item2", "item1")), + ("item2", Array("item1")), + ("item3", Array("item3", "item1", "item2")) + ).toDF(Seq("label", "prediction"): _*) + val mpr = new RankingMetrics(predictionAndLabels, "prediction", "label") + assert(mpr.meanPercentileRank === 0.5) + } + +}