-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-10780][ML] Support initial model for KMeans. #17117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ddd8d86
2824d85
bbad291
4226149
7d842e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,9 +37,9 @@ import org.apache.spark.storage.StorageLevel | |
| import org.apache.spark.util.VersionUtils.majorVersion | ||
|
|
||
| /** | ||
| * Common params for KMeans and KMeansModel | ||
| * Common params for KMeans and KMeansModel. | ||
| */ | ||
| private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol | ||
| private[clustering] trait KMeansModelParams extends Params with HasMaxIter with HasFeaturesCol | ||
| with HasSeed with HasPredictionCol with HasTol { | ||
|
|
||
| /** | ||
|
|
@@ -59,12 +59,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe | |
| /** | ||
| * Param for the initialization algorithm. This can be either "random" to choose random points as | ||
| * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ | ||
| * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. | ||
| * (Bahmani et al., Scalable K-Means++, VLDB 2012), or "initialModel" to use a user provided | ||
| * initial model for warm start. Default: k-means||. | ||
| * If this was set as "initialModel", users must specify the initial model by `setInitialModel`, | ||
| * otherwise, throws IllegalArgumentException. | ||
| * @group expertParam | ||
| */ | ||
| @Since("1.5.0") | ||
| final val initMode = new Param[String](this, "initMode", "The initialization algorithm. " + | ||
| "Supported options: 'random' and 'k-means||'.", | ||
| "Supported options: 'random', 'k-means||' and 'initialModel'.", | ||
| (value: String) => MLlibKMeans.validateInitMode(value)) | ||
|
|
||
| /** @group expertGetParam */ | ||
|
|
@@ -95,6 +98,22 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Common params for KMeans. | ||
| */ | ||
| private[clustering] trait KMeansParams extends KMeansModelParams with HasInitialModel[KMeansModel] { | ||
|
|
||
| /** | ||
| * A KMeansModel to use for warm start. | ||
| * Note the cluster count of initial model must be equal with [[k]], | ||
| * otherwise, throws IllegalArgumentException. | ||
| * @group param | ||
| */ | ||
| @Since("2.2.0") | ||
| final val initialModel: Param[KMeansModel] = | ||
| new Param[KMeansModel](this, "initialModel", "A KMeansModel to use for warm start.") | ||
| } | ||
|
|
||
| /** | ||
| * Model fitted by KMeans. | ||
| * | ||
|
|
@@ -103,8 +122,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe | |
| @Since("1.5.0") | ||
| class KMeansModel private[ml] ( | ||
| @Since("1.5.0") override val uid: String, | ||
| private val parentModel: MLlibKMeansModel) | ||
| extends Model[KMeansModel] with KMeansParams with MLWritable { | ||
| private[clustering] val parentModel: MLlibKMeansModel) | ||
| extends Model[KMeansModel] with KMeansModelParams with MLWritable { | ||
|
|
||
| @Since("1.5.0") | ||
| override def copy(extra: ParamMap): KMeansModel = { | ||
|
|
@@ -123,7 +142,8 @@ class KMeansModel private[ml] ( | |
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema, logging = true) | ||
| val predictUDF = udf((vector: Vector) => predict(vector)) | ||
| val localParent: MLlibKMeansModel = parentModel | ||
| val predictUDF = udf((vector: Vector) => localParent.predict(vector)) | ||
| dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | ||
| } | ||
|
|
||
|
|
@@ -132,8 +152,6 @@ class KMeansModel private[ml] ( | |
| validateAndTransformSchema(schema) | ||
| } | ||
|
|
||
| private[clustering] def predict(features: Vector): Int = parentModel.predict(features) | ||
|
|
||
| @Since("2.0.0") | ||
| def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) | ||
|
|
||
|
|
@@ -253,7 +271,7 @@ object KMeansModel extends MLReadable[KMeansModel] { | |
| @Since("1.5.0") | ||
| class KMeans @Since("1.5.0") ( | ||
| @Since("1.5.0") override val uid: String) | ||
| extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { | ||
| extends Estimator[KMeansModel] with KMeansParams with MLWritable { | ||
|
|
||
| setDefault( | ||
| k -> 2, | ||
|
|
@@ -300,6 +318,10 @@ class KMeans @Since("1.5.0") ( | |
| @Since("1.5.0") | ||
| def setSeed(value: Long): this.type = set(seed, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setInitialModel(value: KMeansModel): this.type = set(initialModel, value) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about def setInitialModel(value: KMeansModel): this.type = {
if (getK ~= value.getK) {
log the warning
set(k, value)
}
set(initMode, MLlibKMeans.K_MEANS_INITIAL_MODEL) // We may log, but I don't really care for this one.
set(initialModel, value)
}
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can not override param in any Users should get the same model regardless of the way to set param. I think the only way to override param is in the start of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate this? I don't fully understand why we can not overwrite setting in set method. Thanks. |
||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): KMeansModel = { | ||
| transformSchema(dataset.schema, logging = true) | ||
|
|
@@ -322,6 +344,18 @@ class KMeans @Since("1.5.0") ( | |
| .setMaxIterations($(maxIter)) | ||
| .setSeed($(seed)) | ||
| .setEpsilon($(tol)) | ||
|
|
||
| if ($(initMode) == MLlibKMeans.K_MEANS_INITIAL_MODEL && isSet(initialModel)) { | ||
| // Check that the feature dimensions are equal | ||
| val numFeatures = instances.first().size | ||
| val dimOfInitialModel = $(initialModel).clusterCenters.head.size | ||
| require(numFeatures == dimOfInitialModel, | ||
| s"The number of features in training dataset is $numFeatures," + | ||
| s" which mismatched with dimension of initial model: $dimOfInitialModel.") | ||
|
|
||
| algo.setInitialModel($(initialModel).parentModel) | ||
| } | ||
|
|
||
| val parentModel = algo.run(instances, Option(instr)) | ||
| val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) | ||
| val summary = new KMeansSummary( | ||
|
|
@@ -335,17 +369,70 @@ class KMeans @Since("1.5.0") ( | |
| model | ||
| } | ||
|
|
||
| /** | ||
| * Check validity for interactions between parameters. | ||
| */ | ||
| private def assertInitialModelValid(): Unit = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think with overwriting above, the only thing we need to check will be if ($(initMode) == MLlibKMeans.K_MEANS_INITIAL_MODEL && !isSet(initialModel)) {
throw new IllegalArgumentException("Users must set param initialModel if you choose " +
"'initialModel' as the initialization.")
}we can just have it in the body of |
||
| if ($(initMode) == MLlibKMeans.K_MEANS_INITIAL_MODEL) { | ||
| if (isSet(initialModel)) { | ||
| val initialModelK = $(initialModel).parentModel.k | ||
| if (initialModelK != $(k)) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this check is needed if we overwrite |
||
| throw new IllegalArgumentException("The initial model's cluster count = " + | ||
| s"$initialModelK, mismatched with k = $k.") | ||
| } | ||
| } else { | ||
| throw new IllegalArgumentException("Users must set param initialModel if you choose " + | ||
| "'initialModel' as the initialization algorithm.") | ||
| } | ||
| } else { | ||
| if (isSet(initialModel)) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, this is not needed if we do the overwriting work in |
||
| logWarning(s"Param initialModel will take no effect when initMode is $initMode.") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| @Since("1.5.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| assertInitialModelValid() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this is not checked in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If we have small logic in checking, I'll have those checking code in |
||
| validateAndTransformSchema(schema) | ||
| } | ||
|
|
||
| @Since("2.2.0") | ||
| override def write: MLWriter = new KMeans.KMeansWriter(this) | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| object KMeans extends DefaultParamsReadable[KMeans] { | ||
| object KMeans extends MLReadable[KMeans] { | ||
|
|
||
| @Since("1.6.0") | ||
| override def load(path: String): KMeans = super.load(path) | ||
|
|
||
| @Since("2.2.0") | ||
| override def read: MLReader[KMeans] = new KMeansReader | ||
|
|
||
| /** [[MLWriter]] instance for [[KMeans]] */ | ||
| private[KMeans] class KMeansWriter(instance: KMeans) extends MLWriter { | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| DefaultParamsWriter.saveInitialModel(instance, path) | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| } | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was trying to move |
||
| } | ||
|
|
||
| private class KMeansReader extends MLReader[KMeans] { | ||
|
|
||
| override def load(path: String): KMeans = { | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, classOf[KMeans].getName) | ||
| val instance = new KMeans(metadata.uid) | ||
|
|
||
| DefaultParamsReader.getAndSetParams(instance, metadata) | ||
| DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be done as: DefaultParamsReader.loadInitialModel[KMeansModel](path, sc).foreach(instance.setInitialModel)I think it's nicer, but I'm not sure if there is a universal preference for side effects with options in Spark, so I'll leave it to you to decide.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, your suggestion can work well, but I'm more prefer to my way, since it's more clear for developer to understand what happened. |
||
| case Some(m) => instance.setInitialModel(m) | ||
| case None => // initialModel doesn't exist, do nothing | ||
| } | ||
| instance | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.param.shared | ||
|
|
||
| import org.apache.spark.ml.Model | ||
| import org.apache.spark.ml.param._ | ||
|
|
||
| private[ml] trait HasInitialModel[T <: Model[T]] extends Params { | ||
|
|
||
| def initialModel: Param[T] | ||
|
|
||
| /** @group getParam */ | ||
| final def getInitialModel: T = $(initialModel) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -418,6 +418,8 @@ object KMeans { | |
| val RANDOM = "random" | ||
| @Since("0.8.0") | ||
| val K_MEANS_PARALLEL = "k-means||" | ||
| @Since("2.2.0") | ||
| val K_MEANS_INITIAL_MODEL = "initialModel" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can be private I think. That, or we should update the valid options for the |
||
|
|
||
| /** | ||
| * Trains a k-means model using the given set of parameters. | ||
|
|
@@ -593,6 +595,7 @@ object KMeans { | |
| initMode match { | ||
| case KMeans.RANDOM => true | ||
| case KMeans.K_MEANS_PARALLEL => true | ||
| case KMeans.K_MEANS_INITIAL_MODEL => true | ||
| case _ => false | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now,
KMeansModelmixesKMeansModelParams, does it mean in the model level, we can not get the information of theinitiModel? Also, in the model, why do we need to mix the seed in?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we decided in the previous discussion to not store the initial model in the produced model, for several reasons, including model serialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough.