|
19 | 19 |
|
20 | 20 | from pyspark.ml.wrapper import JavaWrapper |
21 | 21 | from pyspark.ml.param import Param, Params |
22 | | -from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol |
| 22 | +from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol |
23 | 23 | from pyspark.ml.util import keyword_only |
24 | 24 | from pyspark.mllib.common import inherit_doc |
25 | 25 |
|
26 | | -__all__ = ['Evaluator', 'BinaryClassificationEvaluator'] |
| 26 | +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator'] |
27 | 27 |
|
28 | 28 |
|
29 | 29 | @inherit_doc |
@@ -148,6 +148,70 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", |
148 | 148 | return self._set(**kwargs) |
149 | 149 |
|
150 | 150 |
|
| 151 | +@inherit_doc |
| 152 | +class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): |
| 153 | + """ |
| 154 | + Evaluator for Regression, which expects two input |
| 155 | + columns: prediction and label. |
| 156 | +
|
| 157 | + >>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5), |
| 158 | + ... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)] |
| 159 | + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"]) |
| 160 | + ... |
| 161 | + >>> evaluator = RegressionEvaluator(predictionCol="raw") |
| 162 | + >>> evaluator.evaluate(dataset) |
| 163 | + 2.842... |
| 164 | + >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"}) |
| 165 | + 0.993... |
| 166 | + >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) |
| 167 | + 2.649... |
| 168 | + """ |
| 169 | + # a placeholder to make it appear in the generated doc |
| 170 | + metricName = Param(Params._dummy(), "metricName", |
| 171 | + "metric name in evaluation (mse|rmse|r2|mae)") |
| 172 | + |
| 173 | + @keyword_only |
| 174 | + def __init__(self, predictionCol="prediction", labelCol="label", |
| 175 | + metricName="rmse"): |
| 176 | + """ |
| 177 | + __init__(self, predictionCol="prediction", labelCol="label", \ |
| 178 | + metricName="rmse") |
| 179 | + """ |
| 180 | + super(RegressionEvaluator, self).__init__() |
| 181 | + self._java_obj = self._new_java_obj( |
| 182 | + "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) |
| 183 | + #: param for metric name in evaluation (mse|rmse|r2|mae) |
| 184 | + self.metricName = Param(self, "metricName", |
| 185 | + "metric name in evaluation (mse|rmse|r2|mae)") |
| 186 | + self._setDefault(predictionCol="prediction", labelCol="label", |
| 187 | + metricName="rmse") |
| 188 | + kwargs = self.__init__._input_kwargs |
| 189 | + self._set(**kwargs) |
| 190 | + |
| 191 | + def setMetricName(self, value): |
| 192 | + """ |
| 193 | + Sets the value of :py:attr:`metricName`. |
| 194 | + """ |
| 195 | + self._paramMap[self.metricName] = value |
| 196 | + return self |
| 197 | + |
| 198 | + def getMetricName(self): |
| 199 | + """ |
| 200 | + Gets the value of metricName or its default value. |
| 201 | + """ |
| 202 | + return self.getOrDefault(self.metricName) |
| 203 | + |
| 204 | + @keyword_only |
| 205 | + def setParams(self, predictionCol="prediction", labelCol="label", |
| 206 | + metricName="rmse"): |
| 207 | + """ |
| 208 | + setParams(self, predictionCol="prediction", labelCol="label", |
| 209 | + metricName="rmse") |
| 210 | + Sets params for regression evaluator. |
| 211 | + """ |
| 212 | + kwargs = self.setParams._input_kwargs |
| 213 | + return self._set(**kwargs) |
| 214 | + |
151 | 215 | if __name__ == "__main__": |
152 | 216 | import doctest |
153 | 217 | from pyspark.context import SparkContext |
|
0 commit comments