Skip to content

Commit ca8243f

Browse files
shahidki31srowen
authored andcommitted
[MINOR][ML] Minor correction in the powerIterationSuite
## What changes were proposed in this pull request? Currently the power iteration clustering test in spark ml, maps the results to the labels 0 and 1 for assertion. Since the clustering outputs need not be the same as the mapped labels, it may cause failure in the test case. Even if it correctly maps, theoretically we cannot guarantee which set belongs to which cluster label. KMeans can assign label 0 to either of the set. PowerIterationClusteringSuite in the MLLib checks the clustering results without mapping to the particular cluster label, as shown below. `` val predictions = Array.fill(2)(mutable.Set.empty[Long]) model.assignments.collect().foreach { a => predictions(a.cluster) += a.id } assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) `` ## How was this patch tested? Existing tests Author: Shahid <[email protected]> Closes #21689 from shahidki31/picTestSuiteMinorCorrection.
1 parent 1a2655a commit ca8243f

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.ml.clustering
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.{SparkException, SparkFunSuite}
2123
import org.apache.spark.ml.util.DefaultReadWriteTest
2224
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -76,23 +78,31 @@ class PowerIterationClusteringSuite extends SparkFunSuite
7678
.setMaxIter(40)
7779
.setWeightCol("weight")
7880
.assignClusters(data)
79-
val localAssignments = assignments
80-
.select('id, 'cluster)
81-
.as[(Long, Int)].collect().toSet
82-
val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++
83-
(n1 until n).map(x => (x, 0)).toSet
84-
assert(localAssignments === expectedResult)
81+
.select("id", "cluster")
82+
.as[(Long, Int)]
83+
.collect()
84+
85+
val predictions = Array.fill(2)(mutable.Set.empty[Long])
86+
assignments.foreach {
87+
case (id, cluster) => predictions(cluster) += id
88+
}
89+
assert(predictions.toSet === Set((0 until n1).toSet, (n1 until n).toSet))
8590

8691
val assignments2 = new PowerIterationClustering()
8792
.setK(2)
8893
.setMaxIter(10)
8994
.setInitMode("degree")
9095
.setWeightCol("weight")
9196
.assignClusters(data)
92-
val localAssignments2 = assignments2
93-
.select('id, 'cluster)
94-
.as[(Long, Int)].collect().toSet
95-
assert(localAssignments2 === expectedResult)
97+
.select("id", "cluster")
98+
.as[(Long, Int)]
99+
.collect()
100+
101+
val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
102+
assignments2.foreach {
103+
case (id, cluster) => predictions2(cluster) += id
104+
}
105+
assert(predictions2.toSet === Set((0 until n1).toSet, (n1 until n).toSet))
96106
}
97107

98108
test("supported input types") {

0 commit comments

Comments
 (0)