Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,28 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.2.0")
def setMetricName(value: String): this.type = set(metricName, value)

/**
* param for number of bins to down-sample the curves (ROC curve, PR curve) in area
* computation. If 0, no down-sampling will occur.
* Default: 1000.
* @group expertParam
*/
@Since("3.0.0")
val numBins: IntParam = new IntParam(this, "numBins", "Number of bins to down-sample " +
"the curves (ROC curve, PR curve) in area computation. If 0, no down-sampling will occur. " +
"Must be >= 0.",
ParamValidators.gtEq(0))

/** @group expertGetParam */
@Since("3.0.0")
def getNumBins: Int = $(numBins)

/** @group expertSetParam */
@Since("3.0.0")
def setNumBins(value: Int): this.type = set(numBins, value)

setDefault(numBins -> 1000)

/** @group setParam */
@Since("1.5.0")
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
Expand Down Expand Up @@ -94,7 +116,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
case Row(rawPrediction: Double, label: Double, weight: Double) =>
(rawPrediction, label, weight)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights)
val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights, $(numBins))
val metric = $(metricName) match {
case "areaUnderROC" => metrics.areaUnderROC()
case "areaUnderPR" => metrics.areaUnderPR()
Expand All @@ -104,10 +126,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
}

@Since("1.5.0")
override def isLargerBetter: Boolean = $(metricName) match {
case "areaUnderROC" => true
case "areaUnderPR" => true
}
override def isLargerBetter: Boolean = true

@Since("1.4.1")
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.Since
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
Expand All @@ -43,13 +43,14 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
* - `"mse"`: mean squared error
* - `"r2"`: R^2^ metric
* - `"mae"`: mean absolute error
* - `"var"`: explained variance
*
* @group param
*/
@Since("1.4.0")
val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))
new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams)
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae", "var"))
new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae|var)", allowedParams)
}

/** @group getParam */
Expand All @@ -60,6 +61,25 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.4.0")
def setMetricName(value: String): this.type = set(metricName, value)

/**
* param for whether the regression is through the origin.
* Default: false.
* @group expertParam
*/
@Since("3.0.0")
val throughOrigin: BooleanParam = new BooleanParam(this, "throughOrigin",
"Whether the regression is through the origin.")

/** @group expertGetParam */
@Since("3.0.0")
def getThroughOrigin: Boolean = $(throughOrigin)

/** @group expertSetParam */
@Since("3.0.0")
def setThroughOrigin(value: Boolean): this.type = set(throughOrigin, value)

setDefault(throughOrigin -> false)

/** @group setParam */
@Since("1.4.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
Expand All @@ -86,22 +106,20 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
.rdd
.map { case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight) }
val metrics = new RegressionMetrics(predictionAndLabelsWithWeights)
val metric = $(metricName) match {
val metrics = new RegressionMetrics(predictionAndLabelsWithWeights, $(throughOrigin))
$(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
case "mse" => metrics.meanSquaredError
case "r2" => metrics.r2
case "mae" => metrics.meanAbsoluteError
case "var" => metrics.explainedVariance
}
metric
}

@Since("1.4.0")
override def isLargerBetter: Boolean = $(metricName) match {
case "rmse" => false
case "mse" => false
case "r2" => true
case "mae" => false
case "r2" | "var" => true
case _ => false
}

@Since("1.5.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class RegressionEvaluatorSuite
// mae
evaluator.setMetricName("mae")
assert(evaluator.evaluate(predictions) ~== 0.08399089 absTol 0.01)

// var
evaluator.setMetricName("var")
assert(evaluator.evaluate(predictions) ~== 63.6944519 absTol 0.01)
}

test("read/write") {
Expand Down
64 changes: 53 additions & 11 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
0.70...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})
0.82...
>>> evaluator.getNumBins()
1000

