Skip to content

Commit a980e6b

Browse files
committed
Address review comments
1 parent 7e8bcc7 commit a980e6b

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

python/pyspark/ml/evaluation.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -339,28 +339,20 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
339339
Evaluator for Clustering results, which expects two input
340340
columns: prediction and features.
341341
342-
>>> from sklearn import datasets
343-
>>> from pyspark.sql.types import *
344-
>>> from pyspark.ml.linalg import Vectors, VectorUDT
345-
>>> from pyspark.ml.evaluation import ClusteringEvaluator
346-
...
347-
>>> iris = datasets.load_iris()
348-
>>> iris_rows = [(Vectors.dense(x), int(iris.target[i]))
349-
... for i, x in enumerate(iris.data)]
350-
>>> schema = StructType([
351-
... StructField("features", VectorUDT(), True),
352-
... StructField("cluster_id", IntegerType(), True)])
353-
>>> rdd = spark.sparkContext.parallelize(iris_rows)
354-
>>> dataset = spark.createDataFrame(rdd, schema)
342+
>>> from pyspark.ml.linalg import Vectors
343+
>>> scoreAndLabels = map(lambda x: (Vectors.dense(x[0]), x[1]),
344+
... [([0.0, 0.5], 0.0), ([0.5, 0.0], 0.0), ([10.0, 11.0], 1.0),
345+
... ([10.5, 11.5], 1.0), ([1.0, 1.0], 0.0), ([8.0, 6.0], 1.0)])
346+
>>> dataset = spark.createDataFrame(scoreAndLabels, ["features", "prediction"])
355347
...
356-
>>> evaluator = ClusteringEvaluator(predictionCol="cluster_id")
348+
>>> evaluator = ClusteringEvaluator(predictionCol="prediction")
357349
>>> evaluator.evaluate(dataset)
358-
0.656...
350+
0.9079...
359351
>>> ce_path = temp_path + "/ce"
360352
>>> evaluator.save(ce_path)
361353
>>> evaluator2 = ClusteringEvaluator.load(ce_path)
362354
>>> str(evaluator2.getPredictionCol())
363-
'cluster_id'
355+
'prediction'
364356
365357
.. versionadded:: 2.3.0
366358
"""
@@ -378,8 +370,7 @@ def __init__(self, predictionCol="prediction", featuresCol="features",
378370
super(ClusteringEvaluator, self).__init__()
379371
self._java_obj = self._new_java_obj(
380372
"org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid)
381-
self._setDefault(predictionCol="prediction", featuresCol="features",
382-
metricName="silhouette")
373+
self._setDefault(metricName="silhouette")
383374
kwargs = self._input_kwargs
384375
self._set(**kwargs)
385376

0 commit comments

Comments
 (0)