From cc13c1e46cbbf11a3495943381c28a9045b3d514 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 8 Feb 2016 11:32:08 -0800 Subject: [PATCH 01/42] add initial model to kmeans --- .../apache/spark/ml/clustering/KMeans.scala | 13 +++++-- .../ml/param/shared/SharedParamsCodeGen.scala | 25 ++++++++++++-- .../spark/ml/param/shared/sharedParams.scala | 13 +++++++ .../spark/ml/clustering/KMeansSuite.scala | 34 ++++++++++++++++++- 4 files changed, 79 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index dc6d5d928097..365a6eb08abe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{IntegerType, StructType} * Common params for KMeans and KMeansModel */ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol - with HasSeed with HasPredictionCol with HasTol { + with HasSeed with HasPredictionCol with HasTol with HasInitialModel[KMeansModel] { /** * Set the number of clusters to create (k). Must be > 1. Default: 2. @@ -96,7 +96,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Experimental class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) + private[ml] val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams with MLWritable { @Since("1.5.0") @@ -237,6 +237,10 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + @Since("2.0.0") + def setInitialModel(value: KMeansModel): this.type = set(initialModel, value) + @Since("1.5.0") override def fit(dataset: DataFrame): KMeansModel = { val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } @@ -248,6 +252,11 @@ class KMeans @Since("1.5.0") ( .setMaxIterations($(maxIter)) .setSeed($(seed)) .setEpsilon($(tol)) + + if (isSet(initialModel)) { + algo.setInitialModel($(initialModel).parentModel) + } + val parentModel = algo.run(rdd) val model = new KMeansModel(uid, parentModel) copyValues(model) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 4aff749ff75a..85753d41ac86 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -78,7 +78,24 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " + "empty, default value is 'auto'.", Some("\"auto\""))) - val code = genSharedParams(params) + // scalastyle:off + val extras: Seq[String] = Seq( + """ + |private[ml] trait HasInitialModel[T <: Model[T]] extends Params { + | + | /** + | * Param for initial model of warm start. + | * @group param + | */ + | final val initialModel: Param[T] = new Param[T](this, "initial model", "initial model of warm-start") + | + | /** @group getParam */ + | final def getInitialModel: T = $(initialModel) + |} + |""".stripMargin) + // scalastyle:on + + val code = genSharedParams(params, extras) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" val writer = new PrintWriter(file) writer.write(code) @@ -174,7 +191,7 @@ private[shared] object SharedParamsCodeGen { } /** Generates Scala source code for the input params with header. */ - private def genSharedParams(params: Seq[ParamDesc[_]]): String = { + private def genSharedParams(params: Seq[ParamDesc[_]], extras: Seq[String] = Nil): String = { val header = """/* | * Licensed to the Apache Software Foundation (ASF) under one or more @@ -195,6 +212,7 @@ private[shared] object SharedParamsCodeGen { | |package org.apache.spark.ml.param.shared | + |import org.apache.spark.ml.Model |import org.apache.spark.ml.param._ | |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. @@ -205,7 +223,8 @@ private[shared] object SharedParamsCodeGen { val footer = "// scalastyle:on\n" val traits = params.map(genHasParamTrait).mkString + val extraTraits = extras.mkString - header + traits + footer + header + traits + extraTraits + footer } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index c088c16d1b05..23d564530a73 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.param.shared +import org.apache.spark.ml.Model import org.apache.spark.ml.param._ // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. @@ -389,4 +390,16 @@ private[ml] trait HasSolver extends Params { /** @group getParam */ final def getSolver: String = $(solver) } + +private[ml] trait HasInitialModel[T <: Model[T]] extends Params { + + /** + * Param for initial model of warm start. + * @group param + */ + final val initialModel: Param[T] = new Param[T](this, "initial model", "initial model of warm-start") + + /** @group getParam */ + final def getInitialModel: T = $(initialModel) +} // scalastyle:on diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2724e51f31aa..7b16572d4b3e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, SQLContext} @@ -106,6 +106,38 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val kmeans = new KMeans() testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) } + + test("Initialize using given cluster centers") { + val points = Array( + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(1.0, 1.0, 1.0), + Vectors.dense(2.0, 2.0, 2.0), + Vectors.dense(3.0, 3.0, 3.0), + Vectors.dense(4.0, 4.0, 4.0) + ) + + // creating an initial model + val initialModel = new KMeansModel("test model", new MLlibKMeansModel(points)) + + val predictionColName = "kmeans_prediction" + val kmeans = new KMeans() + .setK(k) + .setPredictionCol(predictionColName) + .setSeed(1) + .setInitialModel(initialModel) + val model = kmeans.fit(dataset) + assert(model.clusterCenters.length === k) + + val transformed = model.transform(dataset) + val expectedColumns = Array("features", predictionColName) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + val clusters = + transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + } } object KMeansSuite { From 36b17292296c1e2afeb8b6fb56839c32ff5f9cdc Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 8 Feb 2016 12:04:36 -0800 Subject: [PATCH 02/42] add two setters for initial model --- .../org/apache/spark/ml/clustering/KMeans.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 365a6eb08abe..aa847d0dc8a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -241,6 +241,23 @@ class KMeans @Since("1.5.0") ( @Since("2.0.0") def setInitialModel(value: KMeansModel): this.type = set(initialModel, value) + /** @group setParam */ + @Since("2.0.0") + def setInitialModel(value: Model[_]): this.type = { + value match { + case m: KMeansModel => set(initialModel, m) + case other => + logInfo(s"KMeansModel required but ${other.getClass.getSimpleName} found.") + this + } + } + + /** @group setParam */ + @Since("2.0.0") + def setInitialModel(clusterCenters: Array[Vector]): this.type = { + set(initialModel, new KMeansModel("initial model", new MLlibKMeansModel(clusterCenters))) + } + @Since("1.5.0") override def fit(dataset: DataFrame): KMeansModel = { val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } From 125ac7667ed3b41e625d0555c60b16997c52b697 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 9 Feb 2016 16:00:19 -0800 Subject: [PATCH 03/42] revert to previous codegen and add a separate sharedParams for general types --- .../ml/param/shared/SharedParamsCodeGen.scala | 25 ++------------ .../shared/sharedGeneralTypeParams.scala | 34 +++++++++++++++++++ .../spark/ml/param/shared/sharedParams.scala | 13 ------- 3 files changed, 37 insertions(+), 35 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 85753d41ac86..4aff749ff75a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -78,24 +78,7 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " + "empty, default value is 'auto'.", Some("\"auto\""))) - // scalastyle:off - val extras: Seq[String] = Seq( - """ - |private[ml] trait HasInitialModel[T <: Model[T]] extends Params { - | - | /** - | * Param for initial model of warm start. - | * @group param - | */ - | final val initialModel: Param[T] = new Param[T](this, "initial model", "initial model of warm-start") - | - | /** @group getParam */ - | final def getInitialModel: T = $(initialModel) - |} - |""".stripMargin) - // scalastyle:on - - val code = genSharedParams(params, extras) + val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" val writer = new PrintWriter(file) writer.write(code) @@ -191,7 +174,7 @@ private[shared] object SharedParamsCodeGen { } /** Generates Scala source code for the input params with header. */ - private def genSharedParams(params: Seq[ParamDesc[_]], extras: Seq[String] = Nil): String = { + private def genSharedParams(params: Seq[ParamDesc[_]]): String = { val header = """/* | * Licensed to the Apache Software Foundation (ASF) under one or more @@ -212,7 +195,6 @@ private[shared] object SharedParamsCodeGen { | |package org.apache.spark.ml.param.shared | - |import org.apache.spark.ml.Model |import org.apache.spark.ml.param._ | |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. @@ -223,8 +205,7 @@ private[shared] object SharedParamsCodeGen { val footer = "// scalastyle:on\n" val traits = params.map(genHasParamTrait).mkString - val extraTraits = extras.mkString - header + traits + extraTraits + footer + header + traits + footer } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala new file mode 100644 index 000000000000..cae80685cdac --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.ml.Model +import org.apache.spark.ml.param._ + +private[ml] trait HasInitialModel[T <: Model[T]] extends Params { + + /** + * Param for initial model to warm start. + * @group param + */ + final val initialModel: Param[T] = + new Param[T](this, "initial model", "initial model to warm start") + + /** @group getParam */ + final def getInitialModel: T = $(initialModel) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 23d564530a73..c088c16d1b05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.param.shared -import org.apache.spark.ml.Model import org.apache.spark.ml.param._ // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. @@ -390,16 +389,4 @@ private[ml] trait HasSolver extends Params { /** @group getParam */ final def getSolver: String = $(solver) } - -private[ml] trait HasInitialModel[T <: Model[T]] extends Params { - - /** - * Param for initial model of warm start. - * @group param - */ - final val initialModel: Param[T] = new Param[T](this, "initial model", "initial model of warm-start") - - /** @group getParam */ - final def getInitialModel: T = $(initialModel) -} // scalastyle:on From 658c4c95bee30117c22af843557c3940c815d428 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 9 Feb 2016 16:08:56 -0800 Subject: [PATCH 04/42] add model check --- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index aa847d0dc8a9..d2d87bbc2e36 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -245,9 +245,9 @@ class KMeans @Since("1.5.0") ( @Since("2.0.0") def setInitialModel(value: Model[_]): this.type = { value match { - case m: KMeansModel => set(initialModel, m) + case m: KMeansModel => setInitialModel(m) case other => - logInfo(s"KMeansModel required but ${other.getClass.getSimpleName} found.") + logWarning(s"KMeansModel required but ${other.getClass.getSimpleName} found.") this } } @@ -255,7 +255,7 @@ class KMeans @Since("1.5.0") ( /** @group setParam */ @Since("2.0.0") def setInitialModel(clusterCenters: Array[Vector]): this.type = { - set(initialModel, new KMeansModel("initial model", new MLlibKMeansModel(clusterCenters))) + setInitialModel(new KMeansModel("initial model", new MLlibKMeansModel(clusterCenters))) } @Since("1.5.0") @@ -271,6 +271,8 @@ class KMeans @Since("1.5.0") ( .setEpsilon($(tol)) if (isSet(initialModel)) { + require($(initialModel).parentModel.clusterCenters.length == $(k), "mismatched cluster count") + require(rdd.first().size == $(initialModel).clusterCenters.head.size, "mismatched dimension") algo.setInitialModel($(initialModel).parentModel) } From abfe0e27b02e866c1ec19ba0066590d7af73181c Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 9 Feb 2016 16:24:38 -0800 Subject: [PATCH 05/42] add more testsuite --- .../apache/spark/ml/clustering/KMeansSuite.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 7b16572d4b3e..01bcbbacfa7a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, SQLContext} private[clustering] case class TestRow(features: Vector) @@ -126,6 +127,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR .setSeed(1) .setInitialModel(initialModel) val model = kmeans.fit(dataset) + assert(model.clusterCenters.length === k) val transformed = model.transform(dataset) @@ -133,10 +135,17 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR expectedColumns.foreach { column => assert(transformed.columns.contains(column)) } - val clusters = - transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet + + val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) + + // Converged initial model should lead to only a single iteration. + val convergedModel = kmeans.setInitialModel(model).fit(dataset).clusterCenters + val oneIterationModel = kmeans.setInitialModel(model).setMaxIter(1).fit(dataset).clusterCenters + convergedModel.zip(oneIterationModel).foreach { case (center1, center2) => + assert(center1 ~== center2 absTol 1E-8) + } } } From b87e07e284b9462433a4c8345ea8dde314c30313 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 10 Feb 2016 14:09:17 -0800 Subject: [PATCH 06/42] add new save/load to kmeans --- .../apache/spark/ml/clustering/KMeans.scala | 53 +++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index d2d87bbc2e36..cd0fc022f31f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{IntegerType, StructType} * Common params for KMeans and KMeansModel */ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol - with HasSeed with HasPredictionCol with HasTol with HasInitialModel[KMeansModel] { + with HasSeed with HasPredictionCol with HasTol { /** * Set the number of clusters to create (k). Must be > 1. Default: 2. @@ -190,7 +190,8 @@ object KMeansModel extends MLReadable[KMeansModel] { @Experimental class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { + extends Estimator[KMeansModel] + with KMeansParams with HasInitialModel[KMeansModel] with MLWritable { setDefault( k -> 2, @@ -285,12 +286,58 @@ class KMeans @Since("1.5.0") ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + @Since("2.0.0") + override def write: MLWriter = new KMeans.KMeansWriter(this) } @Since("1.6.0") -object KMeans extends DefaultParamsReadable[KMeans] { +object KMeans extends MLReadable[KMeans] { @Since("1.6.0") override def load(path: String): KMeans = super.load(path) + + @Since("2.0.0") + override def read: MLReader[KMeans] = new KMeansReader + + /** [[MLWriter]] instance for [[KMeans]] */ + private[KMeans] class KMeansWriter(instance: KMeans) extends MLWriter { + import org.json4s.JsonDSL._ + + override protected def saveImpl(path: String): Unit = { + val extraMetadata = if (instance.isSet(instance.initialModel)) { + val initialModelPath = new Path(path, "initial-model").toString + instance.getInitialModel.save(initialModelPath) + instance.clear(instance.initialModel) + "hasInitialModel" -> true + } else { + "hasInitialModel" -> false + } + + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + } + } + + private class KMeansReader extends MLReader[KMeans] { + + /** Checked against metadata when loading model */ + private val className = classOf[KMeans].getName + + override def load(path: String): KMeans = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val instance = new KMeans(metadata.uid) + DefaultParamsReader.getAndSetParams(instance, metadata) + + val hasInitialModel = (metadata.metadata \ "hasInitialModel").extract[Boolean] + if (hasInitialModel) { + val initialModelPath = new Path(path, "initial-model").toString + val initialModel = KMeansModel.load(initialModelPath) + instance.setInitialModel(initialModel) + } + + instance + } + } } From 9a4a55edeadcf53c5bf37f4ac2c610bd9e9d6b22 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 10 Feb 2016 14:49:41 -0800 Subject: [PATCH 07/42] add new model save/load for KMeansModel --- .../apache/spark/ml/clustering/KMeans.scala | 30 ++++++++++++++++--- .../spark/ml/clustering/KMeansSuite.scala | 23 +++++++------- 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index cd0fc022f31f..9624ad81e993 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} @@ -34,7 +35,7 @@ import org.apache.spark.sql.types.{IntegerType, StructType} * Common params for KMeans and KMeansModel */ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol - with HasSeed with HasPredictionCol with HasTol { + with HasSeed with HasPredictionCol with HasTol with HasInitialModel[KMeansModel] { /** * Set the number of clusters to create (k). Must be > 1. Default: 2. @@ -148,12 +149,23 @@ object KMeansModel extends MLReadable[KMeansModel] { /** [[MLWriter]] instance for [[KMeansModel]] */ private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { + import org.json4s.JsonDSL._ private case class Data(clusterCenters: Array[Vector]) override protected def saveImpl(path: String): Unit = { + val extraMetadata = if (instance.isSet(instance.initialModel)) { + val initialModelPath = new Path(path, "initial-model").toString + instance.getInitialModel.save(initialModelPath) + instance.clear(instance.initialModel) + "hasInitialModel" -> true + } else { + "hasInitialModel" -> false + } + // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + // Save model data: cluster centers val data = Data(instance.clusterCenters) val dataPath = new Path(path, "data").toString @@ -162,6 +174,7 @@ object KMeansModel extends MLReadable[KMeansModel] { } private class KMeansModelReader extends MLReader[KMeansModel] { + implicit val format = DefaultFormats /** Checked against metadata when loading model */ private val className = classOf[KMeansModel].getName @@ -175,6 +188,14 @@ object KMeansModel extends MLReadable[KMeansModel] { val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) + + val hasInitialModel = (metadata.metadata \ "hasInitialModel").extract[Boolean] + if (hasInitialModel) { + val initialModelPath = new Path(path, "initial-model").toString + val initialModel = KMeansModel.load(initialModelPath) + model.set(model.initialModel, initialModel) + } + model } } @@ -190,8 +211,7 @@ object KMeansModel extends MLReadable[KMeansModel] { @Experimental class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] - with KMeansParams with HasInitialModel[KMeansModel] with MLWritable { + extends Estimator[KMeansModel] with KMeansParams with MLWritable { setDefault( k -> 2, @@ -325,6 +345,8 @@ object KMeans extends MLReadable[KMeans] { private val className = classOf[KMeans].getName override def load(path: String): KMeans = { + implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val instance = new KMeans(metadata.uid) DefaultParamsReader.getAndSetParams(instance, metadata) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 01bcbbacfa7a..5e377bafe6ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} @@ -31,11 +33,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR final val k = 5 @transient var dataset: DataFrame = _ + @transient var initialModel: KMeansModel = _ override def beforeAll(): Unit = { super.beforeAll() dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + initialModel = KMeansSuite.generateKMeansModel(3, k) } test("default parameters") { @@ -109,17 +113,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } test("Initialize using given cluster centers") { - val points = Array( - Vectors.dense(0.0, 0.0, 0.0), - Vectors.dense(1.0, 1.0, 1.0), - Vectors.dense(2.0, 2.0, 2.0), - Vectors.dense(3.0, 3.0, 3.0), - Vectors.dense(4.0, 4.0, 4.0) - ) - - // creating an initial model - val initialModel = new KMeansModel("test model", new MLlibKMeansModel(points)) - val predictionColName = "kmeans_prediction" val kmeans = new KMeans() .setK(k) @@ -157,6 +150,11 @@ object KMeansSuite { sql.createDataFrame(rdd) } + def generateKMeansModel(dim: Int, k: Int): KMeansModel = { + val clusterCenters = (1 to k).map(i => Vectors.dense(Array.fill(dim)(Random.nextDouble))) + new KMeansModel("test model", new MLlibKMeansModel(clusterCenters.toArray)) + } + /** * Mapping from all Params to valid settings which differ from the defaults. * This is useful for tests which need to exercise all Params, such as save/load. @@ -166,6 +164,7 @@ object KMeansSuite { "predictionCol" -> "myPrediction", "k" -> 3, "maxIter" -> 2, - "tol" -> 0.01 + "tol" -> 0.01, + "initialModel" -> generateKMeansModel(5, 3) ) } From 65f4237990f08ccaf88ad8291061881459018adc Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 10 Feb 2016 23:33:24 -0800 Subject: [PATCH 08/42] fix side effect --- .../apache/spark/ml/clustering/KMeans.scala | 40 ++++++++++++------- .../shared/sharedGeneralTypeParams.scala | 2 +- .../spark/ml/clustering/KMeansSuite.scala | 16 +------- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 9624ad81e993..0657d2eb5b83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -154,18 +154,24 @@ object KMeansModel extends MLReadable[KMeansModel] { private case class Data(clusterCenters: Array[Vector]) override protected def saveImpl(path: String): Unit = { - val extraMetadata = if (instance.isSet(instance.initialModel)) { + if (instance.isSet(instance.initialModel)) { val initialModelPath = new Path(path, "initial-model").toString - instance.getInitialModel.save(initialModelPath) + val initialModel = instance.getInitialModel + initialModel.save(initialModelPath) + + // Remove the initialModel temporarily instance.clear(instance.initialModel) - "hasInitialModel" -> true + + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> true)) + + // Set the initialModel back to avoid making side effect on instance + instance.set(instance.initialModel, initialModel) } else { - "hasInitialModel" -> false + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> false)) } - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) - // Save model data: cluster centers val data = Data(instance.clusterCenters) val dataPath = new Path(path, "data").toString @@ -325,17 +331,23 @@ object KMeans extends MLReadable[KMeans] { import org.json4s.JsonDSL._ override protected def saveImpl(path: String): Unit = { - val extraMetadata = if (instance.isSet(instance.initialModel)) { + if (instance.isSet(instance.initialModel)) { val initialModelPath = new Path(path, "initial-model").toString - instance.getInitialModel.save(initialModelPath) + val initialModel = instance.getInitialModel + initialModel.save(initialModelPath) + + // Remove the initialModel temporarily instance.clear(instance.initialModel) - "hasInitialModel" -> true + + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> true)) + + // Set the initialModel back to avoid making side effect on instance + instance.set(instance.initialModel, initialModel) } else { - "hasInitialModel" -> false + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> false)) } - - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala index cae80685cdac..5dc6c3ee62b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala @@ -27,7 +27,7 @@ private[ml] trait HasInitialModel[T <: Model[T]] extends Params { * @group param */ final val initialModel: Param[T] = - new Param[T](this, "initial model", "initial model to warm start") + new Param[T](this, "initialModel", "initial model to warm start") /** @group getParam */ final def getInitialModel: T = $(initialModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 5e377bafe6ad..e8229ba570ec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -113,26 +113,12 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } test("Initialize using given cluster centers") { - val predictionColName = "kmeans_prediction" val kmeans = new KMeans() .setK(k) - .setPredictionCol(predictionColName) .setSeed(1) .setInitialModel(initialModel) val model = kmeans.fit(dataset) - assert(model.clusterCenters.length === k) - - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - - val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) - // Converged initial model should lead to only a single iteration. val convergedModel = kmeans.setInitialModel(model).fit(dataset).clusterCenters val oneIterationModel = kmeans.setInitialModel(model).setMaxIter(1).fit(dataset).clusterCenters @@ -165,6 +151,6 @@ object KMeansSuite { "k" -> 3, "maxIter" -> 2, "tol" -> 0.01, - "initialModel" -> generateKMeansModel(5, 3) + "initialModel" -> generateKMeansModel(3, 3) ) } From 166a6fffcfb9ec8aacdcc91ce827450fca0e79d2 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 10 Feb 2016 23:58:03 -0800 Subject: [PATCH 09/42] add hashcode and equals --- .../org/apache/spark/ml/clustering/KMeans.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 0657d2eb5b83..787b485ecd26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -136,6 +136,18 @@ class KMeansModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + + override def hashCode(): Int = + this.getClass.hashCode() + uid.hashCode() + clusterCenters.map(_.hashCode()).sum + + override def equals(other: Any): Boolean = other match { + case that: KMeansModel => + this.uid == that.uid && + this.clusterCenters.length == that.clusterCenters.length && + this.clusterCenters.zip(that.clusterCenters) + .foldLeft(true) { case (indicator, (v1, v2)) => indicator && (v1 == v2) } + case _ => false + } } @Since("1.6.0") From f3f92261b8da1e24a5d013f24f17c87c19b60ec8 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 7 Mar 2016 12:13:28 -0800 Subject: [PATCH 10/42] add } --- mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index c98bf1dab9f6..a1fec251f722 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -148,6 +148,7 @@ class KMeansModel private[ml] ( this.clusterCenters.zip(that.clusterCenters) .foldLeft(true) { case (indicator, (v1, v2)) => indicator && (v1 == v2) } case _ => false + } private var trainingSummary: Option[KMeansSummary] = None From 08afa4cdb7dde17e443faab1b08cc1f85d5d10d4 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 8 Mar 2016 16:57:18 -0800 Subject: [PATCH 11/42] filter initialModel out --- .../org/apache/spark/ml/clustering/KMeans.scala | 12 ------------ .../scala/org/apache/spark/ml/util/ReadWrite.scala | 3 ++- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a1fec251f722..90cd591e82c3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -192,14 +192,8 @@ object KMeansModel extends MLReadable[KMeansModel] { val initialModel = instance.getInitialModel initialModel.save(initialModelPath) - // Remove the initialModel temporarily - instance.clear(instance.initialModel) - // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> true)) - - // Set the initialModel back to avoid making side effect on instance - instance.set(instance.initialModel, initialModel) } else { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> false)) @@ -370,14 +364,8 @@ object KMeans extends MLReadable[KMeans] { val initialModel = instance.getInitialModel initialModel.save(initialModelPath) - // Remove the initialModel temporarily - instance.clear(instance.initialModel) - // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> true)) - - // Set the initialModel back to avoid making side effect on instance - instance.set(instance.initialModel, initialModel) } else { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> false)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7b2504361a6e..c83fb426b058 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -218,7 +218,8 @@ private[ml] object DefaultParamsWriter { paramMap: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val params = instance.extractParamMap().toSeq + .filter(_.param.name != "initialModel").asInstanceOf[Seq[ParamPair[Any]]] val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) From 31f7b94caab2790cbbe99091acf4b614e34cf41c Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 9 Mar 2016 21:51:34 -0800 Subject: [PATCH 12/42] new equal --- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 90cd591e82c3..77303550dacd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -138,15 +138,14 @@ class KMeansModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) - override def hashCode(): Int = - this.getClass.hashCode() + uid.hashCode() + clusterCenters.map(_.hashCode()).sum + override def hashCode(): Int = { + (Array(this.getClass, uid) ++ clusterCenters) + .foldLeft(17) { case (hash, obj) => hash * 31 + obj.hashCode() } + } override def equals(other: Any): Boolean = other match { case that: KMeansModel => - this.uid == that.uid && - this.clusterCenters.length == that.clusterCenters.length && - this.clusterCenters.zip(that.clusterCenters) - .foldLeft(true) { case (indicator, (v1, v2)) => indicator && (v1 == v2) } + this.uid == that.uid && this.clusterCenters.sameElements(that.clusterCenters) case _ => false } From 7c1c8f7782c37932a4791610ce132dc66b50e62a Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 14 Mar 2016 12:35:48 -0700 Subject: [PATCH 13/42] reinse test --- .../spark/ml/clustering/KMeansSuite.scala | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 4f2b5e7271dd..bbf1fc4bad0b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -32,14 +32,12 @@ private[clustering] case class TestRow(features: Vector) class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 + final val initialModel = KMeansSuite.generateKMeansModel(3, k, seed = 14) @transient var dataset: DataFrame = _ - @transient var initialModel: KMeansModel = _ override def beforeAll(): Unit = { super.beforeAll() - dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) - initialModel = KMeansSuite.generateKMeansModel(3, k) } test("default parameters") { @@ -118,13 +116,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val kmeans = new KMeans() .setK(k) .setSeed(1) - .setInitialModel(initialModel) - val model = kmeans.fit(dataset) + .setMaxIter(1000) // Set a fairly high maxIter to make sure the model is converged. + val convergedModel = kmeans.fit(dataset).clusterCenters // Converged initial model should lead to only a single iteration. - val convergedModel = kmeans.setInitialModel(model).fit(dataset).clusterCenters - val oneIterationModel = kmeans.setInitialModel(model).setMaxIter(1).fit(dataset).clusterCenters - convergedModel.zip(oneIterationModel).foreach { case (center1, center2) => + val oneMoreIterationModel = + kmeans.setInitialModel(convergedModel).setMaxIter(1).fit(dataset).clusterCenters + convergedModel.zip(oneMoreIterationModel).foreach { case (center1, center2) => assert(center1 ~== center2 absTol 1E-8) } } @@ -138,8 +136,9 @@ object KMeansSuite { sql.createDataFrame(rdd) } - def generateKMeansModel(dim: Int, k: Int): KMeansModel = { - val clusterCenters = (1 to k).map(i => Vectors.dense(Array.fill(dim)(Random.nextDouble))) + def generateKMeansModel(dim: Int, k: Int, seed: Int = 42): KMeansModel = { + val clusterCenters = (1 to k) + .map(i => Vectors.dense(Array.fill(dim)(new Random(seed).nextDouble))) new KMeansModel("test model", new MLlibKMeansModel(clusterCenters.toArray)) } From 95627b5672ecda575cd27df343af28273dd92210 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 14 Mar 2016 23:29:16 -0700 Subject: [PATCH 14/42] refine KMeans --- .../apache/spark/ml/clustering/KMeans.scala | 40 +++++++++++++------ .../spark/ml/clustering/KMeansSuite.scala | 20 ++++------ 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 77303550dacd..113648bb5b5b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -86,6 +86,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } + + override def validateParams(): Unit = { + super.validateParams() + if (isSet(initialModel)) { + val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length + require(kOfInitialModel == $(k), + s"${$(k)} cluster centers required but $kOfInitialModel found in the initial model.") + } + } } /** @@ -187,7 +196,7 @@ object KMeansModel extends MLReadable[KMeansModel] { override protected def saveImpl(path: String): Unit = { if (instance.isSet(instance.initialModel)) { - val initialModelPath = new Path(path, "initial-model").toString + val initialModelPath = new Path(path, "initialModel").toString val initialModel = instance.getInitialModel initialModel.save(initialModelPath) @@ -221,11 +230,14 @@ object KMeansModel extends MLReadable[KMeansModel] { DefaultParamsReader.getAndSetParams(model, metadata) - val hasInitialModel = (metadata.metadata \ "hasInitialModel").extract[Boolean] - if (hasInitialModel) { - val initialModelPath = new Path(path, "initial-model").toString - val initialModel = KMeansModel.load(initialModelPath) - model.set(model.initialModel, initialModel) + // Try to load initial model after version 2.0.0. + if (metadata.sparkVersion.split("\\.").head.toInt >= 2) { + val hasInitialModel = (metadata.metadata \ "hasInitialModel").extract[Boolean] + if (hasInitialModel) { + val initialModelPath = new Path(path, "initialModel").toString + val initialModel = KMeansModel.load(initialModelPath) + model.set(model.initialModel, initialModel) + } } model @@ -324,7 +336,6 @@ class KMeans @Since("1.5.0") ( .setEpsilon($(tol)) if (isSet(initialModel)) { - require($(initialModel).parentModel.clusterCenters.length == $(k), "mismatched cluster count") require(rdd.first().size == $(initialModel).clusterCenters.head.size, "mismatched dimension") algo.setInitialModel($(initialModel).parentModel) } @@ -359,7 +370,7 @@ object KMeans extends MLReadable[KMeans] { override protected def saveImpl(path: String): Unit = { if (instance.isSet(instance.initialModel)) { - val initialModelPath = new Path(path, "initial-model").toString + val initialModelPath = new Path(path, "initialModel").toString val initialModel = instance.getInitialModel initialModel.save(initialModelPath) @@ -384,11 +395,14 @@ object KMeans extends MLReadable[KMeans] { val instance = new KMeans(metadata.uid) DefaultParamsReader.getAndSetParams(instance, metadata) - val hasInitialModel = (metadata.metadata \ "hasInitialModel").extract[Boolean] - if (hasInitialModel) { - val initialModelPath = new Path(path, "initial-model").toString - val initialModel = KMeansModel.load(initialModelPath) - instance.setInitialModel(initialModel) + // Try to load initial model after version 2.0.0. + if (metadata.sparkVersion.split("\\.").head.toInt >= 2) { + val hasInitialModel = (metadata.metadata \ "hasInitialModel").extract[Boolean] + if (hasInitialModel) { + val initialModelPath = new Path(path, "initialModel").toString + val initialModel = KMeansModel.load(initialModelPath) + instance.setInitialModel(initialModel) + } } instance diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index bbf1fc4bad0b..5981110c51ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.clustering import scala.util.Random import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -113,18 +114,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } test("Initialize using given cluster centers") { - val kmeans = new KMeans() - .setK(k) - .setSeed(1) - .setMaxIter(1000) // Set a fairly high maxIter to make sure the model is converged. - val convergedModel = kmeans.fit(dataset).clusterCenters - - // Converged initial model should lead to only a single iteration. - val oneMoreIterationModel = - kmeans.setInitialModel(convergedModel).setMaxIter(1).fit(dataset).clusterCenters - convergedModel.zip(oneMoreIterationModel).foreach { case (center1, center2) => - assert(center1 ~== center2 absTol 1E-8) - } + val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(1) + val oneIterModel = kmeans.fit(dataset) + val twoIterModel = kmeans.copy(ParamMap(ParamPair(kmeans.maxIter, 2))).fit(dataset) + val oneMoreIterModel = kmeans.setInitialModel(oneIterModel).fit(dataset) + + twoIterModel.clusterCenters.zip(oneMoreIterModel.clusterCenters) + .foreach { case (center1, center2) => assert(center1 ~== center2 absTol 1E-8) } } } From 58bf1cf47e42c1e5e6a6e5db996ca29f1cb17f28 Mon Sep 17 00:00:00 2001 From: yinxusen Date: Fri, 22 Apr 2016 14:14:07 -0700 Subject: [PATCH 15/42] change back to DefaultParamsWritable/Readable --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index aa0e5c3fb351..ac615ec5b55d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -257,7 +257,7 @@ object KMeansModel extends MLReadable[KMeansModel] { @Experimental class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] with KMeansParams with MLWritable { + extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { setDefault( k -> 2, @@ -364,7 +364,7 @@ class KMeans @Since("1.5.0") ( } @Since("1.6.0") -object KMeans extends MLReadable[KMeans] { +object KMeans extends DefaultParamsReadable[KMeans] { @Since("1.6.0") override def load(path: String): KMeans = super.load(path) From b7856e18d180dd03c4388f3081de501710fc9fab Mon Sep 17 00:00:00 2001 From: yinxusen Date: Wed, 27 Apr 2016 19:00:24 -0700 Subject: [PATCH 16/42] add initialmodel metadata to default read write --- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 9 +++------ .../main/scala/org/apache/spark/ml/util/ReadWrite.scala | 7 ++++++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index ac615ec5b55d..53f98cf34df0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -201,14 +201,11 @@ object KMeansModel extends MLReadable[KMeansModel] { val initialModelPath = new Path(path, "initialModel").toString val initialModel = instance.getInitialModel initialModel.save(initialModelPath) - - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> true)) - } else { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> false)) } + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers val data = Data(instance.clusterCenters) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index bc50f77a5ae5..50544dd5bfc9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -249,11 +249,16 @@ private[ml] object DefaultParamsWriter { val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) + // If the instance has an "initialModel" param and the param is defined, then the initial model + // will be saved along with the instance. + val initialModelFlag = + instance.hasParam("initialModel") && instance.isDefined(instance.getParam("initialModel")) val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("initialModel" -> initialModelFlag) val metadata = extraMetadata match { case Some(jObject) => basicMetadata ~ jObject From 9f5e6987351b31bdc3482e2ed26c3354dbd1b741 Mon Sep 17 00:00:00 2001 From: yinxusen Date: Wed, 27 Apr 2016 19:28:59 -0700 Subject: [PATCH 17/42] add save/load for initial model --- .../apache/spark/ml/clustering/KMeans.scala | 45 +++---------------- .../org/apache/spark/ml/util/ReadWrite.scala | 25 +++++++++++ 2 files changed, 31 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 53f98cf34df0..62a2bec37c02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -197,14 +197,9 @@ object KMeansModel extends MLReadable[KMeansModel] { private case class Data(clusterCenters: Array[Vector]) override protected def saveImpl(path: String): Unit = { - if (instance.isSet(instance.initialModel)) { - val initialModelPath = new Path(path, "initialModel").toString - val initialModel = instance.getInitialModel - initialModel.save(initialModelPath) - } - // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) + DefaultParamsWriter.saveInitialModel(instance, path) // Save model data: cluster centers val data = Data(instance.clusterCenters) @@ -228,17 +223,7 @@ object KMeansModel extends MLReadable[KMeansModel] { val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) - - // Try to load initial model after version 2.0.0. - if (metadata.sparkVersion.split("\\.").head.toInt >= 2) { - val hasInitialModel = (metadata.metadata \ "hasInitialModel").extract[Boolean] - if (hasInitialModel) { - val initialModelPath = new Path(path, "initialModel").toString - val initialModel = KMeansModel.load(initialModelPath) - model.set(model.initialModel, initialModel) - } - } - + DefaultParamsReader.loadInitialModel[KMeansModel](model, path, sc) model } } @@ -374,17 +359,8 @@ object KMeans extends DefaultParamsReadable[KMeans] { import org.json4s.JsonDSL._ override protected def saveImpl(path: String): Unit = { - if (instance.isSet(instance.initialModel)) { - val initialModelPath = new Path(path, "initialModel").toString - val initialModel = instance.getInitialModel - initialModel.save(initialModelPath) - - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> true)) - } else { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc, Some("hasInitialModel" -> false)) - } + DefaultParamsWriter.saveInitialModel(instance, path) + DefaultParamsWriter.saveMetadata(instance, path, sc) } } @@ -398,18 +374,9 @@ object KMeans extends DefaultParamsReadable[KMeans] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val instance = new KMeans(metadata.uid) - DefaultParamsReader.getAndSetParams(instance, metadata) - - // Try to load initial model after version 2.0.0. - if (metadata.sparkVersion.split("\\.").head.toInt >= 2) { - val hasInitialModel = (metadata.metadata \ "hasInitialModel").extract[Boolean] - if (hasInitialModel) { - val initialModelPath = new Path(path, "initialModel").toString - val initialModel = KMeansModel.load(initialModelPath) - instance.setInitialModel(initialModel) - } - } + DefaultParamsReader.getAndSetParams(instance, metadata) + DefaultParamsReader.loadInitialModel[KMeansModel](instance, path, sc) instance } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 50544dd5bfc9..bf558433e779 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -268,6 +268,17 @@ private[ml] object DefaultParamsWriter { val metadataJson: String = compact(render(metadata)) metadataJson } + + def saveInitialModel(instance: Params, path: String): Unit = { + val initialModelFlag = + instance.hasParam("initialModel") && instance.isDefined(instance.getParam("initialModel")) + if (initialModelFlag) { + val initialModelPath = new Path(path, "initialModel").toString + val initialModel = instance.getOrDefault(instance.getParam("initialModel")) + assert(initialModel.isInstanceOf[MLWritable]) + initialModel.asInstanceOf[MLWritable].save(initialModelPath) + } + } } /** @@ -396,6 +407,20 @@ private[ml] object DefaultParamsReader { val cls = Utils.classForName(metadata.className) cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) } + + def loadInitialModel[M <: Model[M]](instance: Params, path: String, sc: SparkContext): Unit = { + implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sc) + // Try to load initial model after version 2.0.0. + if (metadata.sparkVersion.split("\\.").head.toInt >= 2) { + val hasInitialModel = (metadata.metadata \ "initialModel").extract[Boolean] + if (hasInitialModel) { + val initialModelPath = new Path(path, "initialModel").toString + val initialModel = loadParamsInstance[Model[M]](initialModelPath, sc) + instance.set(instance.getParam("initialModel"), initialModel) + } + } + } } /** From 914d31991e10dd0d77f03e167f19147ff696834c Mon Sep 17 00:00:00 2001 From: yinxusen Date: Wed, 27 Apr 2016 19:52:42 -0700 Subject: [PATCH 18/42] remove validateParams --- .../org/apache/spark/ml/clustering/KMeans.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 62a2bec37c02..db427f4e98ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -82,17 +82,13 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - } - - override def validateParams(): Unit = { - super.validateParams() - if (isSet(initialModel)) { + if (isDefined(initialModel)) { val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length require(kOfInitialModel == $(k), s"${$(k)} cluster centers required but $kOfInitialModel found in the initial model.") } + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -322,8 +318,11 @@ class KMeans @Since("1.5.0") ( .setSeed($(seed)) .setEpsilon($(tol)) - if (isSet(initialModel)) { - require(rdd.first().size == $(initialModel).clusterCenters.head.size, "mismatched dimension") + if (isDefined(initialModel)) { + val dimOfData = rdd.first().size + val dimOfInitialModel = $(initialModel).clusterCenters.head.size + require(dimOfData == dimOfInitialModel, + s"mismatched dimension, $dimOfData in data while $dimOfInitialModel in the initial model.") algo.setInitialModel($(initialModel).parentModel) } From c40192b0579080f4af572cf6d12bf37942c03866 Mon Sep 17 00:00:00 2001 From: yinxusen Date: Wed, 27 Apr 2016 20:32:47 -0700 Subject: [PATCH 19/42] remove useless DefaultFormats --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index db427f4e98ee..2ba9f447be56 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -18,7 +18,6 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path -import org.json4s.DefaultFormats import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} @@ -205,8 +204,6 @@ object KMeansModel extends MLReadable[KMeansModel] { } private class KMeansModelReader extends MLReader[KMeansModel] { - implicit val format = DefaultFormats - /** Checked against metadata when loading model */ private val className = classOf[KMeansModel].getName @@ -369,8 +366,6 @@ object KMeans extends DefaultParamsReadable[KMeans] { private val className = classOf[KMeans].getName override def load(path: String): KMeans = { - implicit val format = DefaultFormats - val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val instance = new KMeans(metadata.uid) From 23a78d6b7d802cf826b031e1f6b25f364816c4f8 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 31 Aug 2016 00:02:12 -0700 Subject: [PATCH 20/42] fix vector issue --- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 9d11e0bcea7c..cda23825b445 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -340,7 +340,9 @@ class KMeans @Since("1.5.0") ( /** @group setParam */ @Since("2.0.0") def setInitialModel(clusterCenters: Array[Vector]): this.type = { - setInitialModel(new KMeansModel("initial model", new MLlibKMeansModel(clusterCenters))) + setInitialModel( + new KMeansModel("initial model", + new MLlibKMeansModel(clusterCenters.map(OldVectors.fromML)))) } @Since("2.0.0") From d4f59d9b2331df89b2745ed6050634defeaee08d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 2 Sep 2016 14:22:41 -0700 Subject: [PATCH 21/42] multi fixes --- .../apache/spark/ml/clustering/KMeans.scala | 16 ++++----- .../org/apache/spark/ml/util/ReadWrite.scala | 18 +++++----- .../spark/ml/clustering/KMeansSuite.scala | 36 +++++++++++++++++-- 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index cda23825b445..4c2a2a03697c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -259,7 +259,7 @@ object KMeansModel extends MLReadable[KMeansModel] { } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) - DefaultParamsReader.loadInitialModel[KMeansModel](model, path, sc) + DefaultParamsReader.loadAndSetInitialModel[KMeansModel](model, metadata, path, sc) model } } @@ -323,11 +323,11 @@ class KMeans @Since("1.5.0") ( def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ - @Since("2.0.0") + @Since("2.1.0") def setInitialModel(value: KMeansModel): this.type = set(initialModel, value) /** @group setParam */ - @Since("2.0.0") + @Since("2.1.0") def setInitialModel(value: Model[_]): this.type = { value match { case m: KMeansModel => setInitialModel(m) @@ -338,7 +338,7 @@ class KMeans @Since("1.5.0") ( } /** @group setParam */ - @Since("2.0.0") + @Since("2.1.0") def setInitialModel(clusterCenters: Array[Vector]): this.type = { setInitialModel( new KMeansModel("initial model", @@ -385,7 +385,7 @@ class KMeans @Since("1.5.0") ( validateAndTransformSchema(schema) } - @Since("2.0.0") + @Since("2.1.0") override def write: MLWriter = new KMeans.KMeansWriter(this) } @@ -395,13 +395,11 @@ object KMeans extends DefaultParamsReadable[KMeans] { @Since("1.6.0") override def load(path: String): KMeans = super.load(path) - @Since("2.0.0") + @Since("2.1.0") override def read: MLReader[KMeans] = new KMeansReader /** [[MLWriter]] instance for [[KMeans]] */ private[KMeans] class KMeansWriter(instance: KMeans) extends MLWriter { - import org.json4s.JsonDSL._ - override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveInitialModel(instance, path) DefaultParamsWriter.saveMetadata(instance, path, sc) @@ -418,7 +416,7 @@ object KMeans extends DefaultParamsReadable[KMeans] { val instance = new KMeans(metadata.uid) DefaultParamsReader.getAndSetParams(instance, metadata) - DefaultParamsReader.loadInitialModel[KMeansModel](instance, path, sc) + DefaultParamsReader.loadAndSetInitialModel[KMeansModel](instance, metadata, path, sc) instance } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7d99a9e2093d..b697cb78b512 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -32,6 +32,7 @@ import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.ml.param.shared.HasInitialModel import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.util.Utils @@ -314,7 +315,9 @@ private[ml] object DefaultParamsWriter { ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ ("paramMap" -> jsonParams) ~ + // TODO: Figure out more robust way to detect the existing of the initialModel. ("initialModel" -> initialModelFlag) + val metadata = extraMetadata match { case Some(jObject) => basicMetadata ~ jObject @@ -325,13 +328,10 @@ private[ml] object DefaultParamsWriter { metadataJson } - def saveInitialModel(instance: Params, path: String): Unit = { - val initialModelFlag = - instance.hasParam("initialModel") && instance.isDefined(instance.getParam("initialModel")) - if (initialModelFlag) { + def saveInitialModel[T <: HasInitialModel[_ <: MLWritable]](instance: T, path: String): Unit = { + if (instance.isDefined(instance.getParam("initialModel"))) { val initialModelPath = new Path(path, "initialModel").toString val initialModel = instance.getOrDefault(instance.getParam("initialModel")) - assert(initialModel.isInstanceOf[MLWritable]) initialModel.asInstanceOf[MLWritable].save(initialModelPath) } } @@ -464,11 +464,11 @@ private[ml] object DefaultParamsReader { cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) } - def loadInitialModel[M <: Model[M]](instance: Params, path: String, sc: SparkContext): Unit = { + def loadAndSetInitialModel[M <: Model[M]]( + instance: HasInitialModel[M], metadata: Metadata, path: String, sc: SparkContext): Unit = { implicit val format = DefaultFormats - val metadata = DefaultParamsReader.loadMetadata(path, sc) - // Try to load initial model after version 2.0.0. - if (metadata.sparkVersion.split("\\.").head.toInt >= 2) { + // Try to load the initial model + if (metadata.metadata \ "initialModel" != JNothing) { val hasInitialModel = (metadata.metadata \ "initialModel").extract[Boolean] if (hasInitialModel) { val initialModelPath = new Path(path, "initialModel").toString diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 3bcb872864e0..114f4aaa7c79 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -20,21 +20,22 @@ package org.apache.spark.ml.clustering import scala.util.Random import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.Model import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamPair} -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.types.StructType private[clustering] case class TestRow(features: Vector) class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 - final val initialModel = KMeansSuite.generateKMeansModel(3, k, seed = 14) @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { @@ -155,9 +156,37 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR twoIterModel.clusterCenters.zip(oneMoreIterModel.clusterCenters) .foreach { case (center1, center2) => assert(center1 ~== center2 absTol 1E-8) } } + + test("Initialize using wrong model") { + val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(10) + val wrongTypeModel = new KMeansSuite.MockModel() + assert(!kmeans.isSet(kmeans.initialModel)) + + val wrongKModel = KMeansSuite.generateKMeansModel(3, k + 1) + intercept[IllegalArgumentException] { + kmeans.setInitialModel(wrongKModel).fit(dataset) + } + + val wrongDimModel = KMeansSuite.generateKMeansModel(4, k) + intercept[IllegalArgumentException] { + kmeans.setInitialModel(wrongDimModel).fit(dataset) + } + } } object KMeansSuite { + + class MockModel(override val uid: String) extends Model[MockModel] { + + def this() = this(Identifiable.randomUID("mockModel")) + + override def copy(extra: ParamMap): MockModel = throw new NotImplementedError() + + override def transform(dataset: Dataset[_]): DataFrame = throw new NotImplementedError() + + override def transformSchema(schema: StructType): StructType = throw new NotImplementedError() + } + def generateKMeansData(spark: SparkSession, rows: Int, dim: Int, k: Int): DataFrame = { val sc = spark.sparkContext val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) @@ -166,8 +195,9 @@ object KMeansSuite { } def generateKMeansModel(dim: Int, k: Int, seed: Int = 42): KMeansModel = { + val rng = new Random(seed) val clusterCenters = (1 to k) - .map(i => MLlibVectors.dense(Array.fill(dim)(new Random(seed).nextDouble))) + .map(i => MLlibVectors.dense(Array.fill(dim)(rng.nextDouble))) new KMeansModel("test model", new MLlibKMeansModel(clusterCenters.toArray)) } From 47f182b88242dbc2fa198591de5099b5644f4076 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 2 Sep 2016 14:29:19 -0700 Subject: [PATCH 22/42] fix not set initialmodel --- .../test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 114f4aaa7c79..8af99326a932 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -160,7 +160,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR test("Initialize using wrong model") { val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(10) val wrongTypeModel = new KMeansSuite.MockModel() - assert(!kmeans.isSet(kmeans.initialModel)) + assert(!kmeans.setInitialModel(wrongTypeModel).isSet(kmeans.initialModel)) val wrongKModel = KMeansSuite.generateKMeansModel(3, k + 1) intercept[IllegalArgumentException] { From 78ed9a183e123f38929bf2df100c8c1cae375093 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Sun, 11 Sep 2016 21:48:01 -0700 Subject: [PATCH 23/42] refine tests --- .../apache/spark/ml/clustering/KMeans.scala | 4 +-- .../org/apache/spark/ml/util/ReadWrite.scala | 10 +++---- .../spark/ml/clustering/KMeansSuite.scala | 29 +++++++++++++------ 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 4c2a2a03697c..065da481fb84 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -332,8 +332,8 @@ class KMeans @Since("1.5.0") ( value match { case m: KMeansModel => setInitialModel(m) case other => - logWarning(s"KMeansModel required but ${other.getClass.getSimpleName} found.") - this + throw new IllegalArgumentException( + s"KMeansModel required but ${other.getClass.getSimpleName} found.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index b697cb78b512..fdd89339c2e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -329,10 +329,10 @@ private[ml] object DefaultParamsWriter { } def saveInitialModel[T <: HasInitialModel[_ <: MLWritable]](instance: T, path: String): Unit = { - if (instance.isDefined(instance.getParam("initialModel"))) { + if (instance.isDefined(instance.initialModel)) { val initialModelPath = new Path(path, "initialModel").toString - val initialModel = instance.getOrDefault(instance.getParam("initialModel")) - initialModel.asInstanceOf[MLWritable].save(initialModelPath) + val initialModel = instance.getOrDefault(instance.initialModel) + initialModel.save(initialModelPath) } } } @@ -472,8 +472,8 @@ private[ml] object DefaultParamsReader { val hasInitialModel = (metadata.metadata \ "initialModel").extract[Boolean] if (hasInitialModel) { val initialModelPath = new Path(path, "initialModel").toString - val initialModel = loadParamsInstance[Model[M]](initialModelPath, sc) - instance.set(instance.getParam("initialModel"), initialModel) + val initialModel = loadParamsInstance[M](initialModelPath, sc) + instance.set(instance.initialModel, initialModel) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 8af99326a932..4b17520e9d6c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -160,16 +160,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR test("Initialize using wrong model") { val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(10) val wrongTypeModel = new KMeansSuite.MockModel() - assert(!kmeans.setInitialModel(wrongTypeModel).isSet(kmeans.initialModel)) - val wrongKModel = KMeansSuite.generateKMeansModel(3, k + 1) - intercept[IllegalArgumentException] { - kmeans.setInitialModel(wrongKModel).fit(dataset) + withClue("The type of an initial model should only be a KMeansModel.") { + intercept[IllegalArgumentException] { + kmeans.setInitialModel(wrongTypeModel).fit(dataset) + } } - val wrongDimModel = KMeansSuite.generateKMeansModel(4, k) - intercept[IllegalArgumentException] { - kmeans.setInitialModel(wrongDimModel).fit(dataset) + val wrongKModel = KMeansSuite.generateRandomKMeansModel(3, k + 1) + withClue("The number of clusters set in the given model should be the same with the one set" + + " in the KMeans estimator.") { + intercept[IllegalArgumentException] { + kmeans.setInitialModel(wrongKModel).fit(dataset) + } + } + + val wrongDimModel = KMeansSuite.generateRandomKMeansModel(4, k) + withClue("The dimension of points in the model should be the same with the dimension of the" + + " training data.") { + intercept[IllegalArgumentException] { + kmeans.setInitialModel(wrongDimModel).fit(dataset) + } } } } @@ -194,7 +205,7 @@ object KMeansSuite { spark.createDataFrame(rdd) } - def generateKMeansModel(dim: Int, k: Int, seed: Int = 42): KMeansModel = { + def generateRandomKMeansModel(dim: Int, k: Int, seed: Int = 42): KMeansModel = { val rng = new Random(seed) val clusterCenters = (1 to k) .map(i => MLlibVectors.dense(Array.fill(dim)(rng.nextDouble))) @@ -211,6 +222,6 @@ object KMeansSuite { "k" -> 3, "maxIter" -> 2, "tol" -> 0.01, - "initialModel" -> generateKMeansModel(3, 3) + "initialModel" -> generateRandomKMeansModel(3, 3) ) } From 03575bf0d9046a6387fe3ded4bbb2c0fe3186ac5 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 14 Sep 2016 13:39:06 -0700 Subject: [PATCH 24/42] remove some setters --- .../apache/spark/ml/clustering/KMeans.scala | 21 +---------------- .../spark/ml/clustering/KMeansSuite.scala | 23 +++++-------------- 2 files changed, 7 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 065da481fb84..78670da131b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -89,7 +89,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe if (isDefined(initialModel)) { val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length require(kOfInitialModel == $(k), - s"${$(k)} cluster centers required but $kOfInitialModel found in the initial model.") + s"mismatched cluster count, ${$(k)} cluster centers required but $kOfInitialModel found.") } SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) @@ -326,25 +326,6 @@ class KMeans @Since("1.5.0") ( @Since("2.1.0") def setInitialModel(value: KMeansModel): this.type = set(initialModel, value) - /** @group setParam */ - @Since("2.1.0") - def setInitialModel(value: Model[_]): this.type = { - value match { - case m: KMeansModel => setInitialModel(m) - case other => - throw new IllegalArgumentException( - s"KMeansModel required but ${other.getClass.getSimpleName} found.") - } - } - - /** @group setParam */ - @Since("2.1.0") - def setInitialModel(clusterCenters: Array[Vector]): this.type = { - setInitialModel( - new KMeansModel("initial model", - new MLlibKMeansModel(clusterCenters.map(OldVectors.fromML)))) - } - @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { transformSchema(dataset.schema, logging = true) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 4b17520e9d6c..90432a16f727 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -159,29 +159,18 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR test("Initialize using wrong model") { val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(10) - val wrongTypeModel = new KMeansSuite.MockModel() - - withClue("The type of an initial model should only be a KMeansModel.") { - intercept[IllegalArgumentException] { - kmeans.setInitialModel(wrongTypeModel).fit(dataset) - } - } val wrongKModel = KMeansSuite.generateRandomKMeansModel(3, k + 1) - withClue("The number of clusters set in the given model should be the same with the one set" + - " in the KMeans estimator.") { - intercept[IllegalArgumentException] { - kmeans.setInitialModel(wrongKModel).fit(dataset) - } + val wrongKModelThrown = intercept[IllegalArgumentException] { + kmeans.setInitialModel(wrongKModel).fit(dataset) } + assert(wrongKModelThrown.getMessage.contains("mismatched cluster count")) val wrongDimModel = KMeansSuite.generateRandomKMeansModel(4, k) - withClue("The dimension of points in the model should be the same with the dimension of the" + - " training data.") { - intercept[IllegalArgumentException] { - kmeans.setInitialModel(wrongDimModel).fit(dataset) - } + val wrongDimModelThrown = intercept[IllegalArgumentException] { + kmeans.setInitialModel(wrongDimModel).fit(dataset) } + assert(wrongDimModelThrown.getMessage.contains("mismatched dimension")) } } From c21ffa2117364c77d3dbba37f40762cad80a8f19 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 14 Sep 2016 13:47:57 -0700 Subject: [PATCH 25/42] change the implementation of initialModel --- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 8 +++++++- .../spark/ml/param/shared/sharedGeneralTypeParams.scala | 7 +------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 78670da131b3..671e1b9887cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -80,6 +80,13 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) + /** + * Param for KMeansModel to use for warm start. + * @group param + */ + final val initialModel: Param[KMeansModel] = + new Param[KMeansModel](this, "initialModel", "A KMeansModel for warm start.") + /** * Validates and transforms the input schema. * @param schema input schema @@ -219,7 +226,6 @@ object KMeansModel extends MLReadable[KMeansModel] { /** [[MLWriter]] instance for [[KMeansModel]] */ private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - import org.json4s.JsonDSL._ override protected def saveImpl(path: String): Unit = { // Save metadata and Params diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala index 5dc6c3ee62b3..c67380edaa60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala @@ -22,12 +22,7 @@ import org.apache.spark.ml.param._ private[ml] trait HasInitialModel[T <: Model[T]] extends Params { - /** - * Param for initial model to warm start. - * @group param - */ - final val initialModel: Param[T] = - new Param[T](this, "initialModel", "initial model to warm start") + def initialModel: Param[T] /** @group getParam */ final def getInitialModel: T = $(initialModel) From eb7fbbea3a68135442c5088ccc6972b6c50b8f51 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 15 Sep 2016 23:39:18 -0700 Subject: [PATCH 26/42] remove hashcode and equal check --- .../apache/spark/ml/clustering/KMeans.scala | 13 +----- .../spark/ml/util/DefaultReadWriteTest.scala | 43 ++++++++++++++++--- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 671e1b9887cc..9d4f005fa421 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -171,17 +171,6 @@ class KMeansModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) - override def hashCode(): Int = { - (Array(this.getClass, uid) ++ clusterCenters) - .foldLeft(17) { case (hash, obj) => hash * 31 + obj.hashCode() } - } - - override def equals(other: Any): Boolean = other match { - case that: KMeansModel => - this.uid == that.uid && this.clusterCenters.sameElements(that.clusterCenters) - case _ => false - } - private var trainingSummary: Option[KMeansSummary] = None private[clustering] def setSummary(summary: KMeansSummary): this.type = { @@ -382,7 +371,7 @@ object KMeans extends DefaultParamsReadable[KMeans] { @Since("1.6.0") override def load(path: String): KMeans = super.load(path) - @Since("2.1.0") + @Since("1.6.0") override def read: MLReader[KMeans] = new KMeansReader /** [[MLWriter]] instance for [[KMeans]] */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 553b8725b30a..3ed383e8d9d9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -25,7 +25,8 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types.StructType trait DefaultReadWriteTest extends TempDirectory { self: Suite => @@ -39,9 +40,11 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam T ML instance type * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with MLWritable]( + def testDefaultReadWrite[T <: Params with MLWritable, M <: Model[M]]( instance: T, - testParams: Boolean = true): T = { + testParams: Boolean = true, + checkModelData: (M, M) => Unit = (m1: MyModel, m2: MyModel) => + throw new UnsupportedOperationException("Model check function needed")): T = { val uid = instance.uid val subdirName = Identifiable.randomUID("test") @@ -63,6 +66,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { case (Array(values), Array(newValues)) => assert(values === newValues, s"Values do not match on param ${p.name}.") + case (m1: M, m2: M) => + checkModelData(m1, m2) case (value, newValue) => assert(value === newValue, s"Values do not match on param ${p.name}.") } @@ -108,23 +113,47 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => val model = estimator.fit(dataset) // Test Estimator save/load - val estimator2 = testDefaultReadWrite(estimator) + val estimator2 = testDefaultReadWrite(estimator, checkModelData = checkModelData) testParams.foreach { case (p, v) => val param = estimator.getParam(p) - assert(estimator.get(param).get === estimator2.get(param).get) + val paramVal = estimator.get(param).get + val paramVal2 = estimator2.get(param).get + paramVal match { + case _: M => + checkModelData(paramVal.asInstanceOf[M], paramVal2.asInstanceOf[M]) + case other => + assert(estimator.get(param).get === estimator2.get(param).get) + } } // Test Model save/load - val model2 = testDefaultReadWrite(model) + val model2 = testDefaultReadWrite(model, checkModelData = checkModelData) testParams.foreach { case (p, v) => val param = model.getParam(p) - assert(model.get(param).get === model2.get(param).get) + val paramVal = estimator.get(param).get + val paramVal2 = estimator2.get(param).get + paramVal match { + case _: M => + checkModelData(paramVal.asInstanceOf[M], paramVal2.asInstanceOf[M]) + case other => + assert(estimator.get(param).get === estimator2.get(param).get) + } } checkModelData(model, model2) } } +class MyModel extends Model[MyModel] { + override val uid: String = Identifiable.randomUID("MyModel") + + override def transform(dataset: Dataset[_]): DataFrame = dataset.asInstanceOf[DataFrame] + + override def transformSchema(schema: StructType): StructType = schema + + override def copy(extra: ParamMap): MyModel = this +} + class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") From f6e024a276f466455549dbc25c43162d92829468 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 16 Sep 2016 15:01:59 -0700 Subject: [PATCH 27/42] fix errors --- .../spark/ml/util/DefaultReadWriteTest.scala | 67 ++++++++++--------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 3ed383e8d9d9..479749b973bf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -22,7 +22,7 @@ import java.io.{File, IOException} import org.scalatest.Suite import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.{Estimator, Model, PipelineStage} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset} @@ -42,9 +42,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => */ def testDefaultReadWrite[T <: Params with MLWritable, M <: Model[M]]( instance: T, - testParams: Boolean = true, - checkModelData: (M, M) => Unit = (m1: MyModel, m2: MyModel) => - throw new UnsupportedOperationException("Model check function needed")): T = { + testParams: Boolean = true): T = { val uid = instance.uid val subdirName = Identifiable.randomUID("test") @@ -66,8 +64,6 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { case (Array(values), Array(newValues)) => assert(values === newValues, s"Values do not match on param ${p.name}.") - case (m1: M, m2: M) => - checkModelData(m1, m2) case (value, newValue) => assert(value === newValue, s"Values do not match on param ${p.name}.") } @@ -83,6 +79,31 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => another } + /** + * Compare Params with complex types that could not compare with [[===]]. + * + * @param stage A pipeline stage contains these params. + * @param stage2 Another pipeline stage to compare. + * @param testParams Params to compare. + * @param testFunctions Functions to compare complex type params. + */ + def compareParamsWithComplexTypes( + stage: PipelineStage, + stage2: PipelineStage, + testParams: Map[String, Any], + testFunctions: Map[String, (Any, Any) => Unit]): Unit = { + testParams.foreach { case (p, v) => + val param = stage.getParam(p) + val paramVal = stage.get(param).get + val paramVal2 = stage2.get(param).get + if (testFunctions.contains(p)) { + testFunctions(p)(paramVal, paramVal2) + } else { + assert(paramVal === paramVal2) + } + } + } + /** * Default test for Estimator, Model pairs: * - Explicitly set Params, and train model @@ -112,33 +133,19 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } val model = estimator.fit(dataset) - // Test Estimator save/load - val estimator2 = testDefaultReadWrite(estimator, checkModelData = checkModelData) - testParams.foreach { case (p, v) => - val param = estimator.getParam(p) - val paramVal = estimator.get(param).get - val paramVal2 = estimator2.get(param).get - paramVal match { - case _: M => - checkModelData(paramVal.asInstanceOf[M], paramVal2.asInstanceOf[M]) - case other => - assert(estimator.get(param).get === estimator2.get(param).get) - } + val testFunctions = if (testParams.contains("initialModel")) { + Map(("initialModel", checkModelData.asInstanceOf[(Any, Any) => Unit])) + } else { + Map.empty[String, (Any, Any) => Unit] } + // Test Estimator save/load + val estimator2 = testDefaultReadWrite(estimator, testParams = false) + compareParamsWithComplexTypes(estimator, estimator2, testParams, testFunctions) + // Test Model save/load - val model2 = testDefaultReadWrite(model, checkModelData = checkModelData) - testParams.foreach { case (p, v) => - val param = model.getParam(p) - val paramVal = estimator.get(param).get - val paramVal2 = estimator2.get(param).get - paramVal match { - case _: M => - checkModelData(paramVal.asInstanceOf[M], paramVal2.asInstanceOf[M]) - case other => - assert(estimator.get(param).get === estimator2.get(param).get) - } - } + val model2 = testDefaultReadWrite(model, testParams = false) + compareParamsWithComplexTypes(model, model2, testParams, testFunctions) checkModelData(model, model2) } From 92cf83dae222a4e7cce0d201c81a26d70f51b161 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 26 Sep 2016 23:58:22 -0700 Subject: [PATCH 28/42] fix errors --- .../org/apache/spark/ml/clustering/KMeans.scala | 11 ++++++----- .../apache/spark/ml/clustering/KMeansSuite.scala | 11 ----------- .../spark/ml/util/DefaultReadWriteTest.scala | 16 +++------------- 3 files changed, 9 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 9d4f005fa421..0e2851f2ef15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -93,11 +93,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - if (isDefined(initialModel)) { - val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length - require(kOfInitialModel == $(k), - s"mismatched cluster count, ${$(k)} cluster centers required but $kOfInitialModel found.") - } SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } @@ -340,10 +335,16 @@ class KMeans @Since("1.5.0") ( .setEpsilon($(tol)) if (isDefined(initialModel)) { + // Check the equal of number of dimension val dimOfData = rdd.first().size val dimOfInitialModel = $(initialModel).clusterCenters.head.size require(dimOfData == dimOfInitialModel, s"mismatched dimension, $dimOfData in data while $dimOfInitialModel in the initial model.") + + // Check the equal of number of clusters + val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length + require(kOfInitialModel == $(k), + s"mismatched cluster count, ${$(k)} cluster centers required but $kOfInitialModel found.") algo.setInitialModel($(initialModel).parentModel) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 90432a16f727..3401e6652860 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -176,17 +176,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR object KMeansSuite { - class MockModel(override val uid: String) extends Model[MockModel] { - - def this() = this(Identifiable.randomUID("mockModel")) - - override def copy(extra: ParamMap): MockModel = throw new NotImplementedError() - - override def transform(dataset: Dataset[_]): DataFrame = throw new NotImplementedError() - - override def transformSchema(schema: StructType): StructType = throw new NotImplementedError() - } - def generateKMeansData(spark: SparkSession, rows: Int, dim: Int, k: Int): DataFrame = { val sc = spark.sparkContext val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 479749b973bf..dc16009b638b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -25,8 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model, PipelineStage} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.Dataset trait DefaultReadWriteTest extends TempDirectory { self: Suite => @@ -40,7 +39,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam T ML instance type * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with MLWritable, M <: Model[M]]( + def testDefaultReadWrite[T <: Params with MLWritable]( instance: T, testParams: Boolean = true): T = { val uid = instance.uid @@ -133,6 +132,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } val model = estimator.fit(dataset) + // TODO: Change the test function if the type of initialModel isn't the same with type M. val testFunctions = if (testParams.contains("initialModel")) { Map(("initialModel", checkModelData.asInstanceOf[(Any, Any) => Unit])) } else { @@ -151,16 +151,6 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } } -class MyModel extends Model[MyModel] { - override val uid: String = Identifiable.randomUID("MyModel") - - override def transform(dataset: Dataset[_]): DataFrame = dataset.asInstanceOf[DataFrame] - - override def transformSchema(schema: StructType): StructType = schema - - override def copy(extra: ParamMap): MyModel = this -} - class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") From 95bf12f352f085587ff7772ffcd4ccdf9f7f084b Mon Sep 17 00:00:00 2001 From: lapanda Date: Wed, 5 Oct 2016 00:23:42 -0700 Subject: [PATCH 29/42] add TODO with JIRA --- .../org/apache/spark/ml/clustering/KMeans.scala | 2 ++ .../scala/org/apache/spark/ml/util/ReadWrite.scala | 2 +- .../apache/spark/ml/clustering/KMeansSuite.scala | 7 +++---- .../spark/ml/util/DefaultReadWriteTest.scala | 14 ++++---------- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 0e2851f2ef15..d8b88408322a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -369,6 +369,8 @@ class KMeans @Since("1.5.0") ( @Since("1.6.0") object KMeans extends DefaultParamsReadable[KMeans] { + // TODO: [SPARK-17784]: Add a fromCenters method + @Since("1.6.0") override def load(path: String): KMeans = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index fdd89339c2e5..8ec4ea433c59 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -315,7 +315,7 @@ private[ml] object DefaultParamsWriter { ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ ("paramMap" -> jsonParams) ~ - // TODO: Figure out more robust way to detect the existing of the initialModel. + // TODO: [SPARK-17785] Figure out a more robust way to detect the existing of the initialModel ("initialModel" -> initialModelFlag) val metadata = extraMetadata match { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 3401e6652860..24f1553fc2f8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -20,16 +20,14 @@ package org.apache.spark.ml.clustering import scala.util.Random import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.Model import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamPair} -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable} +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.types.StructType private[clustering] case class TestRow(features: Vector) @@ -144,7 +142,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.clusterCenters === model2.clusterCenters) } val kmeans = new KMeans() - testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData, + Map("initialModel" -> (checkModelData _).asInstanceOf[(Any, Any) => Unit])) } test("Initialize using given cluster centers") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index dc16009b638b..944db0c34b05 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -125,27 +125,21 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => estimator: E, dataset: Dataset[_], testParams: Map[String, Any], - checkModelData: (M, M) => Unit): Unit = { + checkModelData: (M, M) => Unit, + checkParamsFunctions: Map[String, (Any, Any) => Unit] = Map.empty): Unit = { // Set some Params to make sure set Params are serialized. testParams.foreach { case (p, v) => estimator.set(estimator.getParam(p), v) } val model = estimator.fit(dataset) - // TODO: Change the test function if the type of initialModel isn't the same with type M. - val testFunctions = if (testParams.contains("initialModel")) { - Map(("initialModel", checkModelData.asInstanceOf[(Any, Any) => Unit])) - } else { - Map.empty[String, (Any, Any) => Unit] - } - // Test Estimator save/load val estimator2 = testDefaultReadWrite(estimator, testParams = false) - compareParamsWithComplexTypes(estimator, estimator2, testParams, testFunctions) + compareParamsWithComplexTypes(estimator, estimator2, testParams, checkParamsFunctions) // Test Model save/load val model2 = testDefaultReadWrite(model, testParams = false) - compareParamsWithComplexTypes(model, model2, testParams, testFunctions) + compareParamsWithComplexTypes(model, model2, testParams, checkParamsFunctions) checkModelData(model, model2) } From 7fc69189b1e501681a5a8f6065697482e5e52584 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 6 Oct 2016 01:58:13 -0700 Subject: [PATCH 30/42] add infering K from initial model --- .../apache/spark/ml/clustering/KMeans.scala | 9 +++---- .../spark/ml/clustering/KMeansSuite.scala | 25 ++++++++++--------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index d8b88408322a..28683a36c277 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -314,7 +314,9 @@ class KMeans @Since("1.5.0") ( /** @group setParam */ @Since("2.1.0") - def setInitialModel(value: KMeansModel): this.type = set(initialModel, value) + def setInitialModel(value: KMeansModel): this.type = { + set(k, value.parentModel.clusterCenters.length).set(initialModel, value) + } @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { @@ -340,11 +342,6 @@ class KMeans @Since("1.5.0") ( val dimOfInitialModel = $(initialModel).clusterCenters.head.size require(dimOfData == dimOfInitialModel, s"mismatched dimension, $dimOfData in data while $dimOfInitialModel in the initial model.") - - // Check the equal of number of clusters - val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length - require(kOfInitialModel == $(k), - s"mismatched cluster count, ${$(k)} cluster centers required but $kOfInitialModel found.") algo.setInitialModel($(initialModel).parentModel) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 24f1553fc2f8..a7e072bfa8ca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -33,13 +33,14 @@ private[clustering] case class TestRow(features: Vector) class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - final val k = 5 + final val k: Int = 5 + final val dim: Int = 3 @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() - dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) + dataset = KMeansSuite.generateKMeansData(spark, 50, dim, k) } test("default parameters") { @@ -146,7 +147,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR Map("initialModel" -> (checkModelData _).asInstanceOf[(Any, Any) => Unit])) } - test("Initialize using given cluster centers") { + test("Initialize using a trained model") { val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(1) val oneIterModel = kmeans.fit(dataset) val twoIterModel = kmeans.copy(ParamMap(ParamPair(kmeans.maxIter, 2))).fit(dataset) @@ -156,21 +157,21 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR .foreach { case (center1, center2) => assert(center1 ~== center2 absTol 1E-8) } } - test("Initialize using wrong model") { - val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(10) - - val wrongKModel = KMeansSuite.generateRandomKMeansModel(3, k + 1) - val wrongKModelThrown = intercept[IllegalArgumentException] { - kmeans.setInitialModel(wrongKModel).fit(dataset) - } - assert(wrongKModelThrown.getMessage.contains("mismatched cluster count")) - + test("Initialize using a wrong model") { + val kmeans = new KMeans().setSeed(1).setMaxIter(10) val wrongDimModel = KMeansSuite.generateRandomKMeansModel(4, k) val wrongDimModelThrown = intercept[IllegalArgumentException] { kmeans.setInitialModel(wrongDimModel).fit(dataset) } assert(wrongDimModelThrown.getMessage.contains("mismatched dimension")) } + + test("Infer K from an initial model") { + val kmeans = new KMeans().setK(k) + val testNewK = 10 + val randomModel = KMeansSuite.generateRandomKMeansModel(dim, testNewK) + assert(kmeans.setInitialModel(randomModel).getK === testNewK) + } } object KMeansSuite { From 5fbb132bc9b11e598086783b77c60be8fb669095 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 7 Oct 2016 12:31:42 -0700 Subject: [PATCH 31/42] add more assert --- .../apache/spark/ml/clustering/KMeans.scala | 16 +++++++++++++++- .../spark/ml/clustering/KMeansSuite.scala | 19 +++++++++++++++---- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 28683a36c277..681ffbf07225 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -315,7 +315,15 @@ class KMeans @Since("1.5.0") ( /** @group setParam */ @Since("2.1.0") def setInitialModel(value: KMeansModel): this.type = { - set(k, value.parentModel.clusterCenters.length).set(initialModel, value) + val kOfInitialModel = value.parentModel.clusterCenters.length + if (isSet(k)) { + require(kOfInitialModel == $(k), + s"mismatched cluster count, ${$(k)} cluster centers required but $kOfInitialModel found.") + } else { + set(k, kOfInitialModel) + logWarning(s"Param K is set to $kOfInitialModel by the initialModel.") + } + set(initialModel, value) } @Since("2.0.0") @@ -342,6 +350,12 @@ class KMeans @Since("1.5.0") ( val dimOfInitialModel = $(initialModel).clusterCenters.head.size require(dimOfData == dimOfInitialModel, s"mismatched dimension, $dimOfData in data while $dimOfInitialModel in the initial model.") + + // Check the equal of number of clusters + val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length + require(kOfInitialModel == $(k), + s"mismatched cluster count, ${$(k)} cluster centers required but $kOfInitialModel found.") + algo.setInitialModel($(initialModel).parentModel) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index a7e072bfa8ca..c599fedf0ef4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -157,8 +157,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR .foreach { case (center1, center2) => assert(center1 ~== center2 absTol 1E-8) } } - test("Initialize using a wrong model") { - val kmeans = new KMeans().setSeed(1).setMaxIter(10) + test("Initialize using a model with wrong dimension of cluster centers") { + val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(1) + val wrongDimModel = KMeansSuite.generateRandomKMeansModel(4, k) val wrongDimModelThrown = intercept[IllegalArgumentException] { kmeans.setInitialModel(wrongDimModel).fit(dataset) @@ -166,12 +167,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(wrongDimModelThrown.getMessage.contains("mismatched dimension")) } - test("Infer K from an initial model") { - val kmeans = new KMeans().setK(k) + test("Infer K from an initial model if K is unset") { + val kmeans = new KMeans() val testNewK = 10 val randomModel = KMeansSuite.generateRandomKMeansModel(dim, testNewK) assert(kmeans.setInitialModel(randomModel).getK === testNewK) } + + test("Initialize using a model with wrong K if K is set") { + val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(1) + + val wrongKModel = KMeansSuite.generateRandomKMeansModel(3, k + 1) + val wrongKModelThrown = intercept[IllegalArgumentException] { + kmeans.setInitialModel(wrongKModel).fit(dataset) + } + assert(wrongKModelThrown.getMessage.contains("mismatched cluster count")) + } } object KMeansSuite { From 4bba7c16c9684cf4b125d0f4a72cd9640e94e62c Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 14 Oct 2016 14:54:04 -0700 Subject: [PATCH 32/42] fix logics of test --- .../org/apache/spark/ml/clustering/KMeans.scala | 13 +++++++++---- .../apache/spark/ml/clustering/KMeansSuite.scala | 14 ++++++++------ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 681ffbf07225..97fa68c8f83d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -82,6 +82,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe /** * Param for KMeansModel to use for warm start. + * The setting of initialModel takes precedence of param K. * @group param */ final val initialModel: Param[KMeansModel] = @@ -317,8 +318,11 @@ class KMeans @Since("1.5.0") ( def setInitialModel(value: KMeansModel): this.type = { val kOfInitialModel = value.parentModel.clusterCenters.length if (isSet(k)) { - require(kOfInitialModel == $(k), - s"mismatched cluster count, ${$(k)} cluster centers required but $kOfInitialModel found.") + if ($(k) != kOfInitialModel) { + set(k, kOfInitialModel) + logWarning(s"Param K is set to $kOfInitialModel by the initialModel." + + s" Previous value is ${$(k)}.") + } } else { set(k, kOfInitialModel) logWarning(s"Param K is set to $kOfInitialModel by the initialModel.") @@ -354,7 +358,8 @@ class KMeans @Since("1.5.0") ( // Check the equal of number of clusters val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length require(kOfInitialModel == $(k), - s"mismatched cluster count, ${$(k)} cluster centers required but $kOfInitialModel found.") + s"mismatched cluster count, ${$(k)} cluster centers required but $kOfInitialModel found" + + s" in the initial model.") algo.setInitialModel($(initialModel).parentModel) } @@ -398,7 +403,7 @@ object KMeans extends DefaultParamsReadable[KMeans] { private class KMeansReader extends MLReader[KMeans] { - /** Checked against metadata when loading model */ + /** Checked against metadata when loading estimator */ private val className = classOf[KMeans].getName override def load(path: String): KMeans = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c599fedf0ef4..4188e1709d76 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -153,6 +153,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val twoIterModel = kmeans.copy(ParamMap(ParamPair(kmeans.maxIter, 2))).fit(dataset) val oneMoreIterModel = kmeans.setInitialModel(oneIterModel).fit(dataset) + assert(oneMoreIterModel.getK === k) + twoIterModel.clusterCenters.zip(oneMoreIterModel.clusterCenters) .foreach { case (center1, center2) => assert(center1 ~== center2 absTol 1E-8) } } @@ -167,19 +169,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(wrongDimModelThrown.getMessage.contains("mismatched dimension")) } - test("Infer K from an initial model if K is unset") { - val kmeans = new KMeans() + test("Infer K from an initial model") { + val kmeans = new KMeans().setK(5) val testNewK = 10 val randomModel = KMeansSuite.generateRandomKMeansModel(dim, testNewK) assert(kmeans.setInitialModel(randomModel).getK === testNewK) } - test("Initialize using a model with wrong K if K is set") { - val kmeans = new KMeans().setK(k).setSeed(1).setMaxIter(1) + test("Reset K after setting initial model") { + val kmeans = new KMeans().setSeed(1).setMaxIter(1) - val wrongKModel = KMeansSuite.generateRandomKMeansModel(3, k + 1) + val wrongKModel = KMeansSuite.generateRandomKMeansModel(dim, k) val wrongKModelThrown = intercept[IllegalArgumentException] { - kmeans.setInitialModel(wrongKModel).fit(dataset) + kmeans.setInitialModel(wrongKModel).setK(k + 1).fit(dataset) } assert(wrongKModelThrown.getMessage.contains("mismatched cluster count")) } From 127ca06d7b949e8c5d8d3134ff0b9206ba484524 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 14 Oct 2016 14:59:57 -0700 Subject: [PATCH 33/42] add new test of using different initial model --- .../scala/org/apache/spark/ml/clustering/KMeansSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 4188e1709d76..b161d6ebc525 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -174,6 +174,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val testNewK = 10 val randomModel = KMeansSuite.generateRandomKMeansModel(dim, testNewK) assert(kmeans.setInitialModel(randomModel).getK === testNewK) + + val differentKRandomModel = KMeansSuite.generateRandomKMeansModel(dim, testNewK + 1) + assert(kmeans.setInitialModel(differentKRandomModel).getK === testNewK + 1) } test("Reset K after setting initial model") { From 0e93fda11f9e99ab53146527bffb99928a7df639 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 14 Oct 2016 15:37:25 -0700 Subject: [PATCH 34/42] get rid of metadata.hasInitialModel --- .../apache/spark/ml/clustering/KMeans.scala | 15 +++++++++-- .../org/apache/spark/ml/util/ReadWrite.scala | 25 +++++-------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 97fa68c8f83d..dd5b708a55c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.util.{Failure, Success} + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -250,7 +252,12 @@ object KMeansModel extends MLReadable[KMeansModel] { } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) - DefaultParamsReader.loadAndSetInitialModel[KMeansModel](model, metadata, path, sc) + DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { + case Success(v) => + model.set(model.initialModel, v) + case Failure(e) => // Unable to load initial model + } + model } } @@ -411,7 +418,11 @@ object KMeans extends DefaultParamsReadable[KMeans] { val instance = new KMeans(metadata.uid) DefaultParamsReader.getAndSetParams(instance, metadata) - DefaultParamsReader.loadAndSetInitialModel[KMeansModel](instance, metadata, path, sc) + DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { + case Success(v) => + instance.setInitialModel(v) + case Failure(e) => // Fail to load initial model + } instance } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 8ec4ea433c59..5ccf1ab2630c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.util import java.io.IOException +import scala.util.Try + import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.{DefaultFormats, JObject} @@ -306,17 +308,11 @@ private[ml] object DefaultParamsWriter { val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) - // If the instance has an "initialModel" param and the param is defined, then the initial model - // will be saved along with the instance. - val initialModelFlag = - instance.hasParam("initialModel") && instance.isDefined(instance.getParam("initialModel")) val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) ~ - // TODO: [SPARK-17785] Figure out a more robust way to detect the existing of the initialModel - ("initialModel" -> initialModelFlag) + ("paramMap" -> jsonParams) val metadata = extraMetadata match { case Some(jObject) => @@ -464,18 +460,9 @@ private[ml] object DefaultParamsReader { cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) } - def loadAndSetInitialModel[M <: Model[M]]( - instance: HasInitialModel[M], metadata: Metadata, path: String, sc: SparkContext): Unit = { - implicit val format = DefaultFormats - // Try to load the initial model - if (metadata.metadata \ "initialModel" != JNothing) { - val hasInitialModel = (metadata.metadata \ "initialModel").extract[Boolean] - if (hasInitialModel) { - val initialModelPath = new Path(path, "initialModel").toString - val initialModel = loadParamsInstance[M](initialModelPath, sc) - instance.set(instance.initialModel, initialModel) - } - } + def loadInitialModel[M <: Model[M]](path: String, sc: SparkContext): Try[M] = { + val initialModelPath = new Path(path, "initialModel").toString + Try(loadParamsInstance[M](initialModelPath, sc)) } } From 261fcfa4e96fe901f1866eb56ac87917935bf28b Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Sun, 16 Oct 2016 21:18:20 -0700 Subject: [PATCH 35/42] fix nits and errors --- .../apache/spark/ml/clustering/KMeans.scala | 23 +++++++++++-------- .../spark/ml/clustering/KMeansSuite.scala | 18 +++++++-------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index dd5b708a55c1..6c6cb55e17bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -84,7 +84,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe /** * Param for KMeansModel to use for warm start. - * The setting of initialModel takes precedence of param K. + * Whenever initialModel is set, the initialModel k will override the param k. * @group param */ final val initialModel: Param[KMeansModel] = @@ -298,7 +298,14 @@ class KMeans @Since("1.5.0") ( /** @group setParam */ @Since("1.5.0") - def setK(value: Int): this.type = set(k, value) + def setK(value: Int): this.type = { + if (isSet(initialModel)) { + logWarning("initialModel is set, so k will be ignored. Clear initialModel first.") + this + } else { + set(k, value) + } + } /** @group expertSetParam */ @Since("1.5.0") @@ -326,9 +333,10 @@ class KMeans @Since("1.5.0") ( val kOfInitialModel = value.parentModel.clusterCenters.length if (isSet(k)) { if ($(k) != kOfInitialModel) { + val previousK = $(k) set(k, kOfInitialModel) logWarning(s"Param K is set to $kOfInitialModel by the initialModel." + - s" Previous value is ${$(k)}.") + s" Previous value is $previousK.") } } else { set(k, kOfInitialModel) @@ -356,7 +364,7 @@ class KMeans @Since("1.5.0") ( .setEpsilon($(tol)) if (isDefined(initialModel)) { - // Check the equal of number of dimension + // Check the equal of dimension val dimOfData = rdd.first().size val dimOfInitialModel = $(initialModel).clusterCenters.head.size require(dimOfData == dimOfInitialModel, @@ -418,11 +426,8 @@ object KMeans extends DefaultParamsReadable[KMeans] { val instance = new KMeans(metadata.uid) DefaultParamsReader.getAndSetParams(instance, metadata) - DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { - case Success(v) => - instance.setInitialModel(v) - case Failure(e) => // Fail to load initial model - } + DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) + .foreach(v => instance.setInitialModel(v)) instance } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index b161d6ebc525..2c1b9a3285ea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -174,19 +174,17 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val testNewK = 10 val randomModel = KMeansSuite.generateRandomKMeansModel(dim, testNewK) assert(kmeans.setInitialModel(randomModel).getK === testNewK) - - val differentKRandomModel = KMeansSuite.generateRandomKMeansModel(dim, testNewK + 1) - assert(kmeans.setInitialModel(differentKRandomModel).getK === testNewK + 1) } - test("Reset K after setting initial model") { - val kmeans = new KMeans().setSeed(1).setMaxIter(1) + test("Ignore k if initialModel is set") { + val kmeans = new KMeans() - val wrongKModel = KMeansSuite.generateRandomKMeansModel(dim, k) - val wrongKModelThrown = intercept[IllegalArgumentException] { - kmeans.setInitialModel(wrongKModel).setK(k + 1).fit(dataset) - } - assert(wrongKModelThrown.getMessage.contains("mismatched cluster count")) + val m1 = KMeansSuite.generateRandomKMeansModel(dim, k) + // ignore k if initialModel is set + assert(kmeans.setInitialModel(m1).setK(k - 1).getK === k) + kmeans.clear(kmeans.initialModel) + // k is not ignored after initialModel is cleared + assert(kmeans.setK(k - 1).getK === k - 1) } } From b3ea01a630f275de37a521f8faa6cb5f5efaae43 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Sun, 16 Oct 2016 22:35:31 -0700 Subject: [PATCH 36/42] fix small errors --- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 11 ++++------- .../org/apache/spark/ml/clustering/KMeansSuite.scala | 4 ++-- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 6c6cb55e17bb..523947455438 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -252,11 +252,8 @@ object KMeansModel extends MLReadable[KMeansModel] { } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) - DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { - case Success(v) => - model.set(model.initialModel, v) - case Failure(e) => // Unable to load initial model - } + DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) + .foreach(v => model.set(model.initialModel, v)) model } @@ -364,13 +361,13 @@ class KMeans @Since("1.5.0") ( .setEpsilon($(tol)) if (isDefined(initialModel)) { - // Check the equal of dimension + // Check that the feature dimensions are equal val dimOfData = rdd.first().size val dimOfInitialModel = $(initialModel).clusterCenters.head.size require(dimOfData == dimOfInitialModel, s"mismatched dimension, $dimOfData in data while $dimOfInitialModel in the initial model.") - // Check the equal of number of clusters + // Check that the number of clusters are equal val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length require(kOfInitialModel == $(k), s"mismatched cluster count, ${$(k)} cluster centers required but $kOfInitialModel found" + diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2c1b9a3285ea..c683a5e34209 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -179,9 +179,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR test("Ignore k if initialModel is set") { val kmeans = new KMeans() - val m1 = KMeansSuite.generateRandomKMeansModel(dim, k) + val randomModel = KMeansSuite.generateRandomKMeansModel(dim, k) // ignore k if initialModel is set - assert(kmeans.setInitialModel(m1).setK(k - 1).getK === k) + assert(kmeans.setInitialModel(randomModel).setK(k - 1).getK === k) kmeans.clear(kmeans.initialModel) // k is not ignored after initialModel is cleared assert(kmeans.setK(k - 1).getK === k - 1) From 47de1feb2633fc9e978b20386461977d54cdddac Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 17 Oct 2016 17:46:34 -0700 Subject: [PATCH 37/42] fix load model excpetion --- .../org/apache/spark/ml/clustering/KMeans.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 523947455438..1d1bc6a98532 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.clustering import scala.util.{Failure, Success} import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} @@ -84,7 +85,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe /** * Param for KMeansModel to use for warm start. - * Whenever initialModel is set, the initialModel k will override the param k. + * Whenever initialModel is set, the initialModel k will override the param k, while other params + * remain unchanged. * @group param */ final val initialModel: Param[KMeansModel] = @@ -252,8 +254,10 @@ object KMeansModel extends MLReadable[KMeansModel] { } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) - DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) - .foreach(v => model.set(model.initialModel, v)) + DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { + case Success(v) => model.set(model.initialModel, v) + case Failure(e) => if (!e.isInstanceOf[InvalidInputException]) throw e + } model } @@ -423,8 +427,10 @@ object KMeans extends DefaultParamsReadable[KMeans] { val instance = new KMeans(metadata.uid) DefaultParamsReader.getAndSetParams(instance, metadata) - DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) - .foreach(v => instance.setInitialModel(v)) + DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { + case Success(v) => instance.setInitialModel(v) + case Failure(e) => if (!e.isInstanceOf[InvalidInputException]) throw e + } instance } } From 2c9cd51b5cf925734966576540823ccbfdc422df Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 18 Oct 2016 15:09:05 -0700 Subject: [PATCH 38/42] add more comments --- .../org/apache/spark/ml/clustering/KMeans.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 1d1bc6a98532..45db0ed8c954 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -45,6 +45,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe /** * The number of clusters to create (k). Must be > 1. Default: 2. + * The param k will be overwrote by the param initialModel if the latter is set. * @group param */ @Since("1.5.0") @@ -59,6 +60,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * Param for the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. + * The param initMode will be ignored if the param initialModel is set. * @group expertParam */ @Since("1.5.0") @@ -85,8 +87,10 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe /** * Param for KMeansModel to use for warm start. - * Whenever initialModel is set, the initialModel k will override the param k, while other params - * remain unchanged. + * Whenever initialModel is set: + * 1. the initialModel k will override the param k; + * 2. the param initMode is ignored; + * 3. other params are remain untouched. * @group param */ final val initialModel: Param[KMeansModel] = @@ -256,7 +260,8 @@ object KMeansModel extends MLReadable[KMeansModel] { DefaultParamsReader.getAndSetParams(model, metadata) DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { case Success(v) => model.set(model.initialModel, v) - case Failure(e) => if (!e.isInstanceOf[InvalidInputException]) throw e + case Failure(_: InvalidInputException) => // initialModel doesn't exist, do nothing + case Failure(e) => throw e } model @@ -429,7 +434,8 @@ object KMeans extends DefaultParamsReadable[KMeans] { DefaultParamsReader.getAndSetParams(instance, metadata) DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { case Success(v) => instance.setInitialModel(v) - case Failure(e) => if (!e.isInstanceOf[InvalidInputException]) throw e + case Failure(_: InvalidInputException) => // initialModel doesn't exist, do nothing + case Failure(e) => throw e } instance } From e5299723e75ebcd31d8ef6411f95bb684f9b5a77 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 18 Oct 2016 18:14:35 -0700 Subject: [PATCH 39/42] eliminate possible initialModels of the direct initialModel --- .../scala/org/apache/spark/ml/util/ReadWrite.scala | 8 +++++++- .../org/apache/spark/ml/clustering/KMeansSuite.scala | 12 ++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 5ccf1ab2630c..91f6287e19c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -324,10 +324,16 @@ private[ml] object DefaultParamsWriter { metadataJson } - def saveInitialModel[T <: HasInitialModel[_ <: MLWritable]](instance: T, path: String): Unit = { + def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]]( + instance: T, path: String): Unit = { if (instance.isDefined(instance.initialModel)) { val initialModelPath = new Path(path, "initialModel").toString val initialModel = instance.getOrDefault(instance.initialModel) + // When saving, only keep the direct initialModel by eliminating possible initialModels of the + // direct initialModel, to avoid unnecessary deep recursion of initialModel. + if (initialModel.hasParam("initialModel")) { + initialModel.clear(initialModel.getParam("initialModel")) + } initialModel.save(initialModelPath) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c683a5e34209..a80a468032c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -186,6 +186,18 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR // k is not ignored after initialModel is cleared assert(kmeans.setK(k - 1).getK === k - 1) } + + test("Eliminate possible initialModels of the direct initialModel") { + val randomModel = KMeansSuite.generateRandomKMeansModel(dim, k) + val kmeans = new KMeans().setK(k).setMaxIter(1).setInitialModel(randomModel) + val firstLevelModel = kmeans.fit(dataset) + val secondLevelModel = kmeans.setInitialModel(firstLevelModel).fit(dataset) + assert(secondLevelModel.getInitialModel + .isSet(secondLevelModel.getInitialModel.getParam("initialModel"))) + val savedThenLoadedModel = testDefaultReadWrite(secondLevelModel, testParams = false) + assert(!savedThenLoadedModel.getInitialModel + .isSet(savedThenLoadedModel.getInitialModel.getParam("initialModel"))) + } } object KMeansSuite { From 7046913e4464207f3a866ed874d5b6fd5d8f3b91 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 7 Nov 2016 22:09:36 -0800 Subject: [PATCH 40/42] first update KMeans --- .../apache/spark/ml/clustering/KMeans.scala | 66 +++++++++++-------- .../spark/mllib/clustering/KMeans.scala | 5 +- 2 files changed, 43 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a6919e767c53..8f14dc3a8bb0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -38,10 +38,28 @@ import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} /** - * Common params for KMeans and KMeansModel + * Params for KMeans */ -private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol - with HasSeed with HasPredictionCol with HasTol with HasInitialModel[KMeansModel] { + +private[clustering] trait KMeansParams extends KMeansModelParams with HasInitialModel[KMeansModel] { + /** + * Param for KMeansModel to use for warm start. + * Whenever initialModel is set: + * 1. the initialModel k will override the param k; + * 2. the param initMode is set to initialModel and manually set is ignored; + * 3. other params are untouched. + * @group param + */ + final val initialModel: Param[KMeansModel] = + new Param[KMeansModel](this, "initialModel", "A KMeansModel for warm start.") + +} + +/** + * Params for KMeansModel + */ +private[clustering] trait KMeansModelParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasTol { /** * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than @@ -86,16 +104,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) - /** - * Param for KMeansModel to use for warm start. - * Whenever initialModel is set: - * 1. the initialModel k will override the param k; - * 2. the param initMode is ignored; - * 3. other params are remain untouched. - * @group param - */ - final val initialModel: Param[KMeansModel] = - new Param[KMeansModel](this, "initialModel", "A KMeansModel for warm start.") /** * Validates and transforms the input schema. @@ -119,7 +127,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, private[ml] val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with MLWritable { + extends Model[KMeansModel] with KMeansModelParams with MLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -139,7 +147,8 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val predictUDF = udf((vector: Vector) => predict(vector)) + val tmpParent: MLlibKMeansModel = parentModel + val predictUDF = udf((vector: Vector) => tmpParent.predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -148,8 +157,6 @@ class KMeansModel private[ml] ( validateAndTransformSchema(schema) } - private[clustering] def predict(features: Vector): Int = parentModel.predict(features) - @Since("2.0.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) @@ -225,7 +232,6 @@ object KMeansModel extends MLReadable[KMeansModel] { override protected def saveImpl(path: String): Unit = { // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) - DefaultParamsWriter.saveInitialModel(instance, path) // Save model data: cluster centers val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => @@ -237,6 +243,7 @@ object KMeansModel extends MLReadable[KMeansModel] { } private class KMeansModelReader extends MLReader[KMeansModel] { + /** Checked against metadata when loading model */ private val className = classOf[KMeansModel].getName @@ -260,11 +267,6 @@ object KMeansModel extends MLReadable[KMeansModel] { } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) - DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match { - case Success(v) => model.set(model.initialModel, v) - case Failure(_: InvalidInputException) => // initialModel doesn't exist, do nothing - case Failure(e) => throw e - } model } @@ -317,7 +319,15 @@ class KMeans @Since("1.5.0") ( /** @group expertSetParam */ @Since("1.5.0") - def setInitMode(value: String): this.type = set(initMode, value) + def setInitMode(value: String): this.type = { + if (isSet(initialModel)) { + logWarning(s"initialModel is set, so initMode will be ignored. Clear initialModel first.") + } + if (value == MLlibKMeans.K_MEANS_INITIAL_MODEL) { + logWarning(s"initMode of $value is not supported here, please use setInitialModel.") + } + set(initMode, value) + } /** @group expertSetParam */ @Since("1.5.0") @@ -350,6 +360,7 @@ class KMeans @Since("1.5.0") ( set(k, kOfInitialModel) logWarning(s"Param K is set to $kOfInitialModel by the initialModel.") } + set(initMode, "initialModel") set(initialModel, value) } @@ -380,9 +391,10 @@ class KMeans @Since("1.5.0") ( // Check that the number of clusters are equal val kOfInitialModel = $(initialModel).parentModel.clusterCenters.length - require(kOfInitialModel == $(k), - s"mismatched cluster count, ${$(k)} cluster centers required but $kOfInitialModel found" + - s" in the initial model.") + if (kOfInitialModel != $(k)) { + logWarning(s"mismatched cluster count, ${$(k)} cluster centers required but" + + s" $kOfInitialModel found in the initial model.") + } algo.setInitialModel($(initialModel).parentModel) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index ed9c064879d0..d141624a57e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -17,8 +17,9 @@ package org.apache.spark.mllib.clustering -import scala.collection.mutable.ArrayBuffer +import com.esotericsoftware.kryo.serializers.VersionFieldSerializer.Since +import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging @@ -414,6 +415,8 @@ object KMeans { val RANDOM = "random" @Since("0.8.0") val K_MEANS_PARALLEL = "k-means||" + @Since("2.1.0") + val K_MEANS_INITIAL_MODEL = "initialModel" /** * Trains a k-means model using the given set of parameters. From 8516a2c6d8875cceee49c19f8d70fb71bd2b9225 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 7 Nov 2016 22:35:47 -0800 Subject: [PATCH 41/42] update test --- .../apache/spark/mllib/clustering/KMeans.scala | 4 ++-- .../spark/ml/clustering/KMeansSuite.scala | 12 ------------ .../spark/ml/util/DefaultReadWriteTest.scala | 17 ++++++++++------- 3 files changed, 12 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index d141624a57e1..60ebab39186e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -17,9 +17,8 @@ package org.apache.spark.mllib.clustering -import com.esotericsoftware.kryo.serializers.VersionFieldSerializer.Since - import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging @@ -592,6 +591,7 @@ object KMeans { initMode match { case KMeans.RANDOM => true case KMeans.K_MEANS_PARALLEL => true + case KMeans.K_MEANS_INITIAL_MODEL => true case _ => false } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 4de92ac1ea66..0783c9953902 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -193,18 +193,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR // k is not ignored after initialModel is cleared assert(kmeans.setK(k - 1).getK === k - 1) } - - test("Eliminate possible initialModels of the direct initialModel") { - val randomModel = KMeansSuite.generateRandomKMeansModel(dim, k) - val kmeans = new KMeans().setK(k).setMaxIter(1).setInitialModel(randomModel) - val firstLevelModel = kmeans.fit(dataset) - val secondLevelModel = kmeans.setInitialModel(firstLevelModel).fit(dataset) - assert(secondLevelModel.getInitialModel - .isSet(secondLevelModel.getInitialModel.getParam("initialModel"))) - val savedThenLoadedModel = testDefaultReadWrite(secondLevelModel, testParams = false) - assert(!savedThenLoadedModel.getInitialModel - .isSet(savedThenLoadedModel.getInitialModel.getParam("initialModel"))) - } } object KMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 944db0c34b05..5f1ef837b527 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -92,13 +92,16 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => testParams: Map[String, Any], testFunctions: Map[String, (Any, Any) => Unit]): Unit = { testParams.foreach { case (p, v) => - val param = stage.getParam(p) - val paramVal = stage.get(param).get - val paramVal2 = stage2.get(param).get - if (testFunctions.contains(p)) { - testFunctions(p)(paramVal, paramVal2) - } else { - assert(paramVal === paramVal2) + if (stage.hasParam(p)) { + assert(stage2.hasParam(p)) + val param = stage.getParam(p) + val paramVal = stage.get(param).get + val paramVal2 = stage2.get(param).get + if (testFunctions.contains(p)) { + testFunctions(p)(paramVal, paramVal2) + } else { + assert(paramVal === paramVal2) + } } } } From 6f169ebf8c0c832010d2dbd8f971cfabff7870f2 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 8 Nov 2016 15:01:04 -0800 Subject: [PATCH 42/42] fix mima test --- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 8f14dc3a8bb0..fe357accdc56 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.types.{IntegerType, StructType} * Params for KMeans */ -private[clustering] trait KMeansParams extends KMeansModelParams with HasInitialModel[KMeansModel] { +private[clustering] trait KMeansInitialModelParams extends HasInitialModel[KMeansModel] { /** * Param for KMeansModel to use for warm start. * Whenever initialModel is set: @@ -58,7 +58,7 @@ private[clustering] trait KMeansParams extends KMeansModelParams with HasInitial /** * Params for KMeansModel */ -private[clustering] trait KMeansModelParams extends Params with HasMaxIter with HasFeaturesCol +private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol with HasTol { /** @@ -127,7 +127,7 @@ private[clustering] trait KMeansModelParams extends Params with HasMaxIter with class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, private[ml] val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansModelParams with MLWritable { + extends Model[KMeansModel] with KMeansParams with MLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -283,7 +283,8 @@ object KMeansModel extends MLReadable[KMeansModel] { @Experimental class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { + extends Estimator[KMeansModel] + with KMeansParams with KMeansInitialModelParams with DefaultParamsWritable { setDefault( k -> 2,