Skip to content

Commit ddd8d86

Browse files
committed
Support initial model for KMeans.
1 parent a350bc1 commit ddd8d86

File tree

6 files changed

+210
-22
lines changed

6 files changed

+210
-22
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
5959
/**
6060
* Param for the initialization algorithm. This can be either "random" to choose random points as
6161
* initial cluster centers, or "k-means||" to use a parallel variant of k-means++
62-
* (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
62+
* (Bahmani et al., Scalable K-Means++, VLDB 2012), or "initialModel" to use a user provided
63+
* initial model for warm start. Default: k-means||.
64+
* If this was set as "initialModel", users must specify the initial model by `setInitialModel`,
65+
* otherwise, throws IllegalArgumentException.
6366
* @group expertParam
6467
*/
6568
@Since("1.5.0")
6669
final val initMode = new Param[String](this, "initMode", "The initialization algorithm. " +
67-
"Supported options: 'random' and 'k-means||'.",
70+
"Supported options: 'random', 'k-means||' and 'initialModel'.",
6871
(value: String) => MLlibKMeans.validateInitMode(value))
6972

7073
/** @group expertGetParam */
@@ -103,7 +106,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
103106
@Since("1.5.0")
104107
class KMeansModel private[ml] (
105108
@Since("1.5.0") override val uid: String,
106-
private val parentModel: MLlibKMeansModel)
109+
private[clustering] val parentModel: MLlibKMeansModel)
107110
extends Model[KMeansModel] with KMeansParams with MLWritable {
108111

109112
@Since("1.5.0")
@@ -123,7 +126,8 @@ class KMeansModel private[ml] (
123126
@Since("2.0.0")
124127
override def transform(dataset: Dataset[_]): DataFrame = {
125128
transformSchema(dataset.schema, logging = true)
126-
val predictUDF = udf((vector: Vector) => predict(vector))
129+
val tmpParent: MLlibKMeansModel = parentModel
130+
val predictUDF = udf((vector: Vector) => tmpParent.predict(vector))
127131
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
128132
}
129133

@@ -132,8 +136,6 @@ class KMeansModel private[ml] (
132136
validateAndTransformSchema(schema)
133137
}
134138

135-
private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
136-
137139
@Since("2.0.0")
138140
def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)
139141

@@ -253,7 +255,18 @@ object KMeansModel extends MLReadable[KMeansModel] {
253255
@Since("1.5.0")
254256
class KMeans @Since("1.5.0") (
255257
@Since("1.5.0") override val uid: String)
256-
extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable {
258+
extends Estimator[KMeansModel]
259+
with KMeansParams with HasInitialModel[KMeansModel] with MLWritable {
260+
261+
/**
262+
* A KMeansModel to use for warm start.
263+
* Note the cluster count of initial model must be equal with [[k]],
264+
* otherwise, throws IllegalArgumentException.
265+
* @group param
266+
*/
267+
@Since("2.2.0")
268+
final val initialModel: Param[KMeansModel] =
269+
new Param[KMeansModel](this, "initialModel", "A KMeansModel to use for warm start.")
257270

258271
setDefault(
259272
k -> 2,
@@ -300,6 +313,10 @@ class KMeans @Since("1.5.0") (
300313
@Since("1.5.0")
301314
def setSeed(value: Long): this.type = set(seed, value)
302315

316+
/** @group setParam */
317+
@Since("2.2.0")
318+
def setInitialModel(value: KMeansModel): this.type = set(initialModel, value)
319+
303320
@Since("2.0.0")
304321
override def fit(dataset: Dataset[_]): KMeansModel = {
305322
transformSchema(dataset.schema, logging = true)
@@ -322,6 +339,18 @@ class KMeans @Since("1.5.0") (
322339
.setMaxIterations($(maxIter))
323340
.setSeed($(seed))
324341
.setEpsilon($(tol))
342+
343+
if ($(initMode) == MLlibKMeans.K_MEANS_INITIAL_MODEL && isSet(initialModel)) {
344+
// Check that the feature dimensions are equal
345+
val numFeatures = instances.first().size
346+
val dimOfInitialModel = $(initialModel).clusterCenters.head.size
347+
require(numFeatures == dimOfInitialModel,
348+
s"The number of features in training dataset is $numFeatures," +
349+
s" which mismatched with dimension of initial model: $dimOfInitialModel.")
350+
351+
algo.setInitialModel($(initialModel).parentModel)
352+
}
353+
325354
val parentModel = algo.run(instances, Option(instr))
326355
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
327356
val summary = new KMeansSummary(
@@ -337,15 +366,63 @@ class KMeans @Since("1.5.0") (
337366

338367
@Since("1.5.0")
339368
override def transformSchema(schema: StructType): StructType = {
369+
if ($(initMode) == MLlibKMeans.K_MEANS_INITIAL_MODEL) {
370+
if (isSet(initialModel)) {
371+
val initialModelK = $(initialModel).parentModel.k
372+
if (initialModelK != $(k)) {
373+
throw new IllegalArgumentException("The initial model's cluster count = " +
374+
s"$initialModelK, mismatched with k = $k.")
375+
}
376+
} else {
377+
throw new IllegalArgumentException("Users must set param initialModel if you choose " +
378+
"'initialModel' as the initialization algorithm.")
379+
}
380+
} else {
381+
if (isSet(initialModel)) {
382+
logWarning(s"Param initialModel will take no effect when initMode is $initMode.")
383+
}
384+
}
340385
validateAndTransformSchema(schema)
341386
}
387+
388+
@Since("2.2.0")
389+
override def write: MLWriter = new KMeans.KMeansWriter(this)
342390
}
343391

344392
@Since("1.6.0")
345-
object KMeans extends DefaultParamsReadable[KMeans] {
393+
object KMeans extends MLReadable[KMeans] {
346394

347395
@Since("1.6.0")
348396
override def load(path: String): KMeans = super.load(path)
397+
398+
@Since("2.2.0")
399+
override def read: MLReader[KMeans] = new KMeansReader
400+
401+
/** [[MLWriter]] instance for [[KMeans]] */
402+
private[KMeans] class KMeansWriter(instance: KMeans) extends MLWriter {
403+
override protected def saveImpl(path: String): Unit = {
404+
DefaultParamsWriter.saveInitialModel(instance, path)
405+
DefaultParamsWriter.saveMetadata(instance, path, sc)
406+
}
407+
}
408+
409+
private class KMeansReader extends MLReader[KMeans] {
410+
411+
/** Checked against metadata when loading estimator */
412+
private val className = classOf[KMeans].getName
413+
414+
override def load(path: String): KMeans = {
415+
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
416+
val instance = new KMeans(metadata.uid)
417+
418+
DefaultParamsReader.getAndSetParams(instance, metadata)
419+
DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match {
420+
case Some(m) => instance.setInitialModel(m)
421+
case None => // initialModel doesn't exist, do nothing
422+
}
423+
instance
424+
}
425+
}
349426
}
350427

351428
/**
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.param.shared
19+
20+
import org.apache.spark.ml.Model
21+
import org.apache.spark.ml.param._
22+
23+
private[ml] trait HasInitialModel[T <: Model[T]] extends Params {
24+
25+
def initialModel: Param[T]
26+
27+
/** @group getParam */
28+
final def getInitialModel: T = $(initialModel)
29+
}

mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.ml._
3232
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
3333
import org.apache.spark.ml.feature.RFormulaModel
3434
import org.apache.spark.ml.param.{ParamPair, Params}
35+
import org.apache.spark.ml.param.shared.HasInitialModel
3536
import org.apache.spark.ml.tuning.ValidatorParams
3637
import org.apache.spark.sql.{SparkSession, SQLContext}
3738
import org.apache.spark.util.Utils
@@ -279,6 +280,8 @@ private[ml] object DefaultParamsWriter {
279280
* Helper for [[saveMetadata()]] which extracts the JSON to save.
280281
* This is useful for ensemble models which need to save metadata for many sub-models.
281282
*
283+
* Note: This function does not handle param `initialModel`.
284+
*
282285
* @see [[saveMetadata()]] for details on what this includes.
283286
*/
284287
def getMetadataToSave(
@@ -288,7 +291,8 @@ private[ml] object DefaultParamsWriter {
288291
paramMap: Option[JValue] = None): String = {
289292
val uid = instance.uid
290293
val cls = instance.getClass.getName
291-
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
294+
val params = instance.extractParamMap().toSeq
295+
.filter(_.param.name != "initialModel").asInstanceOf[Seq[ParamPair[Any]]]
292296
val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
293297
p.name -> parse(p.jsonEncode(v))
294298
}.toList))
@@ -306,6 +310,20 @@ private[ml] object DefaultParamsWriter {
306310
val metadataJson: String = compact(render(metadata))
307311
metadataJson
308312
}
313+
314+
def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]](
315+
instance: T, path: String): Unit = {
316+
if (instance.isDefined(instance.initialModel)) {
317+
val initialModelPath = new Path(path, "initialModel").toString
318+
val initialModel = instance.getOrDefault(instance.initialModel)
319+
// When saving, only keep the direct initialModel by eliminating possible initialModels of the
320+
// direct initialModel, to avoid unnecessary deep recursion of initialModel.
321+
if (initialModel.hasParam("initialModel")) {
322+
initialModel.clear(initialModel.getParam("initialModel"))
323+
}
324+
initialModel.save(initialModelPath)
325+
}
326+
}
309327
}
310328

311329
/**
@@ -434,6 +452,17 @@ private[ml] object DefaultParamsReader {
434452
val cls = Utils.classForName(metadata.className)
435453
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
436454
}
455+
456+
def loadInitialModel[M <: Model[M]](path: String, sc: SparkContext): Option[M] = {
457+
val hadoopConf = sc.hadoopConfiguration
458+
val initialModelPath = new Path(path, "initialModel")
459+
val fs = initialModelPath.getFileSystem(hadoopConf)
460+
if (fs.exists(initialModelPath)) {
461+
Some(loadParamsInstance[M](initialModelPath.toString, sc))
462+
} else {
463+
None
464+
}
465+
}
437466
}
438467

439468
/**

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,8 @@ object KMeans {
418418
val RANDOM = "random"
419419
@Since("0.8.0")
420420
val K_MEANS_PARALLEL = "k-means||"
421+
@Since("2.2.0")
422+
val K_MEANS_INITIAL_MODEL = "initialModel"
421423

422424
/**
423425
* Trains a k-means model using the given set of parameters.
@@ -593,6 +595,7 @@ object KMeans {
593595
initMode match {
594596
case KMeans.RANDOM => true
595597
case KMeans.K_MEANS_PARALLEL => true
598+
case KMeans.K_MEANS_INITIAL_MODEL => true
596599
case _ => false
597600
}
598601
}

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,28 @@ import scala.util.Random
2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.ml.linalg.{Vector, Vectors}
2424
import org.apache.spark.ml.param.ParamMap
25-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
26-
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
25+
import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils}
26+
import org.apache.spark.ml.util.TestingUtils._
27+
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
28+
import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
2729
import org.apache.spark.mllib.util.MLlibTestSparkContext
2830
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
2931

3032
private[clustering] case class TestRow(features: Vector)
3133

3234
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3335

36+
import testImplicits._
37+
3438
final val k = 5
3539
@transient var dataset: Dataset[_] = _
40+
@transient var rData: Dataset[_] = _
3641

3742
override def beforeAll(): Unit = {
3843
super.beforeAll()
3944

4045
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
46+
rData = GaussianMixtureSuite.rData.map(GaussianMixtureSuite.FeatureData).toDF()
4147
}
4248

4349
test("default parameters") {
@@ -152,6 +158,35 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
152158
val kmeans = new KMeans()
153159
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
154160
}
161+
162+
test("training with initial model") {
163+
val kmeans = new KMeans().setK(2).setSeed(1)
164+
val model1 = kmeans.fit(rData)
165+
val model2 = kmeans.setInitMode("initialModel").setInitialModel(model1).fit(rData)
166+
model2.clusterCenters.zip(model1.clusterCenters)
167+
.foreach { case (center2, center1) => assert(center2 ~== center1 absTol 1E-8) }
168+
}
169+
170+
test("training with initial model, error cases") {
171+
val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(1)
172+
173+
// Sets initMode with 'initialModel', but does not specify initial model.
174+
intercept[IllegalArgumentException] {
175+
kmeans.setInitMode("initialModel").fit(dataset)
176+
}
177+
178+
// Training dataset dimension mismatched.
179+
val modelWithDiffDim = KMeansSuite.generateRandomKMeansModel(4, k)
180+
intercept[IllegalArgumentException] {
181+
kmeans.setInitMode("initialModel").setInitialModel(modelWithDiffDim).fit(dataset)
182+
}
183+
184+
// Mismatched cluster count between initial model and param k.
185+
val initialModel = KMeansSuite.generateRandomKMeansModel(3, k + 1)
186+
intercept[IllegalArgumentException] {
187+
kmeans.setInitMode("initialModel").setInitialModel(initialModel).fit(dataset)
188+
}
189+
}
155190
}
156191

157192
object KMeansSuite {
@@ -173,6 +208,13 @@ object KMeansSuite {
173208
spark.createDataFrame(rdd)
174209
}
175210

211+
def generateRandomKMeansModel(dim: Int, k: Int, seed: Int = 42): KMeansModel = {
212+
val rng = new Random(seed)
213+
val clusterCenters = (1 to k)
214+
.map(i => MLlibVectors.dense(Array.fill(dim)(rng.nextDouble)))
215+
new KMeansModel(Identifiable.randomUID("kmeans"), new MLlibKMeansModel(clusterCenters.toArray))
216+
}
217+
176218
/**
177219
* Mapping from all Params to valid settings which differ from the defaults.
178220
* This is useful for tests which need to exercise all Params, such as save/load.
@@ -182,6 +224,7 @@ object KMeansSuite {
182224
"predictionCol" -> "myPrediction",
183225
"k" -> 3,
184226
"maxIter" -> 2,
185-
"tol" -> 0.01
227+
"tol" -> 0.01,
228+
"initialModel" -> generateRandomKMeansModel(3, 3)
186229
)
187230
}

0 commit comments

Comments
 (0)