Skip to content

Commit 97c96b6

Browse files
committed
Converted all clustering tests to check streaming
1 parent 1f3d933 commit 97c96b6

File tree

5 files changed

+69
-58
lines changed

5 files changed

+69
-58
lines changed

mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
package org.apache.spark.ml.clustering
1919

20-
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.ml.linalg.Vector
2121
import org.apache.spark.ml.param.ParamMap
22-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
23-
import org.apache.spark.mllib.util.MLlibTestSparkContext
22+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
2423
import org.apache.spark.sql.Dataset
2524

2625
class BisectingKMeansSuite
27-
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
26+
extends MLTest with DefaultReadWriteTest {
27+
28+
import Encoders._
2829

2930
final val k = 5
3031
@transient var dataset: Dataset[_] = _
@@ -63,10 +64,12 @@ class BisectingKMeansSuite
6364

6465
// Verify fit does not fail on very sparse data
6566
val model = bkm.fit(sparseDataset)
66-
val result = model.transform(sparseDataset)
67-
val numClusters = result.select("prediction").distinct().collect().length
68-
// Verify we hit the edge case
69-
assert(numClusters < k && numClusters > 1)
67+
68+
testTransformerByGlobalCheckFunc[Vector](sparseDataset.toDF(), model, "prediction") { rows =>
69+
val numClusters = rows.distinct.length
70+
// Verify we hit the edge case
71+
assert(numClusters < k && numClusters > 1)
72+
}
7073
}
7174

