Skip to content

Commit 2824d85

Browse files
committed
Add MimaExcludes and docs.
1 parent ddd8d86 commit 2824d85

File tree

4 files changed

+18
-5
lines changed

4 files changed

+18
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ object KMeans extends MLReadable[KMeans] {
400400

401401
/** [[MLWriter]] instance for [[KMeans]] */
402402
private[KMeans] class KMeansWriter(instance: KMeans) extends MLWriter {
403+
403404
override protected def saveImpl(path: String): Unit = {
404405
DefaultParamsWriter.saveInitialModel(instance, path)
405406
DefaultParamsWriter.saveMetadata(instance, path, sc)
@@ -408,11 +409,8 @@ object KMeans extends MLReadable[KMeans] {
408409

409410
private class KMeansReader extends MLReader[KMeans] {
410411

411-
/** Checked against metadata when loading estimator */
412-
private val className = classOf[KMeans].getName
413-
414412
override def load(path: String): KMeans = {
415-
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
413+
val metadata = DefaultParamsReader.loadMetadata(path, sc, classOf[KMeans].getName)
416414
val instance = new KMeans(metadata.uid)
417415

418416
DefaultParamsReader.getAndSetParams(instance, metadata)

mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ private[ml] object DefaultParamsWriter {
280280
* Helper for [[saveMetadata()]] which extracts the JSON to save.
281281
* This is useful for ensemble models which need to save metadata for many sub-models.
282282
*
283-
* Note: This function does not handle param `initialModel`.
283+
* Note: This function does not handle param `initialModel`, see [[saveInitialModel()]].
284284
*
285285
* @see [[saveMetadata()]] for details on what this includes.
286286
*/
@@ -311,6 +311,9 @@ private[ml] object DefaultParamsWriter {
311311
metadataJson
312312
}
313313

314+
/**
315+
* Save estimator's `initialModel` to corresponding path.
316+
*/
314317
def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]](
315318
instance: T, path: String): Unit = {
316319
if (instance.isDefined(instance.initialModel)) {
@@ -453,6 +456,12 @@ private[ml] object DefaultParamsReader {
453456
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
454457
}
455458

459+
/**
460+
* Load estimator's `initialModel` instance from the given path, and return it.
461+
* If the `initialModel` path does not exist, it means the estimator does not have or
462+
* set param `initialModel`, then return None.
463+
* This assumes the model implements [[MLReadable]].
464+
*/
456465
def loadInitialModel[M <: Model[M]](path: String, sc: SparkContext): Option[M] = {
457466
val hadoopConf = sc.hadoopConfiguration
458467
val initialModelPath = new Path(path, "initialModel")

mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
114114
testParams.foreach { case (p, v) =>
115115
val param = estimator.getParam(p)
116116
if (param.name == "initialModel") {
117+
// Estimator's `initialModel` has same type as the model produced by this estimator.
118+
// So we can use `checkModelData` to check equality of `initialModel` as well.
117119
checkModelData(estimator.get(param).get.asInstanceOf[M],
118120
estimator2.get(param).get.asInstanceOf[M])
119121
} else {
@@ -123,6 +125,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
123125

124126
// Test Model save/load
125127
val model2 = testDefaultReadWrite(model)
128+
// Model does not extend HasInitialModel, so we don't check it.
126129
testParams.filter(_._1 != "initialModel").foreach { case (p, v) =>
127130
val param = model.getParam(p)
128131
assert(model.get(param).get === model2.get(param).get)

project/MimaExcludes.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,9 @@ object MimaExcludes {
948948
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"),
949949
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
950950
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy")
951+
) ++ Seq(
952+
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.KMeans$"),
953+
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.KMeans")
951954
)
952955
}
953956

0 commit comments

Comments
 (0)