Skip to content

Commit 887279c

Browse files
imatiach-msftsrowen
authored andcommitted
[SPARK-24102][ML][MLLIB][PYSPARK][FOLLOWUP] Added weight column to pyspark API for regression evaluator and metrics
## What changes were proposed in this pull request? Followup to PR #17085 This PR adds the weight column to the pyspark side, which was already added to the scala API. The PR also undoes a name change in the scala side corresponding to a change in another similar PR as noted here: #17084 (comment) ## How was this patch tested? This patch adds python tests for the changes to the pyspark API. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #24197 from imatiach-msft/ilmat/regressor-eval-python. Authored-by: Ilya Matiach <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 0e16a6f commit 887279c

File tree

3 files changed

+39
-17
lines changed

3 files changed

+39
-17
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@ import org.apache.spark.internal.Logging
2222
import org.apache.spark.mllib.linalg.Vectors
2323
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
2424
import org.apache.spark.rdd.RDD
25-
import org.apache.spark.sql.DataFrame
25+
import org.apache.spark.sql.{DataFrame, Row}
2626

2727
/**
2828
* Evaluator for regression.
2929
*
30-
* @param predAndObsWithOptWeight an RDD of either (prediction, observation, weight)
30+
* @param predictionAndObservations an RDD of either (prediction, observation, weight)
3131
* or (prediction, observation) pairs
3232
* @param throughOrigin True if the regression is through the origin. For example, in linear
3333
* regression, it will be true without fitting intercept.
3434
*/
3535
@Since("1.2.0")
3636
class RegressionMetrics @Since("2.0.0") (
37-
predAndObsWithOptWeight: RDD[_ <: Product], throughOrigin: Boolean)
37+
predictionAndObservations: RDD[_ <: Product], throughOrigin: Boolean)
3838
extends Logging {
3939

4040
@Since("1.2.0")
@@ -47,13 +47,20 @@ class RegressionMetrics @Since("2.0.0") (
4747
* prediction and observation
4848
*/
4949
private[mllib] def this(predictionAndObservations: DataFrame) =
50-
this(predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
50+
this(predictionAndObservations.rdd.map {
51+
case Row(prediction: Double, label: Double, weight: Double) =>
52+
(prediction, label, weight)
53+
case Row(prediction: Double, label: Double) =>
54+
(prediction, label, 1.0)
55+
case other =>
56+
throw new IllegalArgumentException(s"Expected Row of tuples, got $other")
57+
})
5158

5259
/**
5360
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
5461
*/
5562
private lazy val summary: MultivariateStatisticalSummary = {
56-
val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map {
63+
val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
5764
case (prediction: Double, observation: Double, weight: Double) =>
5865
(Vectors.dense(observation, observation - prediction), weight)
5966
case (prediction: Double, observation: Double) =>
@@ -70,7 +77,7 @@ class RegressionMetrics @Since("2.0.0") (
7077
private lazy val SStot = summary.variance(0) * (summary.weightSum - 1)
7178
private lazy val SSreg = {
7279
val yMean = summary.mean(0)
73-
predAndObsWithOptWeight.map {
80+
predictionAndObservations.map {
7481
case (prediction: Double, _: Double, weight: Double) =>
7582
math.pow(prediction - yMean, 2) * weight
7683
case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2)

python/pyspark/ml/evaluation.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,13 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
190190

191191

192192
@inherit_doc
193-
class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
193+
class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol, HasWeightCol,
194194
JavaMLReadable, JavaMLWritable):
195195
"""
196196
.. note:: Experimental
197197
198-
Evaluator for Regression, which expects two input
199-
columns: prediction and label.
198+
Evaluator for Regression, which expects input columns prediction, label
199+
and an optional weight column.
200200
201201
>>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5),
202202
... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)]
@@ -214,6 +214,13 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
214214
>>> evaluator2 = RegressionEvaluator.load(re_path)
215215
>>> str(evaluator2.getPredictionCol())
216216
'raw'
217+
>>> scoreAndLabelsAndWeight = [(-28.98343821, -27.0, 1.0), (20.21491975, 21.5, 0.8),
218+
... (-25.98418959, -22.0, 1.0), (30.69731842, 33.0, 0.6), (74.69283752, 71.0, 0.2)]
219+
>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label", "weight"])
220+
...
221+
>>> evaluator = RegressionEvaluator(predictionCol="raw", weightCol="weight")
222+
>>> evaluator.evaluate(dataset)
223+
2.740...
217224
218225
.. versionadded:: 1.4.0
219226
"""
@@ -227,10 +234,10 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
227234

228235
@keyword_only
229236
def __init__(self, predictionCol="prediction", labelCol="label",
230-
metricName="rmse"):
237+
metricName="rmse", weightCol=None):
231238
"""
232239
__init__(self, predictionCol="prediction", labelCol="label", \
233-
metricName="rmse")
240+
metricName="rmse", weightCol=None)
234241
"""
235242
super(RegressionEvaluator, self).__init__()
236243
self._java_obj = self._new_java_obj(
@@ -256,10 +263,10 @@ def getMetricName(self):
256263
@keyword_only
257264
@since("1.4.0")
258265
def setParams(self, predictionCol="prediction", labelCol="label",
259-
metricName="rmse"):
266+
metricName="rmse", weightCol=None):
260267
"""
261268
setParams(self, predictionCol="prediction", labelCol="label", \
262-
metricName="rmse")
269+
metricName="rmse", weightCol=None)
263270
Sets params for regression evaluator.
264271
"""
265272
kwargs = self._input_kwargs

python/pyspark/mllib/evaluation.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ class RegressionMetrics(JavaModelWrapper):
9595
"""
9696
Evaluator for regression.
9797
98-
:param predictionAndObservations: an RDD of (prediction,
99-
observation) pairs.
98+
:param predictionAndObservations: an RDD of prediction, observation and optional weight.
10099
101100
>>> predictionAndObservations = sc.parallelize([
102101
... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
@@ -111,16 +110,25 @@ class RegressionMetrics(JavaModelWrapper):
111110
0.61...
112111
>>> metrics.r2
113112
0.94...
113+
>>> predictionAndObservationsWithOptWeight = sc.parallelize([
114+
... (2.5, 3.0, 0.5), (0.0, -0.5, 1.0), (2.0, 2.0, 0.3), (8.0, 7.0, 0.9)])
115+
>>> metrics = RegressionMetrics(predictionAndObservationsWithOptWeight)
116+
>>> metrics.rootMeanSquaredError
117+
0.68...
114118
115119
.. versionadded:: 1.4.0
116120
"""
117121

118122
def __init__(self, predictionAndObservations):
119123
sc = predictionAndObservations.ctx
120124
sql_ctx = SQLContext.getOrCreate(sc)
121-
df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([
125+
numCol = len(predictionAndObservations.first())
126+
schema = StructType([
122127
StructField("prediction", DoubleType(), nullable=False),
123-
StructField("observation", DoubleType(), nullable=False)]))
128+
StructField("observation", DoubleType(), nullable=False)])
129+
if numCol == 3:
130+
schema.add("weight", DoubleType(), False)
131+
df = sql_ctx.createDataFrame(predictionAndObservations, schema=schema)
124132
java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics
125133
java_model = java_class(df._jdf)
126134
super(RegressionMetrics, self).__init__(java_model)

0 commit comments

Comments
 (0)