Skip to content

Commit 8c11d1a

Browse files
yinxusenjkbradley
authored andcommitted
[SPARK-11893] Model export/import for spark.ml: TrainValidationSplit
https://issues.apache.org/jira/browse/SPARK-11893 jkbradley In order to share read/write with `TrainValidationSplit`, I move the `SharedReadWrite` out of `CrossValidator` into a new trait `SharedReadWrite` in the tunning package. To reduce the repeated tests, I move the complex tests from `CrossValidatorSuite` to `SharedReadWriteSuite`, and create a fake validator called `MyValidator` to test the shared code. With `SharedReadWrite`, potential newly added `Validator` can share the read/write common part, and only need to implement their extra params save/load. Author: Xusen Yin <[email protected]> Author: Joseph K. Bradley <[email protected]> Closes #9971 from yinxusen/SPARK-11893.
1 parent 39f743a commit 8c11d1a

File tree

5 files changed

+310
-142
lines changed

5 files changed

+310
-142
lines changed

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 14 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,27 @@ package org.apache.spark.ml.tuning
1919

2020
import com.github.fommil.netlib.F2jBLAS
2121
import org.apache.hadoop.fs.Path
22-
import org.json4s.{DefaultFormats, JObject}
23-
import org.json4s.jackson.JsonMethods._
22+
import org.json4s.DefaultFormats
2423

25-
import org.apache.spark.SparkContext
2624
import org.apache.spark.annotation.{Experimental, Since}
2725
import org.apache.spark.internal.Logging
2826
import org.apache.spark.ml._
29-
import org.apache.spark.ml.classification.OneVsRestParams
3027
import org.apache.spark.ml.evaluation.Evaluator
31-
import org.apache.spark.ml.feature.RFormulaModel
3228
import org.apache.spark.ml.param._
3329
import org.apache.spark.ml.param.shared.HasSeed
3430
import org.apache.spark.ml.util._
35-
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
3631
import org.apache.spark.mllib.util.MLUtils
3732
import org.apache.spark.sql.DataFrame
3833
import org.apache.spark.sql.types.StructType
3934

