Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 98 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion

/**
* Common params for KMeans and KMeansModel
* Common params for KMeans and KMeansModel.
*/
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
private[clustering] trait KMeansModelParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasTol {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, KMeansModel mixes KMeansModelParams, does it mean in the model level, we can not get the information of the initiModel? Also, in the model, why do we need to mix the seed in?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we decided in the previous discussion to not store the initial model in the produced model, for several reasons, including model serialization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.


/**
Expand All @@ -59,12 +59,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
/**
* Param for the initialization algorithm. This can be either "random" to choose random points as
* initial cluster centers, or "k-means||" to use a parallel variant of k-means++
* (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
* (Bahmani et al., Scalable K-Means++, VLDB 2012), or "initialModel" to use a user provided
* initial model for warm start. Default: k-means||.
* If this was set as "initialModel", users must specify the initial model by `setInitialModel`,
* otherwise, throws IllegalArgumentException.
* @group expertParam
*/
@Since("1.5.0")
final val initMode = new Param[String](this, "initMode", "The initialization algorithm. " +
"Supported options: 'random' and 'k-means||'.",
"Supported options: 'random', 'k-means||' and 'initialModel'.",
(value: String) => MLlibKMeans.validateInitMode(value))

/** @group expertGetParam */
Expand Down Expand Up @@ -95,6 +98,22 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
}
}

/**
* Common params for KMeans.
*/
private[clustering] trait KMeansParams extends KMeansModelParams with HasInitialModel[KMeansModel] {

/**
* A KMeansModel to use for warm start.
* Note the cluster count of initial model must be equal with [[k]],
* otherwise, throws IllegalArgumentException.
* @group param
*/
@Since("2.2.0")
final val initialModel: Param[KMeansModel] =
new Param[KMeansModel](this, "initialModel", "A KMeansModel to use for warm start.")
}

