diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index e02b532ca8a9..a82bdb4d6baf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -37,9 +37,9 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion /** - * Common params for KMeans and KMeansModel + * Common params for KMeans and KMeansModel. */ -private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol +private[clustering] trait KMeansModelParams extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol with HasTol { /** @@ -59,12 +59,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe /** * Param for the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ - * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. + * (Bahmani et al., Scalable K-Means++, VLDB 2012), or "initialModel" to use a user provided + * initial model for warm start. Default: k-means||. + * If this was set as "initialModel", users must specify the initial model by `setInitialModel`, + * otherwise, throws IllegalArgumentException. * @group expertParam */ @Since("1.5.0") final val initMode = new Param[String](this, "initMode", "The initialization algorithm. " + - "Supported options: 'random' and 'k-means||'.", + "Supported options: 'random', 'k-means||' and 'initialModel'.", (value: String) => MLlibKMeans.validateInitMode(value)) /** @group expertGetParam */ @@ -95,6 +98,22 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe } } +/** + * Common params for KMeans. + */ +private[clustering] trait KMeansParams extends KMeansModelParams with HasInitialModel[KMeansModel] { + + /** + * A KMeansModel to use for warm start. + * Note the cluster count of initial model must be equal with [[k]], + * otherwise, throws IllegalArgumentException. + * @group param + */ + @Since("2.2.0") + final val initialModel: Param[KMeansModel] = + new Param[KMeansModel](this, "initialModel", "A KMeansModel to use for warm start.") +} + /** * Model fitted by KMeans. * @@ -103,8 +122,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with MLWritable { + private[clustering] val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansModelParams with MLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -123,7 +142,8 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val predictUDF = udf((vector: Vector) => predict(vector)) + val localParent: MLlibKMeansModel = parentModel + val predictUDF = udf((vector: Vector) => localParent.predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -132,8 +152,6 @@ class KMeansModel private[ml] ( validateAndTransformSchema(schema) } - private[clustering] def predict(features: Vector): Int = parentModel.predict(features) - @Since("2.0.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) @@ -253,7 +271,7 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.5.0") class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { + extends Estimator[KMeansModel] with KMeansParams with MLWritable { setDefault( k -> 2, @@ -300,6 +318,10 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + @Since("2.2.0") + def setInitialModel(value: KMeansModel): this.type = set(initialModel, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { transformSchema(dataset.schema, logging = true) @@ -322,6 +344,18 @@ class KMeans @Since("1.5.0") ( .setMaxIterations($(maxIter)) .setSeed($(seed)) .setEpsilon($(tol)) + + if ($(initMode) == MLlibKMeans.K_MEANS_INITIAL_MODEL && isSet(initialModel)) { + // Check that the feature dimensions are equal + val numFeatures = instances.first().size + val dimOfInitialModel = $(initialModel).clusterCenters.head.size + require(numFeatures == dimOfInitialModel, + s"The number of features in training dataset is $numFeatures," + + s" which mismatched with dimension of initial model: $dimOfInitialModel.") + + algo.setInitialModel($(initialModel).parentModel) + } + val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( @@ -335,17 +369,70 @@ class KMeans @Since("1.5.0") ( model } + /** + * Check validity for interactions between parameters. + */ + private def assertInitialModelValid(): Unit = { + if ($(initMode) == MLlibKMeans.K_MEANS_INITIAL_MODEL) { + if (isSet(initialModel)) { + val initialModelK = $(initialModel).parentModel.k + if (initialModelK != $(k)) { + throw new IllegalArgumentException("The initial model's cluster count = " + + s"$initialModelK, mismatched with k = $k.") + } + } else { + throw new IllegalArgumentException("Users must set param initialModel if you choose " + + "'initialModel' as the initialization algorithm.") + } + } else { + if (isSet(initialModel)) { + logWarning(s"Param initialModel will take no effect when initMode is $initMode.") + } + } + } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { + assertInitialModelValid() validateAndTransformSchema(schema) } + + @Since("2.2.0") + override def write: MLWriter = new KMeans.KMeansWriter(this) } @Since("1.6.0") -object KMeans extends DefaultParamsReadable[KMeans] { +object KMeans extends MLReadable[KMeans] { @Since("1.6.0") override def load(path: String): KMeans = super.load(path) + + @Since("2.2.0") + override def read: MLReader[KMeans] = new KMeansReader + + /** [[MLWriter]] instance for [[KMeans]] */ + private[KMeans] class KMeansWriter(instance: KMeans) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveInitialModel(instance, path) + DefaultParamsWriter.saveMetadata(instance, path, sc) + } + } + + private class KMeansReader extends MLReader[KMeans] { + + override def load(path: String): KMeans = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, classOf[KMeans].getName) + val instance = new KMeans(metadata.uid) + + DefaultParamsReader.getAndSetParams(instance, metadata) + DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { + case Some(m) => instance.setInitialModel(m) + case None => // initialModel doesn't exist, do nothing + } + instance + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala new file mode 100644 index 000000000000..c67380edaa60 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.ml.Model +import org.apache.spark.ml.param._ + +private[ml] trait HasInitialModel[T <: Model[T]] extends Params { + + def initialModel: Param[T] + + /** @group getParam */ + final def getInitialModel: T = $(initialModel) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 09bddcdb810b..6da531a22c09 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -32,6 +32,7 @@ import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.ml.param.shared.HasInitialModel import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.util.Utils @@ -279,6 +280,8 @@ private[ml] object DefaultParamsWriter { * Helper for [[saveMetadata()]] which extracts the JSON to save. * This is useful for ensemble models which need to save metadata for many sub-models. * + * Note: This function does not handle param `initialModel`, see [[saveInitialModel()]]. + * * @see [[saveMetadata()]] for details on what this includes. */ def getMetadataToSave( @@ -288,7 +291,8 @@ private[ml] object DefaultParamsWriter { paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val params = instance.extractParamMap().toSeq + .filter(_.param.name != "initialModel").asInstanceOf[Seq[ParamPair[Any]]] val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) @@ -306,6 +310,23 @@ private[ml] object DefaultParamsWriter { val metadataJson: String = compact(render(metadata)) metadataJson } + + /** + * Save estimator's `initialModel` to corresponding path. + */ + def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]]( + instance: T, path: String): Unit = { + if (instance.isDefined(instance.initialModel)) { + val initialModelPath = new Path(path, "initialModel").toString + val initialModel = instance.getOrDefault(instance.initialModel) + // When saving, only keep the direct initialModel by eliminating possible initialModels of the + // direct initialModel, to avoid unnecessary deep recursion of initialModel. + if (initialModel.hasParam("initialModel")) { + initialModel.clear(initialModel.getParam("initialModel")) + } + initialModel.save(initialModelPath) + } + } } /** @@ -434,6 +455,23 @@ private[ml] object DefaultParamsReader { val cls = Utils.classForName(metadata.className) cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) } + + /** + * Load estimator's `initialModel` instance from the given path, and return it. + * If the `initialModel` path does not exist, it means the estimator does not have or + * set param `initialModel`, then return None. + * This assumes the model implements [[MLReadable]]. + */ + def loadInitialModel[M <: Model[M]](path: String, sc: SparkContext): Option[M] = { + val hadoopConf = sc.hadoopConfiguration + val initialModelPath = new Path(path, "initialModel") + val fs = initialModelPath.getFileSystem(hadoopConf) + if (fs.exists(initialModelPath)) { + Some(loadParamsInstance[M](initialModelPath.toString, sc)) + } else { + None + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fa72b72e2d92..49f86d3ea490 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -418,6 +418,8 @@ object KMeans { val RANDOM = "random" @Since("0.8.0") val K_MEANS_PARALLEL = "k-means||" + @Since("2.2.0") + val K_MEANS_INITIAL_MODEL = "initialModel" /** * Trains a k-means model using the given set of parameters. @@ -593,6 +595,7 @@ object KMeans { initMode match { case KMeans.RANDOM => true case KMeans.K_MEANS_PARALLEL => true + case KMeans.K_MEANS_INITIAL_MODEL => true case _ => false } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e10127f7d108..ccde24dd42fa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -22,8 +22,10 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -31,13 +33,17 @@ private[clustering] case class TestRow(features: Vector) class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + final val k = 5 @transient var dataset: Dataset[_] = _ + @transient var rData: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) + rData = GaussianMixtureSuite.rData.map(Tuple1.apply).toDF("features") } test("default parameters") { @@ -152,6 +158,35 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val kmeans = new KMeans() testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) } + + test("training with initial model") { + val kmeans = new KMeans().setK(2).setSeed(1) + val model1 = kmeans.fit(rData) + val model2 = kmeans.setInitMode("initialModel").setInitialModel(model1).fit(rData) + model2.clusterCenters.zip(model1.clusterCenters) + .foreach { case (center2, center1) => assert(center2 ~== center1 absTol 1E-8) } + } + + test("training with initial model, error cases") { + val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(1) + + // Sets initMode with 'initialModel', but does not specify initial model. + intercept[IllegalArgumentException] { + kmeans.setInitMode("initialModel").fit(dataset) + } + + // Training dataset dimension mismatched. + val modelWithDiffDim = KMeansSuite.generateRandomKMeansModel(4, k) + intercept[IllegalArgumentException] { + kmeans.setInitMode("initialModel").setInitialModel(modelWithDiffDim).fit(dataset) + } + + // Mismatched cluster count between initial model and param k. + val initialModel = KMeansSuite.generateRandomKMeansModel(3, k + 1) + intercept[IllegalArgumentException] { + kmeans.setInitMode("initialModel").setInitialModel(initialModel).fit(dataset) + } + } } object KMeansSuite { @@ -173,6 +208,13 @@ object KMeansSuite { spark.createDataFrame(rdd) } + def generateRandomKMeansModel(dim: Int, k: Int, seed: Int = 42): KMeansModel = { + val rng = new Random(seed) + val clusterCenters = (1 to k) + .map(i => MLlibVectors.dense(Array.fill(dim)(rng.nextDouble))) + new KMeansModel(Identifiable.randomUID("kmeans"), new MLlibKMeansModel(clusterCenters.toArray)) + } + /** * Mapping from all Params to valid settings which differ from the defaults. * This is useful for tests which need to exercise all Params, such as save/load. @@ -182,6 +224,7 @@ object KMeansSuite { "predictionCol" -> "myPrediction", "k" -> 3, "maxIter" -> 2, - "tol" -> 0.01 + "tol" -> 0.01, + "initialModel" -> generateRandomKMeansModel(3, 3) ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 553b8725b30a..543296db3898 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -59,15 +59,17 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => assert(newInstance.uid === instance.uid) if (testParams) { instance.params.foreach { p => - if (instance.isDefined(p)) { - (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { - case (Array(values), Array(newValues)) => - assert(values === newValues, s"Values do not match on param ${p.name}.") - case (value, newValue) => - assert(value === newValue, s"Values do not match on param ${p.name}.") + if (p.name != "initialModel") { + if (instance.isDefined(p)) { + (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { + case (Array(values), Array(newValues)) => + assert(values === newValues, s"Values do not match on param ${p.name}.") + case (value, newValue) => + assert(value === newValue, s"Values do not match on param ${p.name}.") + } + } else { + assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } - } else { - assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } } } @@ -111,12 +113,20 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => val estimator2 = testDefaultReadWrite(estimator) testParams.foreach { case (p, v) => val param = estimator.getParam(p) - assert(estimator.get(param).get === estimator2.get(param).get) + if (param.name == "initialModel") { + // Estimator's `initialModel` has same type as the model produced by this estimator. + // So we can use `checkModelData` to check equality of `initialModel` as well. + checkModelData(estimator.get(param).get.asInstanceOf[M], + estimator2.get(param).get.asInstanceOf[M]) + } else { + assert(estimator.get(param).get === estimator2.get(param).get) + } } // Test Model save/load val model2 = testDefaultReadWrite(model) - testParams.foreach { case (p, v) => + // Model does not extend HasInitialModel, so we don't check it. + testParams.filter(_._1 != "initialModel").foreach { case (p, v) => val param = model.getParam(p) assert(model.get(param).get === model2.get(param).get) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 511686fb4f37..64fdc2ad8958 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -948,6 +948,15 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy") + ) ++ Seq( + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.KMeans$"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.KMeans"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.KMeansModel"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.clustering.KMeansModelParams.org$apache$spark$ml$clustering$KMeansModelParams$_setter_$initMode_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.clustering.KMeansModelParams.org$apache$spark$ml$clustering$KMeansModelParams$_setter_$initSteps_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.clustering.KMeansModelParams.org$apache$spark$ml$clustering$KMeansModelParams$_setter_$k_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasInitialModel.initialModel"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasInitialModel.getInitialModel") ) }