.. versionadded:: 1.4.0
"""
Expand All @@ -147,17 +149,22 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
"metric name in evaluation (areaUnderROC|areaUnderPR)",
typeConverter=TypeConverters.toString)

numBins = Param(Params._dummy(), "numBins", "Number of bins to down-sample the curves "
"(ROC curve, PR curve) in area computation. If 0, no down-sampling will "
"occur. Must be >= 0.",
typeConverter=TypeConverters.toInt)

@keyword_only
def __init__(self, rawPredictionCol="rawPrediction", labelCol="label",
metricName="areaUnderROC", weightCol=None):
metricName="areaUnderROC", weightCol=None, numBins=1000):
"""
__init__(self, rawPredictionCol="rawPrediction", labelCol="label", \
metricName="areaUnderROC", weightCol=None)
metricName="areaUnderROC", weightCol=None, numBins=1000)
"""
super(BinaryClassificationEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid)
self._setDefault(metricName="areaUnderROC")
self._setDefault(metricName="areaUnderROC", numBins=1000)
kwargs = self._input_kwargs
self._set(**kwargs)

Expand All @@ -175,13 +182,27 @@ def getMetricName(self):
"""
return self.getOrDefault(self.metricName)

@since("3.0.0")
def setNumBins(self, value):
"""
Sets the value of :py:attr:`numBins`.
"""
return self._set(numBins=value)

@since("3.0.0")
def getNumBins(self):
"""
Gets the value of numBins or its default value.
"""
return self.getOrDefault(self.numBins)

@keyword_only
@since("1.4.0")
def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
metricName="areaUnderROC", weightCol=None):
metricName="areaUnderROC", weightCol=None, numBins=1000):
"""
setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \
metricName="areaUnderROC", weightCol=None)
metricName="areaUnderROC", weightCol=None, numBins=1000)
Sets params for binary classification evaluator.
"""
kwargs = self._input_kwargs
Expand Down Expand Up @@ -218,6 +239,8 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh
>>> evaluator = RegressionEvaluator(predictionCol="raw", weightCol="weight")
>>> evaluator.evaluate(dataset)
2.740...
>>> evaluator.getThroughOrigin()
False

.. versionadded:: 1.4.0
"""
Expand All @@ -226,20 +249,25 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeigh
rmse - root mean squared error (default)
mse - mean squared error
r2 - r^2 metric
mae - mean absolute error.""",
mae - mean absolute error
var - explained variance.""",
typeConverter=TypeConverters.toString)

throughOrigin = Param(Params._dummy(), "throughOrigin",
"whether the regression is through the origin.",
typeConverter=TypeConverters.toBoolean)

@keyword_only
def __init__(self, predictionCol="prediction", labelCol="label",
metricName="rmse", weightCol=None):
metricName="rmse", weightCol=None, throughOrigin=False):
"""
__init__(self, predictionCol="prediction", labelCol="label", \
metricName="rmse", weightCol=None)
metricName="rmse", weightCol=None, throughOrigin=False)
"""
super(RegressionEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid)
self._setDefault(metricName="rmse")
self._setDefault(metricName="rmse", throughOrigin=False)
kwargs = self._input_kwargs
self._set(**kwargs)

Expand All @@ -257,13 +285,27 @@ def getMetricName(self):
"""
return self.getOrDefault(self.metricName)

@since("3.0.0")
def setThroughOrigin(self, value):
"""
Sets the value of :py:attr:`throughOrigin`.
"""
return self._set(throughOrigin=value)

@since("3.0.0")
def getThroughOrigin(self):
"""
Gets the value of throughOrigin or its default value.
"""
return self.getOrDefault(self.throughOrigin)

@keyword_only
@since("1.4.0")
def setParams(self, predictionCol="prediction", labelCol="label",
metricName="rmse", weightCol=None):
metricName="rmse", weightCol=None, throughOrigin=False):
"""
setParams(self, predictionCol="prediction", labelCol="label", \
metricName="rmse", weightCol=None)
metricName="rmse", weightCol=None, throughOrigin=False)
Sets params for regression evaluator.
"""
kwargs = self._input_kwargs
Expand Down