From c5eadb6b6113352bc3741eb1be3ae8c730b6ba8e Mon Sep 17 00:00:00 2001 From: Shahid Date: Wed, 9 May 2018 08:18:50 +0530 Subject: [PATCH 1/2] Example code for Power Iteration Clustering --- .../clustering/PowerIterationClustering.scala | 12 ++------ .../PowerIterationClusteringSuite.scala | 28 +++++++++++++++++-- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index 2c30a1d9aa94..9077a5e82d1c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -226,17 +226,9 @@ class PowerIterationClustering private[clustering] ( val predictionsSchema = StructType(Seq( StructField($(idCol), LongType, nullable = false), StructField($(predictionCol), IntegerType, nullable = false))) - val predictions = { - val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) - dataset.schema($(idCol)).dataType match { - case _: LongType => - uncastPredictions - case otherType => - uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) - } - } + val predictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) - dataset.join(predictions, $(idCol)) + predictions } @Since("2.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 65328df17baf..66eeee710b78 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -84,7 +84,7 @@ class PowerIterationClusteringSuite extends SparkFunSuite result.select("id", "prediction").collect().foreach { case Row(id: Long, cluster: Integer) => predictions(cluster) += id } - assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) val result2 = new PowerIterationClustering() .setK(2) @@ -95,7 +95,31 @@ class PowerIterationClusteringSuite extends SparkFunSuite result2.select("id", "prediction").collect().foreach { case Row(id: Long, cluster: Integer) => predictions2(cluster) += id } - assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + assert(predictions2.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) + } + + test("power iteration clustering: random init mode") { + + val data = spark.createDataFrame(Seq( + (0, Array(1), Array(0.9)), + (1, Array(2), Array(0.9)), + (2, Array(3), Array(0.9)), + (3, Array(4), Array(0.1)), + (4, Array(5), Array(0.9)) + )).toDF("id", "neighbors", "similarities") + + val result = new PowerIterationClustering() + .setK(2) + .setMaxIter(10) + .setInitMode("random") + .transform(data) + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + result.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions(cluster) += id + } + assert(predictions.toSet == Set((0 until 4).toSet, Set(4, 5))) + assert(result.columns(1).equals("prediction")) } test("supported input types") { From 213ca9fc6be8f93db3e979346125ab516e5ae42e Mon Sep 17 00:00:00 2001 From: Shahid Date: Wed, 9 May 2018 08:24:37 +0530 Subject: [PATCH 2/2] Example code for Power Iteration Clustering --- .../spark/ml/clustering/PowerIterationClusteringSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 66eeee710b78..ad07cccaab46 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -119,7 +119,6 @@ class PowerIterationClusteringSuite extends SparkFunSuite case Row(id: Long, cluster: Integer) => predictions(cluster) += id } assert(predictions.toSet == Set((0 until 4).toSet, Set(4, 5))) - assert(result.columns(1).equals("prediction")) } test("supported input types") {