7275
test("setter/getter") {
@@ -100,17 +103,13 @@ class BisectingKMeansSuite
100103
val model = bkm.fit(dataset)
101104
assert(model.clusterCenters.length === k)
102105

103-
val transformed = model.transform(dataset)
104-
val expectedColumns = Array("features", predictionColName)
105-
expectedColumns.foreach { column =>
106-
assert(transformed.columns.contains(column))
106+
testTransformerByGlobalCheckFunc[Vector](dataset.toDF(), model,
107+
"features", predictionColName) { rows =>
108+
val clusters = rows.map(_.getAs[Int](predictionColName)).toSet
109+
assert(clusters === Set(0, 1, 2, 3, 4))
110+
assert(model.computeCost(dataset) < 0.1)
111+
assert(model.hasParent)
107112
}
108-
val clusters =
109-
transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet
110-
assert(clusters.size === k)
111-
assert(clusters === Set(0, 1, 2, 3, 4))
112-
assert(model.computeCost(dataset) < 0.1)
113-
assert(model.hasParent)
114113

115114
// Check validity of model summary
116115
val numRows = dataset.count()
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.clustering
19+
20+
import org.apache.spark.ml.linalg.Vector
21+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
22+
23+
private[clustering] object Encoders {
24+
implicit val vectorEncoder = ExpressionEncoder[Vector]()
25+
}

mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors}
2222
import org.apache.spark.ml.param.ParamMap
2323
import org.apache.spark.ml.stat.distribution.MultivariateGaussian
24-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
24+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
2525
import org.apache.spark.ml.util.TestingUtils._
26-
import org.apache.spark.mllib.util.MLlibTestSparkContext
2726
import org.apache.spark.sql.{Dataset, Row}
2827

2928

30-
class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
29+
class GaussianMixtureSuite extends MLTest
3130
with DefaultReadWriteTest {
3231

3332
import testImplicits._
33+
import Encoders._
3434
import GaussianMixtureSuite._
3535

3636
final val k = 5
@@ -118,15 +118,10 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
118118
assert(model.weights.length === k)
119119
assert(model.gaussians.length === k)
120120

121-
val transformed = model.transform(dataset)
122-
val expectedColumns = Array("features", predictionColName, probabilityColName)
123-
expectedColumns.foreach { column =>
124-
assert(transformed.columns.contains(column))
125-
}
126-
127121
// Check prediction matches the highest probability, and probabilities sum to one.
128-
transformed.select(predictionColName, probabilityColName).collect().foreach {
129-
case Row(pred: Int, prob: Vector) =>
122+
testTransformer[Vector](dataset.toDF(), model,
123+
"features", predictionColName, probabilityColName) {
124+
case Row(_, pred: Int, prob: Vector) =>
130125
val probArray = prob.toArray
131126
val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2
132127
assert(pred === predFromProb)

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,17 @@ package org.apache.spark.ml.clustering
1919

2020
import scala.util.Random
2121

22-
import org.apache.spark.SparkFunSuite
2322
import org.apache.spark.ml.linalg.{Vector, Vectors}
2423
import org.apache.spark.ml.param.ParamMap
25-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
24+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
2625
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
27-
import org.apache.spark.mllib.util.MLlibTestSparkContext
2826
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
2927

3028
private[clustering] case class TestRow(features: Vector)
3129

32-
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
30+
class KMeansSuite extends MLTest with DefaultReadWriteTest {
31+
32+
import Encoders._
3333

3434
final val k = 5
3535
@transient var dataset: Dataset[_] = _
@@ -97,15 +97,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
9797
val model = kmeans.fit(dataset)
9898
assert(model.clusterCenters.length === k)
9999

100-
val transformed = model.transform(dataset)
101-
val expectedColumns = Array("features", predictionColName)
102-
expectedColumns.foreach { column =>
103-
assert(transformed.columns.contains(column))
100+
testTransformerByGlobalCheckFunc[Vector](dataset.toDF(), model,
101+
"features", predictionColName) { rows =>
102+
val clusters = rows.map(_.getAs[Int](predictionColName)).toSet
103+
assert(clusters.size === k)
104+
assert(clusters === Set(0, 1, 2, 3, 4))
104105
}
105-
val clusters =
106-
transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet
107-
assert(clusters.size === k)
108-
assert(clusters === Set(0, 1, 2, 3, 4))
106+
109107
assert(model.computeCost(dataset) < 0.1)
110108
assert(model.hasParent)
111109

@@ -137,9 +135,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
137135
model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName)
138136

139137
val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName))
140-
Seq(featuresColName, predictionColName).foreach { column =>
141-
assert(transformed.columns.contains(column))
142-
}
138+
assert(transformed.schema.fieldNames.toSet === Set(featuresColName, predictionColName))
143139
assert(model.getFeaturesCol == featuresColName)
144140
assert(model.getPredictionCol == predictionColName)
145141
}

mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ package org.apache.spark.ml.clustering
1919

2020
import org.apache.hadoop.fs.Path
2121

22-
import org.apache.spark.SparkFunSuite
2322
import org.apache.spark.ml.linalg.{Vector, Vectors}
24-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
23+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
2524
import org.apache.spark.ml.util.TestingUtils._
26-
import org.apache.spark.mllib.util.MLlibTestSparkContext
2725
import org.apache.spark.sql._
26+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2827

2928

3029
object LDASuite {
@@ -60,10 +59,12 @@ object LDASuite {
6059
}
6160

6261

63-
class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
62+
class LDASuite extends MLTest with DefaultReadWriteTest {
6463

6564
import testImplicits._
6665

66+
implicit val vectorEncoder = ExpressionEncoder[Vector]()
67+
6768
val k: Int = 5
6869
val vocabSize: Int = 30
6970
@transient var dataset: Dataset[_] = _
@@ -185,16 +186,11 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
185186
assert(model.topicsMatrix.numCols === k)
186187
assert(!model.isDistributed)
187188

188-
// transform()
189-
val transformed = model.transform(dataset)
190-
val expectedColumns = Array("features", lda.getTopicDistributionCol)
191-
expectedColumns.foreach { column =>
192-
assert(transformed.columns.contains(column))
193-
}
194-
transformed.select(lda.getTopicDistributionCol).collect().foreach { r =>
195-
val topicDistribution = r.getAs[Vector](0)
196-
assert(topicDistribution.size === k)
197-
assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0))
189+
testTransformer[Vector](dataset.toDF(), model,
190+
"features", lda.getTopicDistributionCol) {
191+
case Row(_, topicDistribution: Vector) =>
192+
assert(topicDistribution.size === k)
193+
assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0))
198194
}
199195

200196
// logLikelihood, logPerplexity

0 commit comments

Comments
 (0)