Skip to content

Commit cc13c1e

Browse files
committed
add initial model to kmeans
1 parent 140ddef commit cc13c1e

File tree

4 files changed

+79
-6
lines changed

4 files changed

+79
-6
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{IntegerType, StructType}
3434
* Common params for KMeans and KMeansModel
3535
*/
3636
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
37-
with HasSeed with HasPredictionCol with HasTol {
37+
with HasSeed with HasPredictionCol with HasTol with HasInitialModel[KMeansModel] {
3838

3939
/**
4040
* Set the number of clusters to create (k). Must be > 1. Default: 2.
@@ -96,7 +96,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
9696
@Experimental
9797
class KMeansModel private[ml] (
9898
@Since("1.5.0") override val uid: String,
99-
private val parentModel: MLlibKMeansModel)
99+
private[ml] val parentModel: MLlibKMeansModel)
100100
extends Model[KMeansModel] with KMeansParams with MLWritable {
101101

102102
@Since("1.5.0")
@@ -237,6 +237,10 @@ class KMeans @Since("1.5.0") (
237237
@Since("1.5.0")
238238
def setSeed(value: Long): this.type = set(seed, value)
239239

240+
/** @group setParam */
241+
@Since("2.0.0")
242+
def setInitialModel(value: KMeansModel): this.type = set(initialModel, value)
243+
240244
@Since("1.5.0")
241245
override def fit(dataset: DataFrame): KMeansModel = {
242246
val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
@@ -248,6 +252,11 @@ class KMeans @Since("1.5.0") (
248252
.setMaxIterations($(maxIter))
249253
.setSeed($(seed))
250254
.setEpsilon($(tol))
255+
256+
if (isSet(initialModel)) {
257+
algo.setInitialModel($(initialModel).parentModel)
258+
}
259+
251260
val parentModel = algo.run(rdd)
252261
val model = new KMeansModel(uid, parentModel)
253262
copyValues(model)

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,24 @@ private[shared] object SharedParamsCodeGen {
7878
ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " +
7979
"empty, default value is 'auto'.", Some("\"auto\"")))
8080

81-
val code = genSharedParams(params)
81+
// scalastyle:off
82+
val extras: Seq[String] = Seq(
83+
"""
84+
|private[ml] trait HasInitialModel[T <: Model[T]] extends Params {
85+
|
86+
| /**
87+
| * Param for initial model of warm start.
88+
| * @group param
89+
| */
90+
| final val initialModel: Param[T] = new Param[T](this, "initial model", "initial model of warm-start")
91+
|
92+
| /** @group getParam */
93+
| final def getInitialModel: T = $(initialModel)
94+
|}
95+
|""".stripMargin)
96+
// scalastyle:on
97+
98+
val code = genSharedParams(params, extras)
8299
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
83100
val writer = new PrintWriter(file)
84101
writer.write(code)
@@ -174,7 +191,7 @@ private[shared] object SharedParamsCodeGen {
174191
}
175192

176193
/** Generates Scala source code for the input params with header. */
177-
private def genSharedParams(params: Seq[ParamDesc[_]]): String = {
194+
private def genSharedParams(params: Seq[ParamDesc[_]], extras: Seq[String] = Nil): String = {
178195
val header =
179196
"""/*
180197
| * Licensed to the Apache Software Foundation (ASF) under one or more
@@ -195,6 +212,7 @@ private[shared] object SharedParamsCodeGen {
195212
|
196213
|package org.apache.spark.ml.param.shared
197214
|
215+
|import org.apache.spark.ml.Model
198216
|import org.apache.spark.ml.param._
199217
|
200218
|// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
@@ -205,7 +223,8 @@ private[shared] object SharedParamsCodeGen {
205223
val footer = "// scalastyle:on\n"
206224

207225
val traits = params.map(genHasParamTrait).mkString
226+
val extraTraits = extras.mkString
208227

209-
header + traits + footer
228+
header + traits + extraTraits + footer
210229
}
211230
}

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.ml.param.shared
1919

20+
import org.apache.spark.ml.Model
2021
import org.apache.spark.ml.param._
2122

2223
// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
@@ -389,4 +390,16 @@ private[ml] trait HasSolver extends Params {
389390
/** @group getParam */
390391
final def getSolver: String = $(solver)
391392
}
393+
394+
private[ml] trait HasInitialModel[T <: Model[T]] extends Params {
395+
396+
/**
397+
* Param for initial model of warm start.
398+
* @group param
399+
*/
400+
final val initialModel: Param[T] = new Param[T](this, "initial model", "initial model of warm-start")
401+
402+
/** @group getParam */
403+
final def getInitialModel: T = $(initialModel)
404+
}
392405
// scalastyle:on

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.ml.util.DefaultReadWriteTest
22-
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
22+
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
2323
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2424
import org.apache.spark.mllib.util.MLlibTestSparkContext
2525
import org.apache.spark.sql.{DataFrame, SQLContext}
@@ -106,6 +106,38 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
106106
val kmeans = new KMeans()
107107
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
108108
}
109+
110+
test("Initialize using given cluster centers") {
111+
val points = Array(
112+
Vectors.dense(0.0, 0.0, 0.0),
113+
Vectors.dense(1.0, 1.0, 1.0),
114+
Vectors.dense(2.0, 2.0, 2.0),
115+
Vectors.dense(3.0, 3.0, 3.0),
116+
Vectors.dense(4.0, 4.0, 4.0)
117+
)
118+
119+
// creating an initial model
120+
val initialModel = new KMeansModel("test model", new MLlibKMeansModel(points))
121+
122+
val predictionColName = "kmeans_prediction"
123+
val kmeans = new KMeans()
124+
.setK(k)
125+
.setPredictionCol(predictionColName)
126+
.setSeed(1)
127+
.setInitialModel(initialModel)
128+
val model = kmeans.fit(dataset)
129+
assert(model.clusterCenters.length === k)
130+
131+
val transformed = model.transform(dataset)
132+
val expectedColumns = Array("features", predictionColName)
133+
expectedColumns.foreach { column =>
134+
assert(transformed.columns.contains(column))
135+
}
136+
val clusters =
137+
transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet
138+
assert(clusters.size === k)
139+
assert(clusters === Set(0, 1, 2, 3, 4))
140+
}
109141
}
110142

111143
object KMeansSuite {

0 commit comments

Comments
 (0)