Skip to content

Commit e9c35d7

Browse files
committed
Remove getInitialModel and match cluster count criteria
1 parent 6959861 commit e9c35d7

File tree

1 file changed

+6
-12
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/clustering

1 file changed

+6
-12
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,11 @@ class KMeans private (
165165
* IllegalArgumentException.
166166
*/
167167
def setInitialModel(model: KMeansModel): this.type = {
168-
if (model.k == k) {
169-
initialModel = Some(model)
170-
} else {
171-
throw new IllegalArgumentException("mismatched cluster count (model.k != k)")
172-
}
168+
require(model.k==k, "mismatched cluster count")
169+
initialModel = Some(model)
173170
this
174171
}
175172

176-
/** Return the user supplied initial KMeansModel, if supplied */
177-
def getInitialModel: Option[KMeansModel] = initialModel
178-
179173
/**
180174
* Train a K-means model on the given set of points; `data` should be cached for high
181175
* performance, because this is an iterative algorithm.
@@ -514,10 +508,10 @@ object KMeans {
514508
* @param initialModel an existing set of cluster centers.
515509
*/
516510
def train(
517-
data: RDD[Vector],
518-
k: Int,
519-
maxIterations: Int,
520-
initialModel: KMeansModel): KMeansModel = {
511+
data: RDD[Vector],
512+
k: Int,
513+
maxIterations: Int,
514+
initialModel: KMeansModel): KMeansModel = {
521515
new KMeans().setK(k)
522516
.setMaxIterations(maxIterations)
523517
.setInitialModel(initialModel)

0 commit comments

Comments
 (0)