40-
4135
/**
4236
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
4337
*/
4438
private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
4539
/**
4640
* Param for number of folds for cross validation. Must be >= 2.
4741
* Default: 3
42+
*
4843
* @group param
4944
*/
5045
val numFolds: IntParam = new IntParam(this, "numFolds",
@@ -163,10 +158,10 @@ object CrossValidator extends MLReadable[CrossValidator] {
163158

164159
private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter {
165160

166-
SharedReadWrite.validateParams(instance)
161+
ValidatorParams.validateParams(instance)
167162

168163
override protected def saveImpl(path: String): Unit =
169-
SharedReadWrite.saveImpl(path, instance, sc)
164+
ValidatorParams.saveImpl(path, instance, sc)
170165
}
171166

172167
private class CrossValidatorReader extends MLReader[CrossValidator] {
@@ -175,132 +170,18 @@ object CrossValidator extends MLReadable[CrossValidator] {
175170
private val className = classOf[CrossValidator].getName
176171

177172
override def load(path: String): CrossValidator = {
178-
val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
179-
SharedReadWrite.load(path, sc, className)
173+
implicit val format = DefaultFormats
174+
175+
val (metadata, estimator, evaluator, estimatorParamMaps) =
176+
ValidatorParams.loadImpl(path, sc, className)
177+
val numFolds = (metadata.params \ "numFolds").extract[Int]
180178
new CrossValidator(metadata.uid)
181179
.setEstimator(estimator)
182180
.setEvaluator(evaluator)
183181
.setEstimatorParamMaps(estimatorParamMaps)
184182
.setNumFolds(numFolds)
185183
}
186184
}
187-
188-
private object CrossValidatorReader {
189-
/**
190-
* Examine the given estimator (which may be a compound estimator) and extract a mapping
191-
* from UIDs to corresponding [[Params]] instances.
192-
*/
193-
def getUidMap(instance: Params): Map[String, Params] = {
194-
val uidList = getUidMapImpl(instance)
195-
val uidMap = uidList.toMap
196-
if (uidList.size != uidMap.size) {
197-
throw new RuntimeException("CrossValidator.load found a compound estimator with stages" +
198-
s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}")
199-
}
200-
uidMap
201-
}
202-
203-
def getUidMapImpl(instance: Params): List[(String, Params)] = {
204-
val subStages: Array[Params] = instance match {
205-
case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
206-
case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
207-
case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
208-
case ovr: OneVsRestParams =>
209-
// TODO: SPARK-11892: This case may require special handling.
210-
throw new UnsupportedOperationException("CrossValidator write will fail because it" +
211-
" cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
212-
case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
213-
case _: Params => Array()
214-
}
215-
val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
216-
List((instance.uid, instance)) ++ subStageMaps
217-
}
218-
}
219-
220-
private[tuning] object SharedReadWrite {
221-
222-
/**
223-
* Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable.
224-
* This does not check [[CrossValidator.estimatorParamMaps]].
225-
*/
226-
def validateParams(instance: ValidatorParams): Unit = {
227-
def checkElement(elem: Params, name: String): Unit = elem match {
228-
case stage: MLWritable => // good
229-
case other =>
230-
throw new UnsupportedOperationException("CrossValidator write will fail " +
231-
s" because it contains $name which does not implement Writable." +
232-
s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
233-
}
234-
checkElement(instance.getEvaluator, "evaluator")
235-
checkElement(instance.getEstimator, "estimator")
236-
// Check to make sure all Params apply to this estimator. Throw an error if any do not.
237-
// Extraneous Params would cause problems when loading the estimatorParamMaps.
238-
val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance)
239-
instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
240-
pMap.toSeq.foreach { case ParamPair(p, v) =>
241-
require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" +
242-
s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" +
243-
s" Evaluator. An extraneous Param was found: $p")
244-
}
245-
}
246-
}
247-
248-
private[tuning] def saveImpl(
249-
path: String,
250-
instance: CrossValidatorParams,
251-
sc: SparkContext,
252-
extraMetadata: Option[JObject] = None): Unit = {
253-
import org.json4s.JsonDSL._
254-
255-
val estimatorParamMapsJson = compact(render(
256-
instance.getEstimatorParamMaps.map { case paramMap =>
257-
paramMap.toSeq.map { case ParamPair(p, v) =>
258-
Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
259-
}
260-
}.toSeq
261-
))
262-
val jsonParams = List(
263-
"numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
264-
"estimatorParamMaps" -> parse(estimatorParamMapsJson)
265-
)
266-
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
267-
268-
val evaluatorPath = new Path(path, "evaluator").toString
269-
instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
270-
val estimatorPath = new Path(path, "estimator").toString
271-
instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
272-
}
273-
274-
private[tuning] def load[M <: Model[M]](
275-
path: String,
276-
sc: SparkContext,
277-
expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = {
278-
279-
val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
280-
281-
implicit val format = DefaultFormats
282-
val evaluatorPath = new Path(path, "evaluator").toString
283-
val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
284-
val estimatorPath = new Path(path, "estimator").toString
285-
val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
286-
287-
val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator)
288-
289-
val numFolds = (metadata.params \ "numFolds").extract[Int]
290-
val estimatorParamMaps: Array[ParamMap] =
291-
(metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
292-
pMap =>
293-
val paramPairs = pMap.map { case pInfo: Map[String, String] =>
294-
val est = uidToParams(pInfo("parent"))
295-
val param = est.getParam(pInfo("name"))
296-
val value = param.jsonDecode(pInfo("value"))
297-
param -> value
298-
}
299-
ParamMap(paramPairs: _*)
300-
}.toArray
301-
(metadata, estimator, evaluator, estimatorParamMaps, numFolds)
302-
}
303-
}
304185
}
305186

306187
/**
@@ -346,8 +227,6 @@ class CrossValidatorModel private[ml] (
346227
@Since("1.6.0")
347228
object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
348229

349-
import CrossValidator.SharedReadWrite
350-
351230
@Since("1.6.0")
352231
override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
353232

@@ -357,12 +236,12 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
357236
private[CrossValidatorModel]
358237
class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
359238

360-
SharedReadWrite.validateParams(instance)
239+
ValidatorParams.validateParams(instance)
361240

362241
override protected def saveImpl(path: String): Unit = {
363242
import org.json4s.JsonDSL._
364243
val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
365-
SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
244+
ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
366245
val bestModelPath = new Path(path, "bestModel").toString
367246
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
368247
}
@@ -376,8 +255,9 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
376255
override def load(path: String): CrossValidatorModel = {
377256
implicit val format = DefaultFormats
378257

379-
val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
380-
SharedReadWrite.load(path, sc, className)
258+
val (metadata, estimator, evaluator, estimatorParamMaps) =
259+
ValidatorParams.loadImpl(path, sc, className)
260+
val numFolds = (metadata.params \ "numFolds").extract[Int]
381261
val bestModelPath = new Path(path, "bestModel").toString
382262
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
383263
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray

mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717

1818
package org.apache.spark.ml.tuning
1919

20+
import org.apache.hadoop.fs.Path
21+
import org.json4s.DefaultFormats
22+
2023
import org.apache.spark.annotation.{Experimental, Since}
2124
import org.apache.spark.internal.Logging
2225
import org.apache.spark.ml.{Estimator, Model}
2326
import org.apache.spark.ml.evaluation.Evaluator
2427
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
25-
import org.apache.spark.ml.util.Identifiable
28+
import org.apache.spark.ml.util._
2629
import org.apache.spark.sql.DataFrame
2730
import org.apache.spark.sql.types.StructType
2831

@@ -33,6 +36,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
3336
/**
3437
* Param for ratio between train and validation data. Must be between 0 and 1.
3538
* Default: 0.75
39+
*
3640
* @group param
3741
*/
3842
val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio",
@@ -55,7 +59,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
5559
@Experimental
5660
class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String)
5761
extends Estimator[TrainValidationSplitModel]
58-
with TrainValidationSplitParams with Logging {
62+
with TrainValidationSplitParams with MLWritable with Logging {
5963

6064
@Since("1.5.0")
6165
def this() = this(Identifiable.randomUID("tvs"))
@@ -130,6 +134,47 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
130134
}
131135
copied
132136
}
137+
138+
@Since("2.0.0")
139+
override def write: MLWriter = new TrainValidationSplit.TrainValidationSplitWriter(this)
140+
}
141+
142+
@Since("2.0.0")
143+
object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
144+
145+
@Since("2.0.0")
146+
override def read: MLReader[TrainValidationSplit] = new TrainValidationSplitReader
147+
148+
@Since("2.0.0")
149+
override def load(path: String): TrainValidationSplit = super.load(path)
150+
151+
private[TrainValidationSplit] class TrainValidationSplitWriter(instance: TrainValidationSplit)
152+
extends MLWriter {
153+
154+
ValidatorParams.validateParams(instance)
155+
156+
override protected def saveImpl(path: String): Unit =
157+
ValidatorParams.saveImpl(path, instance, sc)
158+
}
159+
160+
private class TrainValidationSplitReader extends MLReader[TrainValidationSplit] {
161+
162+
/** Checked against metadata when loading model */
163+
private val className = classOf[TrainValidationSplit].getName
164+
165+
override def load(path: String): TrainValidationSplit = {
166+
implicit val format = DefaultFormats
167+
168+
val (metadata, estimator, evaluator, estimatorParamMaps) =
169+
ValidatorParams.loadImpl(path, sc, className)
170+
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
171+
new TrainValidationSplit(metadata.uid)
172+
.setEstimator(estimator)
173+
.setEvaluator(evaluator)
174+
.setEstimatorParamMaps(estimatorParamMaps)
175+
.setTrainRatio(trainRatio)
176+
}
177+
}
133178
}
134179

