Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.evaluation

import scala.reflect.ClassTag

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.mllib.evaluation.RankingMetrics
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.apache.spark.sql.functions._

/**
* :: Experimental ::
* Evaluator for ranking, which expects two input columns: prediction and label.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a little more documentation here, perhaps mentioning the exact input types of these columns (as they're not just Double, they're Arrays, and this differs from the other evaluators).

* Both prediction and label columns need to be instances of Array[T] where T is the ClassTag.
*/
@Since("2.0.0")
@Experimental
final class RankingEvaluator[T: ClassTag] @Since("2.0.0") (@Since("2.0.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable with Logging {

@Since("2.0.0")
def this() = this(Identifiable.randomUID("rankingEval"))

@Since("2.0.0")
final val k = new IntParam(this, "k", "Top-K cutoff", (x: Int) => x > 0)

/** @group getParam */
@Since("2.0.0")
def getK: Int = $(k)

/** @group setParam */
@Since("2.0.0")
def setK(value: Int): this.type = set(k, value)

setDefault(k -> 1)

/**
* Param for metric name in evaluation. Supports:
* - `"map"` (default): Mean Average Precision
* - `"mapk"`: Mean Average Precision@K
* - `"ndcg"`: Normalized Discounted Cumulative Gain
* - `"mrr"`: Mean Reciprocal Rank
*
* @group param
*/
@Since("2.0.0")
val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(Array("map", "mapk", "ndcg", "mrr"))
new Param(this, "metricName", "metric name in evaluation (map|mapk|ndcg|mrr)", allowedParams)
}

/** @group getParam */
@Since("2.0.0")
def getMetricName: String = $(metricName)

/** @group setParam */
@Since("2.0.0")
def setMetricName(value: String): this.type = set(metricName, value)

/** @group setParam */
@Since("2.0.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/** @group setParam */
@Since("2.0.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

setDefault(metricName -> "map")

@Since("2.0.0")
override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
val predictionColName = $(predictionCol)
val predictionType = schema($(predictionCol)).dataType
val labelColName = $(labelCol)
val labelType = schema($(labelCol)).dataType
require(predictionType == labelType,
s"Prediction column $predictionColName and Label column $labelColName " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these columns must in fact be ArrayType

s"must be of the same type, but Prediction column $predictionColName is $predictionType " +
s"and Label column $labelColName is $labelType")

val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(predictionType), col($(labelCol)).cast(labelType))
.rdd.
map { case Row(prediction: Seq[T], label: Seq[T]) => (prediction.toArray, label.toArray) }

val metrics = new RankingMetrics[T](predictionAndLabels)
val metric = $(metricName) match {
case "map" => metrics.meanAveragePrecision
case "ndcg" => metrics.ndcgAt($(k))
case "mapk" => metrics.precisionAt($(k))
case "mrr" => metrics.meanReciprocalRank
}
metric
}

@Since("2.0.0")
override def isLargerBetter: Boolean = $(metricName) match {
case "map" => false
case "ndcg" => false
case "mapk" => false
case "mrr" => false
}

@Since("2.0.0")
override def copy(extra: ParamMap): RankingEvaluator[T] = defaultCopy(extra)
}

@Since("2.0.0")
object RankingEvaluator extends DefaultParamsReadable[RankingEvaluator[_]] {

@Since("2.0.0")
override def load(path: String): RankingEvaluator[_] = super.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,45 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
}.mean()
}

/**
* Compute the mean reciprocal rank (MRR) of all the queries.
*
* MRR is the inverse position of the first relevant document, and is therefore well-suited
* to applications in which only the first result matters.The reciprocal rank is the
* multiplicative inverse of the rank of the first correct answer for a query response and
* the mean reciprocal rank is the average of the reciprocal ranks of results for a sample
* of queries. MRR is well-suited to applications in which only the first result matters.
*
* If a query has an empty ground truth set, zero will be used as precision together with
* a log warning.
*
* See the following paper for detail:
*
* Brian McFee, Gert R. G. Lanckriet Metric Learning to Rank. ICML 2010: 775-782
*
* @return the mean reciprocal rank of all the queries.
*/
lazy val meanReciprocalRank: Double = {
predictionAndLabels.map { case (pred, lab) =>
val labSet = lab.toSet

if (labSet.nonEmpty) {
var i = 0
var reciprocalRank = 0.0
while (i < pred.length && reciprocalRank == 0.0) {
if (labSet.contains(pred(i))) {
reciprocalRank = 1.0 / (i + 1)
}
i += 1
}
reciprocalRank
} else {
logWarning("Empty ground truth set, check input data")
0.0
}
}.mean()
}

}

object RankingMetrics {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.evaluation

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class RankingEvaluatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

test("params") {
ParamsSuite.checkParams(new RankingEvaluator)
}

test("Ranking Evaluator: default params") {
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.implicits._

val predictionAndLabels = sqlContext.createDataFrame(sc.parallelize(
Seq(
(Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)),
(Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)),
(Array[Int](1, 2, 3, 4, 5), Array[Int]())
), 2)).toDF(Seq("prediction", "label"): _*)

// default = map, k = 1
val evaluator = new RankingEvaluator()
assert(evaluator.evaluate(predictionAndLabels) ~== 0.355026 absTol 0.01)

// mapk, k = 5
evaluator.setMetricName("mapk").setK(5)
assert(evaluator.evaluate(predictionAndLabels) ~== 0.8/3 absTol 0.01)

// ndcg, k = 5
evaluator.setMetricName("ndcg")
assert(evaluator.evaluate(predictionAndLabels) ~== 0.328788 absTol 0.01)

// mrr
evaluator.setMetricName("mrr")
assert(evaluator.evaluate(predictionAndLabels) ~== 0.5 absTol 0.01)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.mllib.util.TestingUtils._

class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {

test("Ranking metrics: MAP, NDCG") {
test("Ranking metrics: MAP, NDCG, MRR, PRECK") {
val predictionAndLabels = sc.parallelize(
Seq(
(Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)),
Expand All @@ -34,6 +34,7 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {

val metrics = new RankingMetrics(predictionAndLabels)
val map = metrics.meanAveragePrecision
val mrr = metrics.meanReciprocalRank

assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps)
assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps)
Expand All @@ -49,6 +50,8 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)

assert(mrr ~== 0.5 absTol eps)
}

test("MAP, NDCG with few predictions (SPARK-14886)") {
Expand Down