Skip to content

Commit 3f5fc8e

Browse files
committed
test case modified and one runs condition added
1 parent cd5dc5c commit 3f5fc8e

File tree

2 files changed

+31
-43
lines changed

2 files changed

+31
-43
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,14 @@ class KMeans private (
160160
// random or k-means|| initializationMode
161161
private var initialModel: Option[KMeansModel] = None
162162

163-
/** Set the initial starting point, bypassing the random initialization or k-means||
164-
* The condition (model.k == this.k) must be met; failure will result in an
165-
* IllegalArgumentException.
166-
*/
163+
/**
164+
* Set the initial starting point, bypassing the random initialization or k-means||
165+
* The condition model.k == this.k must be met, and only one run is allowed;
166+
* failure in either case will result in an IllegalArgumentException.
167+
*/
167168
def setInitialModel(model: KMeansModel): this.type = {
168-
require(model.k==k, "mismatched cluster count")
169+
require(model.k == k, "mismatched cluster count")
170+
require(runs == 1, "can only run once with given initial model")
169171
initialModel = Some(model)
170172
this
171173
}
@@ -499,25 +501,6 @@ object KMeans {
499501
train(data, k, maxIterations, runs, K_MEANS_PARALLEL)
500502
}
501503

502-
/**
503-
* Trains a k-means model using the given set of parameters and initial cluster centers
504-
*
505-
* @param data training points stored as `RDD[Vector]`
506-
* @param k number of clusters
507-
* @param maxIterations max number of iterations
508-
* @param initialModel an existing set of cluster centers.
509-
*/
510-
def train(
511-
data: RDD[Vector],
512-
k: Int,
513-
maxIterations: Int,
514-
initialModel: KMeansModel): KMeansModel = {
515-
new KMeans().setK(k)
516-
.setMaxIterations(maxIterations)
517-
.setInitialModel(initialModel)
518-
.run(data)
519-
}
520-
521504
/**
522505
* Returns the index of the closest center to the given point, as well as the squared distance.
523506
*/

mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -282,28 +282,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
282282
test("Initialize using given cluster centers") {
283283
val points = Seq(
284284
Vectors.dense(0.0, 0.0),
285-
Vectors.dense(0.0, 0.1),
286-
Vectors.dense(0.1, 0.0),
287-
Vectors.dense(9.0, 0.0),
288-
Vectors.dense(9.0, 0.2),
289-
Vectors.dense(9.2, 0.0)
285+
Vectors.dense(1.0, 0.0),
286+
Vectors.dense(0.0, 1.0),
287+
Vectors.dense(1.0, 1.0)
290288
)
291289
val rdd = sc.parallelize(points, 3)
292-
val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1)
293-
294-
val tempDir = Utils.createTempDir()
295-
val path = tempDir.toURI.toString
296-
model.save(sc, path)
297-
val loadedModel = KMeansModel.load(sc, path)
298-
299-
val newModel = KMeans.train(rdd, k = 2, maxIterations = 2, initialModel = loadedModel)
300-
val predicts = newModel.predict(rdd).collect()
301290

302-
assert(predicts(0) === predicts(1))
303-
assert(predicts(0) === predicts(2))
304-
assert(predicts(3) === predicts(4))
305-
assert(predicts(3) === predicts(5))
306-
assert(predicts(0) != predicts(3))
291+
val m1 = new KMeansModel(Array(points(0), points(2)))
292+
val m2 = new KMeansModel(Array(points(1), points(3)))
293+
294+
val modelM1 = new KMeans()
295+
.setK(2)
296+
.setMaxIterations(1)
297+
.setInitialModel(m1)
298+
.run(rdd)
299+
val modelM2 = new KMeans()
300+
.setK(2)
301+
.setMaxIterations(1)
302+
.setInitialModel(m2)
303+
.run(rdd)
304+
305+
val predicts1 = modelM1.predict(rdd).collect()
306+
val predicts2 = modelM2.predict(rdd).collect()
307+
308+
assert(predicts1(0) === predicts1(1))
309+
assert(predicts1(2) === predicts1(3))
310+
assert(predicts2(0) === predicts2(1))
311+
assert(predicts2(2) === predicts2(3))
307312
}
308313

309314
}

0 commit comments

Comments
 (0)