135180
/**
@@ -146,7 +191,7 @@ class TrainValidationSplitModel private[ml] (
146191
@Since("1.5.0") override val uid: String,
147192
@Since("1.5.0") val bestModel: Model[_],
148193
@Since("1.5.0") val validationMetrics: Array[Double])
149-
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
194+
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable {
150195

151196
@Since("1.5.0")
152197
override def transform(dataset: DataFrame): DataFrame = {
@@ -167,4 +212,53 @@ class TrainValidationSplitModel private[ml] (
167212
validationMetrics.clone())
168213
copyValues(copied, extra)
169214
}
215+
216+
@Since("2.0.0")
217+
override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
218+
}
219+
220+
@Since("2.0.0")
221+
object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
222+
223+
@Since("2.0.0")
224+
override def read: MLReader[TrainValidationSplitModel] = new TrainValidationSplitModelReader
225+
226+
@Since("2.0.0")
227+
override def load(path: String): TrainValidationSplitModel = super.load(path)
228+
229+
private[TrainValidationSplitModel]
230+
class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter {
231+
232+
ValidatorParams.validateParams(instance)
233+
234+
override protected def saveImpl(path: String): Unit = {
235+
import org.json4s.JsonDSL._
236+
val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq
237+
ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
238+
val bestModelPath = new Path(path, "bestModel").toString
239+
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
240+
}
241+
}
242+
243+
private class TrainValidationSplitModelReader extends MLReader[TrainValidationSplitModel] {
244+
245+
/** Checked against metadata when loading model */
246+
private val className = classOf[TrainValidationSplitModel].getName
247+
248+
override def load(path: String): TrainValidationSplitModel = {
249+
implicit val format = DefaultFormats
250+
251+
val (metadata, estimator, evaluator, estimatorParamMaps) =
252+
ValidatorParams.loadImpl(path, sc, className)
253+
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
254+
val bestModelPath = new Path(path, "bestModel").toString
255+
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
256+
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
257+
val tvs = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
258+
tvs.set(tvs.estimator, estimator)
259+
.set(tvs.evaluator, evaluator)
260+
.set(tvs.estimatorParamMaps, estimatorParamMaps)
261+
.set(tvs.trainRatio, trainRatio)
262+
}
263+
}
170264
}

0 commit comments

Comments
 (0)