@@ -27,9 +27,9 @@ class BinaryClassificationMetrics(JavaModelWrapper):
2727 >>> scoreAndLabels = sc.parallelize([
2828 ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
2929 >>> metrics = BinaryClassificationMetrics(scoreAndLabels)
30- >>> metrics.areaUnderROC()
30+ >>> metrics.areaUnderROC
3131 0.70...
32- >>> metrics.areaUnderPR()
32+ >>> metrics.areaUnderPR
3333 0.83...
3434 >>> metrics.unpersist()
3535 """
@@ -47,13 +47,15 @@ def __init__(self, scoreAndLabels):
4747 java_model = java_class (df ._jdf )
4848 super (BinaryClassificationMetrics , self ).__init__ (java_model )
4949
50+ @property
5051 def areaUnderROC (self ):
5152 """
5253 Computes the area under the receiver operating characteristic
5354 (ROC) curve.
5455 """
5556 return self .call ("areaUnderROC" )
5657
58+ @property
5759 def areaUnderPR (self ):
5860 """
5961 Computes the area under the precision-recall curve.
@@ -67,6 +69,78 @@ def unpersist(self):
6769 self .call ("unpersist" )
6870
6971
72+ class RegressionMetrics (JavaModelWrapper ):
73+ """
74+ Evaluator for regression.
75+
76+ >>> predictionAndObservations = sc.parallelize([
77+ ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
78+ >>> metrics = RegressionMetrics(predictionAndObservations)
79+ >>> metrics.explainedVariance
80+ 0.95...
81+ >>> metrics.meanAbsoluteError
82+ 0.5...
83+ >>> metrics.meanSquaredError
84+ 0.37...
85+ >>> metrics.rootMeanSquaredError
86+ 0.61...
87+ >>> metrics.r2
88+ 0.94...
89+ """
90+
91+ def __init__ (self , predictionAndObservations ):
92+ """
93+ :param predictionAndObservations: an RDD of (prediction, observation) pairs.
94+ """
95+ sc = predictionAndObservations .ctx
96+ sql_ctx = SQLContext (sc )
97+ df = sql_ctx .createDataFrame (predictionAndObservations , schema = StructType ([
98+ StructField ("prediction" , DoubleType (), nullable = False ),
99+ StructField ("observation" , DoubleType (), nullable = False )]))
100+ java_class = sc ._jvm .org .apache .spark .mllib .evaluation .RegressionMetrics
101+ java_model = java_class (df ._jdf )
102+ super (RegressionMetrics , self ).__init__ (java_model )
103+
104+ @property
105+ def explainedVariance (self ):
106+ """
107+ Returns the explained variance regression score.
108+ explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
109+ """
110+ return self .call ("explainedVariance" )
111+
112+ @property
113+ def meanAbsoluteError (self ):
114+ """
115+ Returns the mean absolute error, which is a risk function corresponding to the
116+ expected value of the absolute error loss or l1-norm loss.
117+ """
118+ return self .call ("meanAbsoluteError" )
119+
120+ @property
121+ def meanSquaredError (self ):
122+ """
123+ Returns the mean squared error, which is a risk function corresponding to the
124+ expected value of the squared error loss or quadratic loss.
125+ """
126+ return self .call ("meanSquaredError" )
127+
128+ @property
129+ def rootMeanSquaredError (self ):
130+ """
131+ Returns the root mean squared error, which is defined as the square root of
132+ the mean squared error.
133+ """
134+ return self .call ("rootMeanSquaredError" )
135+
136+ @property
137+ def r2 (self ):
138+ """
139+ Returns R^2^, the coefficient of determination.
140+ """
141+ return self .call ("r2" )
142+
143+
70144def _test ():
71145 import doctest
72146 from pyspark import SparkContext
0 commit comments