diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala new file mode 100644 index 0000000000000..28e1e88ea5d80 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingEvaluator.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions.{coalesce, col, collect_list, row_number, udf} +import org.apache.spark.sql.types.LongType + +/** + * Evaluator for ranking. + */ +@Since("2.2.0") +@Experimental +final class RankingEvaluator @Since("2.2.0")(@Since("2.2.0") override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + + @Since("2.2.0") + def this() = this(Identifiable.randomUID("rankingEval")) + + @Since("2.2.0") + val k = new IntParam(this, "k", "Top-K cutoff", (x: Int) => x > 0) + + /** @group getParam */ + @Since("2.2.0") + def getK: Int = $(k) + + /** @group setParam */ + @Since("2.2.0") + def setK(value: Int): this.type = set(k, value) + + setDefault(k -> 1) + + @Since("2.2.0") + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("mpr")) + new Param(this, "metricName", "metric name in evaluation (mpr)", allowedParams) + } + + /** @group getParam */ + @Since("2.2.0") + def getMetricName: String = $(metricName) + + /** @group setParam */ + @Since("2.2.0") + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + @Since("2.2.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.2.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** + * Param for query column name. + * @group param + */ + val queryCol: Param[String] = new Param[String](this, "queryCol", "query column name") + + setDefault(queryCol, "query") + + /** @group getParam */ + @Since("2.2.0") + def getQueryCol: String = $(queryCol) + + /** @group setParam */ + @Since("2.2.0") + def setQueryCol(value: String): this.type = set(queryCol, value) + + setDefault(metricName -> "mpr") + + @Since("2.2.0") + override def evaluate(dataset: Dataset[_]): Double = { + val schema = dataset.schema + SchemaUtils.checkNumericType(schema, $(predictionCol)) + SchemaUtils.checkNumericType(schema, $(labelCol)) + SchemaUtils.checkNumericType(schema, $(queryCol)) + + val w = Window.partitionBy(col($(queryCol))).orderBy(col($(predictionCol)).desc) + + val topAtk: DataFrame = dataset + .na.drop("all", Seq($(predictionCol))) + .select(col($(predictionCol)), col($(labelCol)).cast(LongType), col($(queryCol))) + .withColumn("rn", row_number().over(w)).where(col("rn") <= $(k)) + .drop("rn") + .groupBy(col($(queryCol))) + .agg(collect_list($(labelCol)).as("topAtk")) + + val mapToEmptyArray_ = udf(() => Array.empty[Long]) + + val predictionAndLabels: DataFrame = dataset + .join(topAtk, Seq($(queryCol)), "outer") + .withColumn("topAtk", coalesce(col("topAtk"), mapToEmptyArray_())) + .select($(labelCol), "topAtk") + + val metrics = new RankingMetrics(predictionAndLabels, "topAtk", $(labelCol)) + val metric = $(metricName) match { + case "mpr" => metrics.meanPercentileRank + } + metric + } + + @Since("2.2.0") + override def isLargerBetter: Boolean = $(metricName) match { + case "mpr" => false + } + + @Since("2.2.0") + override def copy(extra: ParamMap): RankingEvaluator = defaultCopy(extra) +} + +@Since("2.2.0") +object RankingEvaluator extends DefaultParamsReadable[RankingEvaluator] { + + @Since("2.2.0") + override def load(path: String): RankingEvaluator = super.load(path) + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingMetrics.scala new file mode 100644 index 0000000000000..2e721e1df3b9e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RankingMetrics.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Since +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.functions.{mean, sum} +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.types.DoubleType + +@Since("2.2.0") +class RankingMetrics( + predictionAndObservations: DataFrame, predictionCol: String, labelCol: String) + extends Logging with Serializable { + + /** + * Compute the Mean Percentile Rank (MPR) of all the queries. + * + * See the following paper for detail ("Expected percentile rank" in the paper): + * Hu, Y., Y. Koren, and C. Volinsky. “Collaborative Filtering for Implicit Feedback Datasets.” + * In 2008 Eighth IEEE International Conference on Data Mining, 263–72, 2008. + * doi:10.1109/ICDM.2008.22. + * + * @return the mean percentile rank + */ + lazy val meanPercentileRank: Double = { + + def rank = udf((predicted: Seq[Any], actual: Any) => { + val l_i = predicted.indexOf(actual) + + if (l_i == -1) { + 1 + } else { + l_i.toDouble / predicted.size + } + }, DoubleType) + + val R_prime = predictionAndObservations.count() + val predictionColumn: Column = predictionAndObservations.col(predictionCol) + val labelColumn: Column = predictionAndObservations.col(labelCol) + + val rankSum: Double = predictionAndObservations + .withColumn("rank", rank(predictionColumn, labelColumn)) + .agg(sum("rank")).first().getDouble(0) + + rankSum / R_prime + } + + /** + * Compute the average precision of all the queries, truncated at ranking position k. + * + * If for a query, the ranking algorithm returns n (n is less than k) results, the precision + * value will be computed as #(relevant items retrieved) / k. This formula also applies when + * the size of the ground truth set is less than k. + * + * If a query has an empty ground truth set, zero will be used as precision together with + * a log warning. + * + * See the following paper for detail: + * + * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen + * + * @param k the position to compute the truncated precision, must be positive + * @return the average precision at the first k ranking positions + */ + @Since("2.2.0") + def precisionAt(k: Int): Double = { + require(k > 0, "ranking position k should be positive") + + def precisionAtK = udf((predicted: Seq[Any], actual: Seq[Any]) => { + val actualSet = actual.toSet + if (actualSet.nonEmpty) { + val n = math.min(predicted.length, k) + var i = 0 + var cnt = 0 + while (i < n) { + if (actualSet.contains(predicted(i))) { + cnt += 1 + } + i += 1 + } + cnt.toDouble / k + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }, DoubleType) + + val predictionColumn: Column = predictionAndObservations.col(predictionCol) + val labelColumn: Column = predictionAndObservations.col(labelCol) + + predictionAndObservations + .withColumn("predictionAtK", precisionAtK(predictionColumn, labelColumn)) + .agg(mean("predictionAtK")).first().getDouble(0) + } + + /** + * Returns the mean average precision (MAP) of all the queries. + * If a query has an empty ground truth set, the average precision will be zero and a log + * warning is generated. + */ + lazy val meanAveragePrecision: Double = { + + def map = udf((predicted: Seq[Any], actual: Seq[Any]) => { + val actualSet = actual.toSet + if (actualSet.nonEmpty) { + var i = 0 + var cnt = 0 + var precSum = 0.0 + val n = predicted.length + while (i < n) { + if (actualSet.contains(predicted(i))) { + cnt += 1 + precSum += cnt.toDouble / (i + 1) + } + i += 1 + } + precSum / actualSet.size + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }, DoubleType) + + val predictionColumn: Column = predictionAndObservations.col(predictionCol) + val labelColumn: Column = predictionAndObservations.col(labelCol) + + predictionAndObservations + .withColumn("MAP", map(predictionColumn, labelColumn)) + .agg(mean("MAP")).first().getDouble(0) + } + + /** + * Compute the average NDCG value of all the queries, truncated at ranking position k. + * The discounted cumulative gain at position k is computed as: + * sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), + * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current + * implementation, the relevance value is binary. + + * If a query has an empty ground truth set, zero will be used as ndcg together with + * a log warning. + * + * See the following paper for detail: + * + * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen + * + * @param k the position to compute the truncated ndcg, must be positive + * @return the average ndcg at the first k ranking positions + */ + @Since("2.2.0") + def ndcgAt(k: Int): Double = { + require(k > 0, "ranking position k should be positive") + + def ndcgAtK = udf((predicted: Seq[Any], actual: Seq[Any]) => { + val actualSet = actual.toSet + + if (actualSet.nonEmpty) { + val labSetSize = actualSet.size + val n = math.min(math.max(predicted.length, labSetSize), k) + var maxDcg = 0.0 + var dcg = 0.0 + var i = 0 + while (i < n) { + val gain = 1.0 / math.log(i + 2) + if (i < predicted.length && actualSet.contains(predicted(i))) { + dcg += gain + } + if (i < labSetSize) { + maxDcg += gain + } + i += 1 + } + dcg / maxDcg + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }, DoubleType) + + val predictionColumn: Column = predictionAndObservations.col(predictionCol) + val labelColumn: Column = predictionAndObservations.col(labelCol) + + predictionAndObservations + .withColumn("ndcgAtK", ndcgAtK(predictionColumn, labelColumn)) + .agg(mean("ndcgAtK")).first().getDouble(0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala new file mode 100644 index 0000000000000..2965c55b1beb3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingEvaluatorSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class RankingEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + test("params") { + ParamsSuite.checkParams(new RankingEvaluator) + } + + test("Ranking Evaluator: default params") { + + val predictionAndLabels = + Seq( + (1L, 1L, 0.5f), + (1L, 2L, Float.NaN), + (2L, 1L, 0.1f), + (2L, 2L, 0.7f) + ).toDF(Seq("query", "label", "prediction"): _*) + + // default = mpr, k = 1 + val evaluator = new RankingEvaluator() + assert(evaluator.evaluate(predictionAndLabels) ~== 0.5 absTol 0.01) + + // mpr, k = 5 + evaluator.setMetricName("mpr").setK(5) + assert(evaluator.evaluate(predictionAndLabels) ~== 0.375 absTol 0.01) + } + + test("Ranking Evaluator: no predictions") { + val predictionAndLabels = + Seq( + (1L, 2L, Float.NaN) + ).toDF(Seq("query", "label", "prediction"): _*) + + // default = mpr, k = 1 + val evaluator = new RankingEvaluator() + assert(evaluator.evaluate(predictionAndLabels) == 1.0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingMetricsSuite.scala new file mode 100644 index 0000000000000..29b0f70000271 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RankingMetricsSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext{ + + import testImplicits._ + + test("Ranking metrics: precision@K, MAP, NDCG") { + val predictionAndLabels = sc.parallelize( + Seq( + (Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)), + (Array(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array(1, 2, 3)), + (Array(1, 2, 3, 4, 5), Array.empty[Int]) + ), 2).toDF(Seq("prediction", "label"): _*) + val eps = 1.0E-5 + + val metrics = new RankingMetrics(predictionAndLabels, "prediction", "label") + val map = metrics.meanAveragePrecision + + assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps) + assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps) + assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps) + assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps) + + assert(map ~== 0.355026 absTol eps) + + assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps) + assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps) + assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps) + assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps) + } + + test("Ranking metrics: Mean Percentile Rank (Long type)") { + val predictionAndLabels = Seq( + (111216304L, Array(111216304L)), + (108848657L, Array.empty[Long]) + ).toDF(Seq("label", "prediction"): _*) + val mpr = new RankingMetrics(predictionAndLabels, "prediction", "label") + assert(mpr.meanPercentileRank === 0.5) + } + + test("Ranking metrics: Mean Percentile Rank (String type)") { + val predictionAndLabels = Seq( + ("item1", Array("item2", "item1")), + ("item2", Array("item1")), + ("item3", Array("item3", "item1", "item2")) + ).toDF(Seq("label", "prediction"): _*) + val mpr = new RankingMetrics(predictionAndLabels, "prediction", "label") + assert(mpr.meanPercentileRank === 0.5) + } + +}