Skip to content

Commit 65c696e

Browse files
Ram Sriharshamengxr
authored andcommitted
[SPARK-7833] [ML] Add python wrapper for RegressionEvaluator
Author: Ram Sriharsha <[email protected]> Closes apache#6365 from harsha2010/SPARK-7833 and squashes the following commits: 923f288 [Ram Sriharsha] cleanup 7623b7d [Ram Sriharsha] python style fix 9743f83 [Ram Sriharsha] [SPARK-7833][ml] Add python wrapper for RegressionEvaluator
1 parent ed21476 commit 65c696e

File tree

3 files changed

+69
-4
lines changed

3 files changed

+69
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ import org.apache.spark.sql.types.DoubleType
3131
* Evaluator for regression, which expects two input columns: prediction and label.
3232
*/
3333
@AlphaComponent
34-
class RegressionEvaluator(override val uid: String)
34+
final class RegressionEvaluator(override val uid: String)
3535
extends Evaluator with HasPredictionCol with HasLabelCol {
3636

3737
def this() = this(Identifiable.randomUID("regEval"))
3838

3939
/**
4040
* param for metric name in evaluation
41-
* @group param
41+
* @group param supports mse, rmse, r2, mae as valid metric names.
4242
*/
4343
val metricName: Param[String] = {
4444
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))

mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
3939
val dataset = sqlContext.createDataFrame(
4040
sc.parallelize(LinearDataGenerator.generateLinearInput(
4141
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
42+
4243
/**
4344
* Using the following R code to load the data, train the model and evaluate metrics.
4445
*

python/pyspark/ml/evaluation.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919

2020
from pyspark.ml.wrapper import JavaWrapper
2121
from pyspark.ml.param import Param, Params
22-
from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol
22+
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
2323
from pyspark.ml.util import keyword_only
2424
from pyspark.mllib.common import inherit_doc
2525

26-
__all__ = ['Evaluator', 'BinaryClassificationEvaluator']
26+
__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator']
2727

2828

2929
@inherit_doc
@@ -148,6 +148,70 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
148148
return self._set(**kwargs)
149149

150150

151+
@inherit_doc
152+
class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
153+
"""
154+
Evaluator for Regression, which expects two input
155+
columns: prediction and label.
156+
157+
>>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5),
158+
... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)]
159+
>>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"])
160+
...
161+
>>> evaluator = RegressionEvaluator(predictionCol="raw")
162+
>>> evaluator.evaluate(dataset)
163+
2.842...
164+
>>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
165+
0.993...
166+
>>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
167+
2.649...
168+
"""
169+
# a placeholder to make it appear in the generated doc
170+
metricName = Param(Params._dummy(), "metricName",
171+
"metric name in evaluation (mse|rmse|r2|mae)")
172+
173+
@keyword_only
174+
def __init__(self, predictionCol="prediction", labelCol="label",
175+
metricName="rmse"):
176+
"""
177+
__init__(self, predictionCol="prediction", labelCol="label", \
178+
metricName="rmse")
179+
"""
180+
super(RegressionEvaluator, self).__init__()
181+
self._java_obj = self._new_java_obj(
182+
"org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid)
183+
#: param for metric name in evaluation (mse|rmse|r2|mae)
184+
self.metricName = Param(self, "metricName",
185+
"metric name in evaluation (mse|rmse|r2|mae)")
186+
self._setDefault(predictionCol="prediction", labelCol="label",
187+
metricName="rmse")
188+
kwargs = self.__init__._input_kwargs
189+
self._set(**kwargs)
190+
191+
def setMetricName(self, value):
192+
"""
193+
Sets the value of :py:attr:`metricName`.
194+
"""
195+
self._paramMap[self.metricName] = value
196+
return self
197+
198+
def getMetricName(self):
199+
"""
200+
Gets the value of metricName or its default value.
201+
"""
202+
return self.getOrDefault(self.metricName)
203+
204+
@keyword_only
205+
def setParams(self, predictionCol="prediction", labelCol="label",
206+
metricName="rmse"):
207+
"""
208+
setParams(self, predictionCol="prediction", labelCol="label",
209+
metricName="rmse")
210+
Sets params for regression evaluator.
211+
"""
212+
kwargs = self.setParams._input_kwargs
213+
return self._set(**kwargs)
214+
151215
if __name__ == "__main__":
152216
import doctest
153217
from pyspark.context import SparkContext

0 commit comments

Comments
 (0)