/**
* Model fitted by KMeans.
*
Expand All @@ -103,8 +122,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Since("1.5.0")
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
private val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams with MLWritable {
private[clustering] val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansModelParams with MLWritable {

@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
Expand All @@ -123,7 +142,8 @@ class KMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val predictUDF = udf((vector: Vector) => predict(vector))
val localParent: MLlibKMeansModel = parentModel
val predictUDF = udf((vector: Vector) => localParent.predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

Expand All @@ -132,8 +152,6 @@ class KMeansModel private[ml] (
validateAndTransformSchema(schema)
}

private[clustering] def predict(features: Vector): Int = parentModel.predict(features)

@Since("2.0.0")
def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)

Expand Down Expand Up @@ -253,7 +271,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
@Since("1.5.0")
class KMeans @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable {
extends Estimator[KMeansModel] with KMeansParams with MLWritable {

setDefault(
k -> 2,
Expand Down Expand Up @@ -300,6 +318,10 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)

/** @group setParam */
@Since("2.2.0")
def setInitialModel(value: KMeansModel): this.type = set(initialModel, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about

def setInitialModel(value: KMeansModel): this.type = {
  if (getK ~= value.getK) {
    log the warning
    set(k, value)
  }
  set(initMode, MLlibKMeans.K_MEANS_INITIAL_MODEL) // We may log, but I don't really care for this one.
  set(initialModel, value)
}

Copy link
Contributor Author

@yanboliang yanboliang Mar 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can not override param in any set*** function, since ML pipeline API supports other param setting method like:

def fit(dataset: Dataset[_], paramMap: ParamMap): M = {
    copy(paramMap).fit(dataset)
  }

Users should get the same model regardless of the way to set param. I think the only way to override param is in the start of fit function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate this? I don't fully understand why we can not overwrite setting in set method. Thanks.


@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = {
transformSchema(dataset.schema, logging = true)
Expand All @@ -322,6 +344,18 @@ class KMeans @Since("1.5.0") (
.setMaxIterations($(maxIter))
.setSeed($(seed))
.setEpsilon($(tol))

if ($(initMode) == MLlibKMeans.K_MEANS_INITIAL_MODEL && isSet(initialModel)) {
// Check that the feature dimensions are equal
val numFeatures = instances.first().size
val dimOfInitialModel = $(initialModel).clusterCenters.head.size
require(numFeatures == dimOfInitialModel,
s"The number of features in training dataset is $numFeatures," +
s" which mismatched with dimension of initial model: $dimOfInitialModel.")

algo.setInitialModel($(initialModel).parentModel)
}

val parentModel = algo.run(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
Expand All @@ -335,17 +369,70 @@ class KMeans @Since("1.5.0") (
model
}

/**
* Check validity for interactions between parameters.
*/
private def assertInitialModelValid(): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with overwriting above, the only thing we need to check will be

if ($(initMode) == MLlibKMeans.K_MEANS_INITIAL_MODEL && !isSet(initialModel)) {
  throw new IllegalArgumentException("Users must set param initialModel if you choose " +
           "'initialModel' as the initialization.")
}

we can just have it in the body of fit method.

if ($(initMode) == MLlibKMeans.K_MEANS_INITIAL_MODEL) {
if (isSet(initialModel)) {
val initialModelK = $(initialModel).parentModel.k
if (initialModelK != $(k)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this check is needed if we overwrite k when initialModel is set.

throw new IllegalArgumentException("The initial model's cluster count = " +
s"$initialModelK, mismatched with k = $k.")
}
} else {
throw new IllegalArgumentException("Users must set param initialModel if you choose " +
"'initialModel' as the initialization algorithm.")
}
} else {
if (isSet(initialModel)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this is not needed if we do the overwriting work in setInitialModel.

logWarning(s"Param initialModel will take no effect when initMode is $initMode.")
}
}
}

@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
assertInitialModelValid()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is not checked in fit?

Copy link
Contributor

@sethah sethah Mar 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformSchema will be called in the fit method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformSchema is called in transform method, and model.transform is called in computing the summary. I think we should fail it earlier instead of checking it in the end. Also, it's implicit that it's being checked when computing summary. We should explicitly check it.

If we have small logic in checking, I'll have those checking code in fit method.

validateAndTransformSchema(schema)
}

@Since("2.2.0")
override def write: MLWriter = new KMeans.KMeansWriter(this)
}

@Since("1.6.0")
object KMeans extends DefaultParamsReadable[KMeans] {
object KMeans extends MLReadable[KMeans] {

@Since("1.6.0")
override def load(path: String): KMeans = super.load(path)

@Since("2.2.0")
override def read: MLReader[KMeans] = new KMeansReader

/** [[MLWriter]] instance for [[KMeans]] */
private[KMeans] class KMeansWriter(instance: KMeans) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveInitialModel(instance, path)
DefaultParamsWriter.saveMetadata(instance, path, sc)
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to move saveInitialModel into saveMetadata and making it more succinct. We can do this for MLWriter, but it's hard for MLReader[T]. Since we need to explicitly pass the type of initialModel as well, so we need refactor MLReader[T] to MLReader[T, M]. However, I think lots of estimators/transformers will not use initialModel, so the extra type [M] does not make sense.

}

private class KMeansReader extends MLReader[KMeans] {

override def load(path: String): KMeans = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, classOf[KMeans].getName)
val instance = new KMeans(metadata.uid)

DefaultParamsReader.getAndSetParams(instance, metadata)
DefaultParamsReader.loadInitialModel[KMeansModel](path, sc) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done as:

 DefaultParamsReader.loadInitialModel[KMeansModel](path, sc).foreach(instance.setInitialModel)

I think it's nicer, but I'm not sure if there is a universal preference for side effects with options in Spark, so I'll leave it to you to decide.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, your suggestion can work well, but I'm more prefer to my way, since it's more clear for developer to understand what happened.

case Some(m) => instance.setInitialModel(m)
case None => // initialModel doesn't exist, do nothing
}
instance
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.param.shared

import org.apache.spark.ml.Model
import org.apache.spark.ml.param._

private[ml] trait HasInitialModel[T <: Model[T]] extends Params {

def initialModel: Param[T]

/** @group getParam */
final def getInitialModel: T = $(initialModel)
}
40 changes: 39 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.param.shared.HasInitialModel
import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -279,6 +280,8 @@ private[ml] object DefaultParamsWriter {
* Helper for [[saveMetadata()]] which extracts the JSON to save.
* This is useful for ensemble models which need to save metadata for many sub-models.
*
* Note: This function does not handle param `initialModel`, see [[saveInitialModel()]].
*
* @see [[saveMetadata()]] for details on what this includes.
*/
def getMetadataToSave(
Expand All @@ -288,7 +291,8 @@ private[ml] object DefaultParamsWriter {
paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val params = instance.extractParamMap().toSeq
.filter(_.param.name != "initialModel").asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList))
Expand All @@ -306,6 +310,23 @@ private[ml] object DefaultParamsWriter {
val metadataJson: String = compact(render(metadata))
metadataJson
}

/**
* Save estimator's `initialModel` to corresponding path.
*/
def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]](
instance: T, path: String): Unit = {
if (instance.isDefined(instance.initialModel)) {
val initialModelPath = new Path(path, "initialModel").toString
val initialModel = instance.getOrDefault(instance.initialModel)
// When saving, only keep the direct initialModel by eliminating possible initialModels of the
// direct initialModel, to avoid unnecessary deep recursion of initialModel.
if (initialModel.hasParam("initialModel")) {
initialModel.clear(initialModel.getParam("initialModel"))
}
initialModel.save(initialModelPath)
}
}
}

/**
Expand Down Expand Up @@ -434,6 +455,23 @@ private[ml] object DefaultParamsReader {
val cls = Utils.classForName(metadata.className)
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
}

/**
* Load estimator's `initialModel` instance from the given path, and return it.
* If the `initialModel` path does not exist, it means the estimator does not have or
* set param `initialModel`, then return None.
* This assumes the model implements [[MLReadable]].
*/
def loadInitialModel[M <: Model[M]](path: String, sc: SparkContext): Option[M] = {
val hadoopConf = sc.hadoopConfiguration
val initialModelPath = new Path(path, "initialModel")
val fs = initialModelPath.getFileSystem(hadoopConf)
if (fs.exists(initialModelPath)) {
Some(loadParamsInstance[M](initialModelPath.toString, sc))
} else {
None
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ object KMeans {
val RANDOM = "random"
@Since("0.8.0")
val K_MEANS_PARALLEL = "k-means||"
@Since("2.2.0")
val K_MEANS_INITIAL_MODEL = "initialModel"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be private I think. That, or we should update the valid options for the setInitializationMode doc. But I think it's best to make it private.


/**
* Trains a k-means model using the given set of parameters.
Expand Down Expand Up @@ -593,6 +595,7 @@ object KMeans {
initMode match {
case KMeans.RANDOM => true
case KMeans.K_MEANS_PARALLEL => true
case KMeans.K_MEANS_INITIAL_MODEL => true
case _ => false
}
}
Expand Down
Loading