Skip to content

Commit a571adc

Browse files
committed
updated python code
1 parent 079e114 commit a571adc

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

python/pyspark/ml/evaluation.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def isLargerBetter(self):
106106

107107

108108
@inherit_doc
109-
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol,
109+
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, HasWeightCol,
110110
JavaMLReadable, JavaMLWritable):
111111
"""
112112
.. note:: Experimental
@@ -130,6 +130,14 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
130130
>>> evaluator2 = BinaryClassificationEvaluator.load(bce_path)
131131
>>> str(evaluator2.getRawPredictionCol())
132132
'raw'
133+
>>> scoreAndLabelsAndWeight = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1], x[2]),
134+
... [(0.1, 0.0, 1.0), (0.1, 1.0, 0.9), (0.4, 0.0, 0.7), (0.6, 0.0, 0.9),
135+
... (0.6, 1.0, 1.0), (0.6, 1.0, 0.3), (0.8, 1.0, 1.0)])
136+
>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label", "weight"])
137+
...
138+
>>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw", weightCol="weight")
139+
>>> evaluator.evaluate(dataset)
140+
0.70...
133141
134142
.. versionadded:: 1.4.0
135143
"""
@@ -140,10 +148,10 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
140148

141149
@keyword_only
142150
def __init__(self, rawPredictionCol="rawPrediction", labelCol="label",
143-
metricName="areaUnderROC"):
151+
metricName="areaUnderROC", weightCol=None):
144152
"""
145153
__init__(self, rawPredictionCol="rawPrediction", labelCol="label", \
146-
metricName="areaUnderROC")
154+
metricName="areaUnderROC", weightCol=None)
147155
"""
148156
super(BinaryClassificationEvaluator, self).__init__()
149157
self._java_obj = self._new_java_obj(
@@ -169,10 +177,10 @@ def getMetricName(self):
169177
@keyword_only
170178
@since("1.4.0")
171179
def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
172-
metricName="areaUnderROC"):
180+
metricName="areaUnderROC", weightCol=None):
173181
"""
174182
setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \
175-
metricName="areaUnderROC")
183+
metricName="areaUnderROC", weightCol=None)
176184
Sets params for binary classification evaluator.
177185
"""
178186
kwargs = self._input_kwargs

python/pyspark/mllib/evaluation.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class BinaryClassificationMetrics(JavaModelWrapper):
3030
"""
3131
Evaluator for binary classification.
3232
33-
:param scoreAndLabels: an RDD of (score, label) pairs
33+
:param scoreAndLabelsWithOptWeight: an RDD of score, label and optional weight.
3434
3535
>>> scoreAndLabels = sc.parallelize([
3636
... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
@@ -40,16 +40,28 @@ class BinaryClassificationMetrics(JavaModelWrapper):
4040
>>> metrics.areaUnderPR
4141
0.83...
4242
>>> metrics.unpersist()
43+
>>> scoreAndLabelsWithOptWeight = sc.parallelize([
44+
... (0.1, 0.0, 1.0), (0.1, 1.0, 0.4), (0.4, 0.0, 0.2), (0.6, 0.0, 0.6), (0.6, 1.0, 0.9),
45+
... (0.6, 1.0, 0.5), (0.8, 1.0, 0.7)], 2)
46+
>>> metrics = BinaryClassificationMetrics(scoreAndLabelsWithOptWeight)
47+
>>> metrics.areaUnderROC
48+
0.70...
49+
>>> metrics.areaUnderPR
50+
0.83...
4351
4452
.. versionadded:: 1.4.0
4553
"""
4654

47-
def __init__(self, scoreAndLabels):
48-
sc = scoreAndLabels.ctx
55+
def __init__(self, scoreAndLabelsWithOptWeight):
56+
sc = scoreAndLabelsWithOptWeight.ctx
4957
sql_ctx = SQLContext.getOrCreate(sc)
50-
df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
58+
numCol = len(scoreAndLabelsWithOptWeight.first())
59+
schema = StructType([
5160
StructField("score", DoubleType(), nullable=False),
52-
StructField("label", DoubleType(), nullable=False)]))
61+
StructField("label", DoubleType(), nullable=False)])
62+
if (numCol == 3):
63+
schema.add("weight", DoubleType(), False)
64+
df = sql_ctx.createDataFrame(scoreAndLabelsWithOptWeight, schema=schema)
5365
java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
5466
java_model = java_class(df._jdf)
5567
super(BinaryClassificationMetrics, self).__init__(java_model)

0 commit comments

Comments
 (0)