@@ -141,6 +141,135 @@ def r2(self):
141141 return self .call ("r2" )
142142
143143
144+ class MulticlassMetrics (JavaModelWrapper ):
145+ """
146+ Evaluator for multiclass classification.
147+
148+ >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
149+ ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
150+ >>> metrics = MulticlassMetrics(predictionAndLabels)
151+ >>> metrics.falsePositiveRate(0.0)
152+ 0.2...
153+ >>> metrics.precision(1.0)
154+ 0.75...
155+ >>> metrics.recall(2.0)
156+ 1.0...
157+ >>> metrics.fMeasure(0.0, 2.0)
158+ 0.52...
159+ >>> metrics.precision()
160+ 0.66...
161+ >>> metrics.recall()
162+ 0.66...
163+ >>> metrics.weightedFalsePositiveRate
164+ 0.19...
165+ >>> metrics.weightedPrecision
166+ 0.68...
167+ >>> metrics.weightedRecall
168+ 0.66...
169+ >>> metrics.weightedFMeasure()
170+ 0.66...
171+ >>> metrics.weightedFMeasure(2.0)
172+ 0.65...
173+ """
174+
175+ def __init__ (self , predictionAndLabels ):
176+ """
177+ :param predictionAndLabels an RDD of (prediction, label) pairs.
178+ """
179+ sc = predictionAndLabels .ctx
180+ sql_ctx = SQLContext (sc )
181+ df = sql_ctx .createDataFrame (predictionAndLabels , schema = StructType ([
182+ StructField ("prediction" , DoubleType (), nullable = False ),
183+ StructField ("label" , DoubleType (), nullable = False )]))
184+ java_class = sc ._jvm .org .apache .spark .mllib .evaluation .MulticlassMetrics
185+ java_model = java_class (df ._jdf )
186+ super (MulticlassMetrics , self ).__init__ (java_model )
187+
188+ def truePositiveRate (self , label ):
189+ """
190+ Returns true positive rate for a given label (category).
191+ """
192+ return self .call ("truePositiveRate" , label )
193+
194+ def falsePositiveRate (self , label ):
195+ """
196+ Returns false positive rate for a given label (category).
197+ """
198+ return self .call ("falsePositiveRate" , label )
199+
200+ def precision (self , label = None ):
201+ """
202+ Returns precision or precision for a given label (category) if specified.
203+ """
204+ if label is None :
205+ return self .call ("precision" )
206+ else :
207+ return self .call ("precision" , float (label ))
208+
209+ def recall (self , label = None ):
210+ """
211+ Returns recall or recall for a given label (category) if specified.
212+ """
213+ if label is None :
214+ return self .call ("recall" )
215+ else :
216+ return self .call ("recall" , float (label ))
217+
218+ def fMeasure (self , label = None , beta = None ):
219+ """
220+ Returns f-measure or f-measure for a given label (category) if specified.
221+ """
222+ if beta is None :
223+ if label is None :
224+ return self .call ("fMeasure" )
225+ else :
226+ return self .call ("fMeasure" , label )
227+ else :
228+ if label is None :
229+ raise Exception ("If the beta parameter is specified, label can not be none" )
230+ else :
231+ return self .call ("fMeasure" , label , beta )
232+
233+ @property
234+ def weightedTruePositiveRate (self ):
235+ """
236+ Returns weighted true positive rate.
237+ (equals to precision, recall and f-measure)
238+ """
239+ return self .call ("weightedTruePositiveRate" )
240+
241+ @property
242+ def weightedFalsePositiveRate (self ):
243+ """
244+ Returns weighted false positive rate.
245+ """
246+ return self .call ("weightedFalsePositiveRate" )
247+
248+ @property
249+ def weightedRecall (self ):
250+ """
251+ Returns weighted averaged recall.
252+ (equals to precision, recall and f-measure)
253+ """
254+ return self .call ("weightedRecall" )
255+
256+ @property
257+ def weightedPrecision (self ):
258+ """
259+ Returns weighted averaged precision.
260+ """
261+ return self .call ("weightedPrecision" )
262+
263+ def weightedFMeasure (self , beta = None ):
264+ """
265+ Returns weighted averaged f-measure.
266+ """
267+ if beta is None :
268+ return self .call ("weightedFMeasure" )
269+ else :
270+ return self .call ("weightedFMeasure" , beta )
271+
272+
144273def _test ():
145274 import doctest
146275 from pyspark import SparkContext
0 commit comments