Skip to content

Commit 36b1729

Browse files
committed
add two setters for initial model
1 parent cc13c1e commit 36b1729

File tree

1 file changed

+17
-0
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/clustering

1 file changed

+17
-0
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,23 @@ class KMeans @Since("1.5.0") (
241241
@Since("2.0.0")
242242
def setInitialModel(value: KMeansModel): this.type = set(initialModel, value)
243243

244+
/** @group setParam */
245+
@Since("2.0.0")
246+
def setInitialModel(value: Model[_]): this.type = {
247+
value match {
248+
case m: KMeansModel => set(initialModel, m)
249+
case other =>
250+
logInfo(s"KMeansModel required but ${other.getClass.getSimpleName} found.")
251+
this
252+
}
253+
}
254+
255+
/** @group setParam */
256+
@Since("2.0.0")
257+
def setInitialModel(clusterCenters: Array[Vector]): this.type = {
258+
set(initialModel, new KMeansModel("initial model", new MLlibKMeansModel(clusterCenters)))
259+
}
260+
244261
@Since("1.5.0")
245262
override def fit(dataset: DataFrame): KMeansModel = {
246263
val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }

0 commit comments

Comments
 (0)