Skip to content

Commit 3f6296f

Browse files
FlytxtRnDjkbradley
authored andcommitted
[SPARK-8018] [MLLIB] KMeans should accept initial cluster centers as param
This allows Kmeans to be initialized using an existing set of cluster centers provided as a KMeansModel object. This mode of initialization performs a single run. Author: FlytxtRnD <[email protected]> Closes apache#6737 from FlytxtRnD/Kmeans-8018 and squashes the following commits: 94b56df [FlytxtRnD] style correction ef95ee2 [FlytxtRnD] style correction c446c58 [FlytxtRnD] documentation and numRuns warning change 06d13ef [FlytxtRnD] numRuns corrected d12336e [FlytxtRnD] numRuns variable modifications 07f8554 [FlytxtRnD] remove setRuns from setIntialModel e721dfe [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 242ead1 [FlytxtRnD] corrected == to === in assert 714acb5 [FlytxtRnD] added numRuns 60c8ce2 [FlytxtRnD] ignore runs parameter and initialModel test suite changed 582e6d9 [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 3f5fc8e [FlytxtRnD] test case modified and one runs condition added cd5dc5c [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 16f1b53 [FlytxtRnD] Merge branch 'Kmeans-8018', remote-tracking branch 'upstream/master' into Kmeans-8018 e9c35d7 [FlytxtRnD] Remove getInitialModel and match cluster count criteria 6959861 [FlytxtRnD] Accept initial cluster centers in KMeans
1 parent 4692769 commit 3f6296f

File tree

3 files changed

+58
-6
lines changed

3 files changed

+58
-6
lines changed

docs/mllib-clustering.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on
3333
a given dataset, the algorithm returns the best clustering result).
3434
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
3535
* *epsilon* determines the distance threshold within which we consider k-means to have converged.
36+
* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed.
3637

3738
**Examples**
3839

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,21 @@ class KMeans private (
156156
this
157157
}
158158

159+
// Initial cluster centers can be provided as a KMeansModel object rather than using the
160+
// random or k-means|| initializationMode
161+
private var initialModel: Option[KMeansModel] = None
162+
163+
/**
164+
* Set the initial starting point, bypassing the random initialization or k-means||
165+
* The condition model.k == this.k must be met, failure results
166+
* in an IllegalArgumentException.
167+
*/
168+
def setInitialModel(model: KMeansModel): this.type = {
169+
require(model.k == k, "mismatched cluster count")
170+
initialModel = Some(model)
171+
this
172+
}
173+
159174
/**
160175
* Train a K-means model on the given set of points; `data` should be cached for high
161176
* performance, because this is an iterative algorithm.
@@ -193,20 +208,34 @@ class KMeans private (
193208

194209
val initStartTime = System.nanoTime()
195210

196-
val centers = if (initializationMode == KMeans.RANDOM) {
197-
initRandom(data)
211+
// Only one run is allowed when initialModel is given
212+
val numRuns = if (initialModel.nonEmpty) {
213+
if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
214+
1
198215
} else {
199-
initKMeansParallel(data)
216+
runs
200217
}
201218

219+
val centers = initialModel match {
220+
case Some(kMeansCenters) => {
221+
Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
222+
}
223+
case None => {
224+
if (initializationMode == KMeans.RANDOM) {
225+
initRandom(data)
226+
} else {
227+
initKMeansParallel(data)
228+
}
229+
}
230+
}
202231
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
203232
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
204233
" seconds.")
205234

206-
val active = Array.fill(runs)(true)
207-
val costs = Array.fill(runs)(0.0)
235+
val active = Array.fill(numRuns)(true)
236+
val costs = Array.fill(numRuns)(0.0)
208237

209-
var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
238+
var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
210239
var iteration = 0
211240

212241
val iterationStartTime = System.nanoTime()

mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
278278
}
279279
}
280280
}
281+
282+
test("Initialize using given cluster centers") {
283+
val points = Seq(
284+
Vectors.dense(0.0, 0.0),
285+
Vectors.dense(1.0, 0.0),
286+
Vectors.dense(0.0, 1.0),
287+
Vectors.dense(1.0, 1.0)
288+
)
289+
val rdd = sc.parallelize(points, 3)
290+
// creating an initial model
291+
val initialModel = new KMeansModel(Array(points(0), points(2)))
292+
293+
val returnModel = new KMeans()
294+
.setK(2)
295+
.setMaxIterations(0)
296+
.setInitialModel(initialModel)
297+
.run(rdd)
298+
// comparing the returned model and the initial model
299+
assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0))
300+
assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
301+
}
302+
281303
}
282304

283305
object KMeansSuite extends SparkFunSuite {

0 commit comments

Comments
 (0)