Skip to content

Commit 60c8ce2

Browse files
committed
ignore runs parameter and initialModel test suite changed
1 parent 582e6d9 commit 60c8ce2

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ class KMeans private (
167167
*/
168168
def setInitialModel(model: KMeansModel): this.type = {
169169
require(model.k == k, "mismatched cluster count")
170-
require(runs == 1, "can only run once with given initial model")
170+
this.setRuns(1)
171+
logWarning("Ignoring runs; one run is allowed when initialModel is given.")
171172
initialModel = Some(model)
172173
this
173174
}

mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -287,28 +287,17 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
287287
Vectors.dense(1.0, 1.0)
288288
)
289289
val rdd = sc.parallelize(points, 3)
290+
// creating an initial model
291+
val initialModel = new KMeansModel(Array(points(0), points(2)))
290292

291-
val m1 = new KMeansModel(Array(points(0), points(2)))
292-
val m2 = new KMeansModel(Array(points(1), points(3)))
293-
294-
val modelM1 = new KMeans()
295-
.setK(2)
296-
.setMaxIterations(1)
297-
.setInitialModel(m1)
298-
.run(rdd)
299-
val modelM2 = new KMeans()
293+
val returnModel = new KMeans()
300294
.setK(2)
301-
.setMaxIterations(1)
302-
.setInitialModel(m2)
295+
.setMaxIterations(0)
296+
.setInitialModel(initialModel)
303297
.run(rdd)
304-
305-
val predicts1 = modelM1.predict(rdd).collect()
306-
val predicts2 = modelM2.predict(rdd).collect()
307-
308-
assert(predicts1(0) === predicts1(1))
309-
assert(predicts1(2) === predicts1(3))
310-
assert(predicts2(0) === predicts2(1))
311-
assert(predicts2(2) === predicts2(3))
298+
// comparing the returned model and the initial model
299+
assert(returnModel.clusterCenters(0) == initialModel.clusterCenters(0))
300+
assert(returnModel.clusterCenters(1) == initialModel.clusterCenters(1))
312301
}
313302

314303
}

0 commit comments

Comments
 (0)