Skip to content

Commit bf7e81a

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-6091] [MLLIB] Add MulticlassMetrics in PySpark/MLlib
https://issues.apache.org/jira/browse/SPARK-6091 Author: Yanbo Liang <[email protected]> Closes apache#6011 from yanboliang/spark-6091 and squashes the following commits: bb3e4ba [Yanbo Liang] trigger jenkins 53c045d [Yanbo Liang] keep compatibility for python 2.6 972d5ac [Yanbo Liang] Add MulticlassMetrics in PySpark/MLlib
1 parent b13162b commit bf7e81a

File tree

2 files changed

+137
-0
lines changed

2 files changed

+137
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.SparkContext._
2323
import org.apache.spark.annotation.Experimental
2424
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
2525
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.sql.DataFrame
2627

2728
/**
2829
* ::Experimental::
@@ -33,6 +34,13 @@ import org.apache.spark.rdd.RDD
3334
@Experimental
3435
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
3536

37+
/**
38+
* An auxiliary constructor taking a DataFrame.
39+
* @param predictionAndLabels a DataFrame with two double columns: prediction and label
40+
*/
41+
private[mllib] def this(predictionAndLabels: DataFrame) =
42+
this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))
43+
3644
private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
3745
private lazy val labelCount: Long = labelCountByClass.values.sum
3846
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels

python/pyspark/mllib/evaluation.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,135 @@ def r2(self):
141141
return self.call("r2")
142142

143143

144+
class MulticlassMetrics(JavaModelWrapper):
145+
"""
146+
Evaluator for multiclass classification.
147+
148+
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
149+
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
150+
>>> metrics = MulticlassMetrics(predictionAndLabels)
151+
>>> metrics.falsePositiveRate(0.0)
152+
0.2...
153+
>>> metrics.precision(1.0)
154+
0.75...
155+
>>> metrics.recall(2.0)
156+
1.0...
157+
>>> metrics.fMeasure(0.0, 2.0)
158+
0.52...
159+
>>> metrics.precision()
160+
0.66...
161+
>>> metrics.recall()
162+
0.66...
163+
>>> metrics.weightedFalsePositiveRate
164+
0.19...
165+
>>> metrics.weightedPrecision
166+
0.68...
167+
>>> metrics.weightedRecall
168+
0.66...
169+
>>> metrics.weightedFMeasure()
170+
0.66...
171+
>>> metrics.weightedFMeasure(2.0)
172+
0.65...
173+
"""
174+
175+
def __init__(self, predictionAndLabels):
176+
"""
177+
:param predictionAndLabels an RDD of (prediction, label) pairs.
178+
"""
179+
sc = predictionAndLabels.ctx
180+
sql_ctx = SQLContext(sc)
181+
df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
182+
StructField("prediction", DoubleType(), nullable=False),
183+
StructField("label", DoubleType(), nullable=False)]))
184+
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
185+
java_model = java_class(df._jdf)
186+
super(MulticlassMetrics, self).__init__(java_model)
187+
188+
def truePositiveRate(self, label):
189+
"""
190+
Returns true positive rate for a given label (category).
191+
"""
192+
return self.call("truePositiveRate", label)
193+
194+
def falsePositiveRate(self, label):
195+
"""
196+
Returns false positive rate for a given label (category).
197+
"""
198+
return self.call("falsePositiveRate", label)
199+
200+
def precision(self, label=None):
201+
"""
202+
Returns precision or precision for a given label (category) if specified.
203+
"""
204+
if label is None:
205+
return self.call("precision")
206+
else:
207+
return self.call("precision", float(label))
208+
209+
def recall(self, label=None):
210+
"""
211+
Returns recall or recall for a given label (category) if specified.
212+
"""
213+
if label is None:
214+
return self.call("recall")
215+
else:
216+
return self.call("recall", float(label))
217+
218+
def fMeasure(self, label=None, beta=None):
219+
"""
220+
Returns f-measure or f-measure for a given label (category) if specified.
221+
"""
222+
if beta is None:
223+
if label is None:
224+
return self.call("fMeasure")
225+
else:
226+
return self.call("fMeasure", label)
227+
else:
228+
if label is None:
229+
raise Exception("If the beta parameter is specified, label can not be none")
230+
else:
231+
return self.call("fMeasure", label, beta)
232+
233+
@property
234+
def weightedTruePositiveRate(self):
235+
"""
236+
Returns weighted true positive rate.
237+
(equals to precision, recall and f-measure)
238+
"""
239+
return self.call("weightedTruePositiveRate")
240+
241+
@property
242+
def weightedFalsePositiveRate(self):
243+
"""
244+
Returns weighted false positive rate.
245+
"""
246+
return self.call("weightedFalsePositiveRate")
247+
248+
@property
249+
def weightedRecall(self):
250+
"""
251+
Returns weighted averaged recall.
252+
(equals to precision, recall and f-measure)
253+
"""
254+
return self.call("weightedRecall")
255+
256+
@property
257+
def weightedPrecision(self):
258+
"""
259+
Returns weighted averaged precision.
260+
"""
261+
return self.call("weightedPrecision")
262+
263+
def weightedFMeasure(self, beta=None):
264+
"""
265+
Returns weighted averaged f-measure.
266+
"""
267+
if beta is None:
268+
return self.call("weightedFMeasure")
269+
else:
270+
return self.call("weightedFMeasure", beta)
271+
272+
144273
def _test():
145274
import doctest
146275
from pyspark import SparkContext

0 commit comments

Comments
 (0)