Skip to content

Commit 1712a7c

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-6093] [MLLIB] Add RegressionMetrics in PySpark/MLlib
https://issues.apache.org/jira/browse/SPARK-6093 Author: Yanbo Liang <[email protected]> Closes #5941 from yanboliang/spark-6093 and squashes the following commits: 6934af3 [Yanbo Liang] change to @Property aac3bc5 [Yanbo Liang] Add RegressionMetrics in PySpark/MLlib
1 parent 068c315 commit 1712a7c

File tree

2 files changed

+85
-2
lines changed

2 files changed

+85
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD
2222
import org.apache.spark.Logging
2323
import org.apache.spark.mllib.linalg.Vectors
2424
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
25+
import org.apache.spark.sql.DataFrame
2526

2627
/**
2728
* :: Experimental ::
@@ -32,6 +33,14 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Multivariate
3233
@Experimental
3334
class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging {
3435

36+
/**
37+
* An auxiliary constructor taking a DataFrame.
38+
* @param predictionAndObservations a DataFrame with two double columns:
39+
* prediction and observation
40+
*/
41+
private[mllib] def this(predictionAndObservations: DataFrame) =
42+
this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1))))
43+
3544
/**
3645
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
3746
*/

python/pyspark/mllib/evaluation.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ class BinaryClassificationMetrics(JavaModelWrapper):
2727
>>> scoreAndLabels = sc.parallelize([
2828
... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
2929
>>> metrics = BinaryClassificationMetrics(scoreAndLabels)
30-
>>> metrics.areaUnderROC()
30+
>>> metrics.areaUnderROC
3131
0.70...
32-
>>> metrics.areaUnderPR()
32+
>>> metrics.areaUnderPR
3333
0.83...
3434
>>> metrics.unpersist()
3535
"""
@@ -47,13 +47,15 @@ def __init__(self, scoreAndLabels):
4747
java_model = java_class(df._jdf)
4848
super(BinaryClassificationMetrics, self).__init__(java_model)
4949

50+
@property
5051
def areaUnderROC(self):
5152
"""
5253
Computes the area under the receiver operating characteristic
5354
(ROC) curve.
5455
"""
5556
return self.call("areaUnderROC")
5657

58+
@property
5759
def areaUnderPR(self):
5860
"""
5961
Computes the area under the precision-recall curve.
@@ -67,6 +69,78 @@ def unpersist(self):
6769
self.call("unpersist")
6870

6971

72+
class RegressionMetrics(JavaModelWrapper):
73+
"""
74+
Evaluator for regression.
75+
76+
>>> predictionAndObservations = sc.parallelize([
77+
... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
78+
>>> metrics = RegressionMetrics(predictionAndObservations)
79+
>>> metrics.explainedVariance
80+
0.95...
81+
>>> metrics.meanAbsoluteError
82+
0.5...
83+
>>> metrics.meanSquaredError
84+
0.37...
85+
>>> metrics.rootMeanSquaredError
86+
0.61...
87+
>>> metrics.r2
88+
0.94...
89+
"""
90+
91+
def __init__(self, predictionAndObservations):
92+
"""
93+
:param predictionAndObservations: an RDD of (prediction, observation) pairs.
94+
"""
95+
sc = predictionAndObservations.ctx
96+
sql_ctx = SQLContext(sc)
97+
df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([
98+
StructField("prediction", DoubleType(), nullable=False),
99+
StructField("observation", DoubleType(), nullable=False)]))
100+
java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics
101+
java_model = java_class(df._jdf)
102+
super(RegressionMetrics, self).__init__(java_model)
103+
104+
@property
105+
def explainedVariance(self):
106+
"""
107+
Returns the explained variance regression score.
108+
explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
109+
"""
110+
return self.call("explainedVariance")
111+
112+
@property
113+
def meanAbsoluteError(self):
114+
"""
115+
Returns the mean absolute error, which is a risk function corresponding to the
116+
expected value of the absolute error loss or l1-norm loss.
117+
"""
118+
return self.call("meanAbsoluteError")
119+
120+
@property
121+
def meanSquaredError(self):
122+
"""
123+
Returns the mean squared error, which is a risk function corresponding to the
124+
expected value of the squared error loss or quadratic loss.
125+
"""
126+
return self.call("meanSquaredError")
127+
128+
@property
129+
def rootMeanSquaredError(self):
130+
"""
131+
Returns the root mean squared error, which is defined as the square root of
132+
the mean squared error.
133+
"""
134+
return self.call("rootMeanSquaredError")
135+
136+
@property
137+
def r2(self):
138+
"""
139+
Returns R^2^, the coefficient of determination.
140+
"""
141+
return self.call("r2")
142+
143+
70144
def _test():
71145
import doctest
72146
from pyspark import SparkContext

0 commit comments

Comments
 (0)