@@ -59,12 +59,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
5959 /**
6060 * Param for the initialization algorithm. This can be either "random" to choose random points as
6161 * initial cluster centers, or "k-means||" to use a parallel variant of k-means++
62- * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
62+ * (Bahmani et al., Scalable K-Means++, VLDB 2012), or "initialModel" to use a user provided
63+ * initial model for warm start. Default: k-means||.
64+ * If this was set as "initialModel", users must specify the initial model by `setInitialModel`,
65+ * otherwise, throws IllegalArgumentException.
6366 * @group expertParam
6467 */
6568 @ Since (" 1.5.0" )
6669 final val initMode = new Param [String ](this , " initMode" , " The initialization algorithm. " +
67- " Supported options: 'random' and 'k-means||'." ,
70+ " Supported options: 'random', 'k-means||' and 'initialModel '." ,
6871 (value : String ) => MLlibKMeans .validateInitMode(value))
6972
7073 /** @group expertGetParam */
@@ -103,7 +106,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
103106@ Since (" 1.5.0" )
104107class KMeansModel private [ml] (
105108 @ Since (" 1.5.0" ) override val uid : String ,
106- private val parentModel : MLlibKMeansModel )
109+ private [clustering] val parentModel : MLlibKMeansModel )
107110 extends Model [KMeansModel ] with KMeansParams with MLWritable {
108111
109112 @ Since (" 1.5.0" )
@@ -123,7 +126,8 @@ class KMeansModel private[ml] (
123126 @ Since (" 2.0.0" )
124127 override def transform (dataset : Dataset [_]): DataFrame = {
125128 transformSchema(dataset.schema, logging = true )
126- val predictUDF = udf((vector : Vector ) => predict(vector))
129+ val tmpParent : MLlibKMeansModel = parentModel
130+ val predictUDF = udf((vector : Vector ) => tmpParent.predict(vector))
127131 dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
128132 }
129133
@@ -132,8 +136,6 @@ class KMeansModel private[ml] (
132136 validateAndTransformSchema(schema)
133137 }
134138
135- private [clustering] def predict (features : Vector ): Int = parentModel.predict(features)
136-
137139 @ Since (" 2.0.0" )
138140 def clusterCenters : Array [Vector ] = parentModel.clusterCenters.map(_.asML)
139141
@@ -253,7 +255,18 @@ object KMeansModel extends MLReadable[KMeansModel] {
253255@ Since (" 1.5.0" )
254256class KMeans @ Since (" 1.5.0" ) (
255257 @ Since (" 1.5.0" ) override val uid : String )
256- extends Estimator [KMeansModel ] with KMeansParams with DefaultParamsWritable {
258+ extends Estimator [KMeansModel ]
259+ with KMeansParams with HasInitialModel [KMeansModel ] with MLWritable {
260+
261+ /**
262+ * A KMeansModel to use for warm start.
263+ * Note the cluster count of initial model must be equal with [[k ]],
264+ * otherwise, throws IllegalArgumentException.
265+ * @group param
266+ */
267+ @ Since (" 2.2.0" )
268+ final val initialModel : Param [KMeansModel ] =
269+ new Param [KMeansModel ](this , " initialModel" , " A KMeansModel to use for warm start." )
257270
258271 setDefault(
259272 k -> 2 ,
@@ -300,6 +313,10 @@ class KMeans @Since("1.5.0") (
300313 @ Since (" 1.5.0" )
301314 def setSeed (value : Long ): this .type = set(seed, value)
302315
316+ /** @group setParam */
317+ @ Since (" 2.2.0" )
318+ def setInitialModel (value : KMeansModel ): this .type = set(initialModel, value)
319+
303320 @ Since (" 2.0.0" )
304321 override def fit (dataset : Dataset [_]): KMeansModel = {
305322 transformSchema(dataset.schema, logging = true )
@@ -322,6 +339,18 @@ class KMeans @Since("1.5.0") (
322339 .setMaxIterations($(maxIter))
323340 .setSeed($(seed))
324341 .setEpsilon($(tol))
342+
343+ if ($(initMode) == MLlibKMeans .K_MEANS_INITIAL_MODEL && isSet(initialModel)) {
344+ // Check that the feature dimensions are equal
345+ val numFeatures = instances.first().size
346+ val dimOfInitialModel = $(initialModel).clusterCenters.head.size
347+ require(numFeatures == dimOfInitialModel,
348+ s " The number of features in training dataset is $numFeatures, " +
349+ s " which mismatched with dimension of initial model: $dimOfInitialModel. " )
350+
351+ algo.setInitialModel($(initialModel).parentModel)
352+ }
353+
325354 val parentModel = algo.run(instances, Option (instr))
326355 val model = copyValues(new KMeansModel (uid, parentModel).setParent(this ))
327356 val summary = new KMeansSummary (
@@ -337,15 +366,63 @@ class KMeans @Since("1.5.0") (
337366
338367 @ Since (" 1.5.0" )
339368 override def transformSchema (schema : StructType ): StructType = {
369+ if ($(initMode) == MLlibKMeans .K_MEANS_INITIAL_MODEL ) {
370+ if (isSet(initialModel)) {
371+ val initialModelK = $(initialModel).parentModel.k
372+ if (initialModelK != $(k)) {
373+ throw new IllegalArgumentException (" The initial model's cluster count = " +
374+ s " $initialModelK, mismatched with k = $k. " )
375+ }
376+ } else {
377+ throw new IllegalArgumentException (" Users must set param initialModel if you choose " +
378+ " 'initialModel' as the initialization algorithm." )
379+ }
380+ } else {
381+ if (isSet(initialModel)) {
382+ logWarning(s " Param initialModel will take no effect when initMode is $initMode. " )
383+ }
384+ }
340385 validateAndTransformSchema(schema)
341386 }
387+
388+ @ Since (" 2.2.0" )
389+ override def write : MLWriter = new KMeans .KMeansWriter (this )
342390}
343391
344392@ Since (" 1.6.0" )
345- object KMeans extends DefaultParamsReadable [KMeans ] {
393+ object KMeans extends MLReadable [KMeans ] {
346394
347395 @ Since (" 1.6.0" )
348396 override def load (path : String ): KMeans = super .load(path)
397+
398+ @ Since (" 2.2.0" )
399+ override def read : MLReader [KMeans ] = new KMeansReader
400+
401+ /** [[MLWriter ]] instance for [[KMeans ]] */
402+ private [KMeans ] class KMeansWriter (instance : KMeans ) extends MLWriter {
403+ override protected def saveImpl (path : String ): Unit = {
404+ DefaultParamsWriter .saveInitialModel(instance, path)
405+ DefaultParamsWriter .saveMetadata(instance, path, sc)
406+ }
407+ }
408+
409+ private class KMeansReader extends MLReader [KMeans ] {
410+
411+ /** Checked against metadata when loading estimator */
412+ private val className = classOf [KMeans ].getName
413+
414+ override def load (path : String ): KMeans = {
415+ val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
416+ val instance = new KMeans (metadata.uid)
417+
418+ DefaultParamsReader .getAndSetParams(instance, metadata)
419+ DefaultParamsReader .loadInitialModel[KMeansModel ](path, sc) match {
420+ case Some (m) => instance.setInitialModel(m)
421+ case None => // initialModel doesn't exist, do nothing
422+ }
423+ instance
424+ }
425+ }
349426}
350427
351428/**
0 commit comments