@@ -339,28 +339,20 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
339339 Evaluator for Clustering results, which expects two input
340340 columns: prediction and features.
341341
342- >>> from sklearn import datasets
343- >>> from pyspark.sql.types import *
344- >>> from pyspark.ml.linalg import Vectors, VectorUDT
345- >>> from pyspark.ml.evaluation import ClusteringEvaluator
346- ...
347- >>> iris = datasets.load_iris()
348- >>> iris_rows = [(Vectors.dense(x), int(iris.target[i]))
349- ... for i, x in enumerate(iris.data)]
350- >>> schema = StructType([
351- ... StructField("features", VectorUDT(), True),
352- ... StructField("cluster_id", IntegerType(), True)])
353- >>> rdd = spark.sparkContext.parallelize(iris_rows)
354- >>> dataset = spark.createDataFrame(rdd, schema)
342+ >>> from pyspark.ml.linalg import Vectors
343+ >>> scoreAndLabels = map(lambda x: (Vectors.dense(x[0]), x[1]),
344+ ... [([0.0, 0.5], 0.0), ([0.5, 0.0], 0.0), ([10.0, 11.0], 1.0),
345+ ... ([10.5, 11.5], 1.0), ([1.0, 1.0], 0.0), ([8.0, 6.0], 1.0)])
346+ >>> dataset = spark.createDataFrame(scoreAndLabels, ["features", "prediction"])
355347 ...
356- >>> evaluator = ClusteringEvaluator(predictionCol="cluster_id ")
348+ >>> evaluator = ClusteringEvaluator(predictionCol="prediction ")
357349 >>> evaluator.evaluate(dataset)
358- 0.656 ...
350+ 0.9079 ...
359351 >>> ce_path = temp_path + "/ce"
360352 >>> evaluator.save(ce_path)
361353 >>> evaluator2 = ClusteringEvaluator.load(ce_path)
362354 >>> str(evaluator2.getPredictionCol())
363- 'cluster_id '
355+ 'prediction '
364356
365357 .. versionadded:: 2.3.0
366358 """
@@ -378,8 +370,7 @@ def __init__(self, predictionCol="prediction", featuresCol="features",
378370 super (ClusteringEvaluator , self ).__init__ ()
379371 self ._java_obj = self ._new_java_obj (
380372 "org.apache.spark.ml.evaluation.ClusteringEvaluator" , self .uid )
381- self ._setDefault (predictionCol = "prediction" , featuresCol = "features" ,
382- metricName = "silhouette" )
373+ self ._setDefault (metricName = "silhouette" )
383374 kwargs = self ._input_kwargs
384375 self ._set (** kwargs )
385376
0 commit comments