1717
1818package org .apache .spark .ml .tuning
1919
20- import java .util .{List => JList }
20+ import java .util .{List => JList , Locale }
2121
2222import scala .collection .JavaConverters ._
2323import scala .concurrent .Future
@@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging
3131import org .apache .spark .ml .{Estimator , Model }
3232import org .apache .spark .ml .evaluation .Evaluator
3333import org .apache .spark .ml .param .{IntParam , ParamMap , ParamValidators }
34- import org .apache .spark .ml .param .shared .HasParallelism
34+ import org .apache .spark .ml .param .shared .{ HasCollectSubModels , HasParallelism }
3535import org .apache .spark .ml .util ._
3636import org .apache .spark .mllib .util .MLUtils
3737import org .apache .spark .sql .{DataFrame , Dataset }
@@ -67,7 +67,8 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
6767@ Since (" 1.2.0" )
6868class CrossValidator @ Since (" 1.2.0" ) (@ Since (" 1.4.0" ) override val uid : String )
6969 extends Estimator [CrossValidatorModel ]
70- with CrossValidatorParams with HasParallelism with MLWritable with Logging {
70+ with CrossValidatorParams with HasParallelism with HasCollectSubModels
71+ with MLWritable with Logging {
7172
7273 @ Since (" 1.2.0" )
7374 def this () = this (Identifiable .randomUID(" cv" ))
@@ -101,6 +102,21 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
101102 @ Since (" 2.3.0" )
102103 def setParallelism (value : Int ): this .type = set(parallelism, value)
103104
105+ /**
106+ * Whether to collect submodels when fitting. If set, we can get submodels from
107+ * the returned model.
108+ *
109+ * Note: If set this param, when you save the returned model, you can set an option
110+ * "persistSubModels" to be "true" before saving, in order to save these submodels.
111+ * You can check documents of
112+ * {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter}
113+ * for more information.
114+ *
115+ * @group expertSetParam
116+ */
117+ @ Since (" 2.3.0" )
118+ def setCollectSubModels (value : Boolean ): this .type = set(collectSubModels, value)
119+
104120 @ Since (" 2.0.0" )
105121 override def fit (dataset : Dataset [_]): CrossValidatorModel = {
106122 val schema = dataset.schema
@@ -117,6 +133,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
117133 instr.logParams(numFolds, seed, parallelism)
118134 logTuningParams(instr)
119135
136+ val collectSubModelsParam = $(collectSubModels)
137+
138+ var subModels : Option [Array [Array [Model [_]]]] = if (collectSubModelsParam) {
139+ Some (Array .fill($(numFolds))(Array .fill[Model [_]](epm.length)(null )))
140+ } else None
141+
120142 // Compute metrics for each model over each split
121143 val splits = MLUtils .kFold(dataset.toDF.rdd, $(numFolds), $(seed))
122144 val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
@@ -125,10 +147,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
125147 logDebug(s " Train split $splitIndex with multiple sets of parameters. " )
126148
127149 // Fit models in a Future for training in parallel
128- val modelFutures = epm.map { paramMap =>
150+ val modelFutures = epm.zipWithIndex. map { case ( paramMap, paramIndex) =>
129151 Future [Model [_]] {
130- val model = est.fit(trainingDataset, paramMap)
131- model.asInstanceOf [Model [_]]
152+ val model = est.fit(trainingDataset, paramMap).asInstanceOf [Model [_]]
153+
154+ if (collectSubModelsParam) {
155+ subModels.get(splitIndex)(paramIndex) = model
156+ }
157+ model
132158 } (executionContext)
133159 }
134160
@@ -160,7 +186,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
160186 logInfo(s " Best cross-validation metric: $bestMetric. " )
161187 val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf [Model [_]]
162188 instr.logSuccess(bestModel)
163- copyValues(new CrossValidatorModel (uid, bestModel, metrics).setParent(this ))
189+ copyValues(new CrossValidatorModel (uid, bestModel, metrics)
190+ .setSubModels(subModels).setParent(this ))
164191 }
165192
166193 @ Since (" 1.4.0" )
@@ -244,6 +271,31 @@ class CrossValidatorModel private[ml] (
244271 this (uid, bestModel, avgMetrics.asScala.toArray)
245272 }
246273
274+ private var _subModels : Option [Array [Array [Model [_]]]] = None
275+
276+ private [tuning] def setSubModels (subModels : Option [Array [Array [Model [_]]]])
277+ : CrossValidatorModel = {
278+ _subModels = subModels
279+ this
280+ }
281+
282+ /**
283+ * @return submodels represented in two dimension array. The index of outer array is the
284+ * fold index, and the index of inner array corresponds to the ordering of
285+ * estimatorParamMaps
286+ * @throws IllegalArgumentException if subModels are not available. To retrieve subModels,
287+ * make sure to set collectSubModels to true before fitting.
288+ */
289+ @ Since (" 2.3.0" )
290+ def subModels : Array [Array [Model [_]]] = {
291+ require(_subModels.isDefined, " subModels not available, To retrieve subModels, make sure " +
292+ " to set collectSubModels to true before fitting." )
293+ _subModels.get
294+ }
295+
296+ @ Since (" 2.3.0" )
297+ def hasSubModels : Boolean = _subModels.isDefined
298+
247299 @ Since (" 2.0.0" )
248300 override def transform (dataset : Dataset [_]): DataFrame = {
249301 transformSchema(dataset.schema, logging = true )
@@ -260,34 +312,76 @@ class CrossValidatorModel private[ml] (
260312 val copied = new CrossValidatorModel (
261313 uid,
262314 bestModel.copy(extra).asInstanceOf [Model [_]],
263- avgMetrics.clone())
315+ avgMetrics.clone()
316+ ).setSubModels(CrossValidatorModel .copySubModels(_subModels))
264317 copyValues(copied, extra).setParent(parent)
265318 }
266319
267320 @ Since (" 1.6.0" )
268- override def write : MLWriter = new CrossValidatorModel .CrossValidatorModelWriter (this )
321+ override def write : CrossValidatorModel .CrossValidatorModelWriter = {
322+ new CrossValidatorModel .CrossValidatorModelWriter (this )
323+ }
269324}
270325
271326@ Since (" 1.6.0" )
272327object CrossValidatorModel extends MLReadable [CrossValidatorModel ] {
273328
329+ private [CrossValidatorModel ] def copySubModels (subModels : Option [Array [Array [Model [_]]]])
330+ : Option [Array [Array [Model [_]]]] = {
331+ subModels.map(_.map(_.map(_.copy(ParamMap .empty).asInstanceOf [Model [_]])))
332+ }
333+
274334 @ Since (" 1.6.0" )
275335 override def read : MLReader [CrossValidatorModel ] = new CrossValidatorModelReader
276336
277337 @ Since (" 1.6.0" )
278338 override def load (path : String ): CrossValidatorModel = super .load(path)
279339
280- private [CrossValidatorModel ]
281- class CrossValidatorModelWriter (instance : CrossValidatorModel ) extends MLWriter {
340+ /**
341+ * Writer for CrossValidatorModel.
342+ * @param instance CrossValidatorModel instance used to construct the writer
343+ *
344+ * CrossValidatorModelWriter supports an option "persistSubModels", with possible values
345+ * "true" or "false". If you set the collectSubModels Param before fitting, then you can
346+ * set "persistSubModels" to "true" in order to persist the subModels. By default,
347+ * "persistSubModels" will be "true" when subModels are available and "false" otherwise.
348+ * If subModels are not available, then setting "persistSubModels" to "true" will cause
349+ * an exception.
350+ */
351+ @ Since (" 2.3.0" )
352+ final class CrossValidatorModelWriter private [tuning] (
353+ instance : CrossValidatorModel ) extends MLWriter {
282354
283355 ValidatorParams .validateParams(instance)
284356
285357 override protected def saveImpl (path : String ): Unit = {
358+ val persistSubModelsParam = optionMap.getOrElse(" persistsubmodels" ,
359+ if (instance.hasSubModels) " true" else " false" )
360+
361+ require(Array (" true" , " false" ).contains(persistSubModelsParam.toLowerCase(Locale .ROOT )),
362+ s " persistSubModels option value ${persistSubModelsParam} is invalid, the possible " +
363+ " values are \" true\" or \" false\" " )
364+ val persistSubModels = persistSubModelsParam.toBoolean
365+
286366 import org .json4s .JsonDSL ._
287- val extraMetadata = " avgMetrics" -> instance.avgMetrics.toSeq
367+ val extraMetadata = (" avgMetrics" -> instance.avgMetrics.toSeq) ~
368+ (" persistSubModels" -> persistSubModels)
288369 ValidatorParams .saveImpl(path, instance, sc, Some (extraMetadata))
289370 val bestModelPath = new Path (path, " bestModel" ).toString
290371 instance.bestModel.asInstanceOf [MLWritable ].save(bestModelPath)
372+ if (persistSubModels) {
373+ require(instance.hasSubModels, " When persisting tuning models, you can only set " +
374+ " persistSubModels to true if the tuning was done with collectSubModels set to true. " +
375+ " To save the sub-models, try rerunning fitting with collectSubModels set to true." )
376+ val subModelsPath = new Path (path, " subModels" )
377+ for (splitIndex <- 0 until instance.getNumFolds) {
378+ val splitPath = new Path (subModelsPath, s " fold ${splitIndex.toString}" )
379+ for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
380+ val modelPath = new Path (splitPath, paramIndex.toString).toString
381+ instance.subModels(splitIndex)(paramIndex).asInstanceOf [MLWritable ].save(modelPath)
382+ }
383+ }
384+ }
291385 }
292386 }
293387
@@ -301,11 +395,30 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
301395
302396 val (metadata, estimator, evaluator, estimatorParamMaps) =
303397 ValidatorParams .loadImpl(path, sc, className)
398+ val numFolds = (metadata.params \ " numFolds" ).extract[Int ]
304399 val bestModelPath = new Path (path, " bestModel" ).toString
305400 val bestModel = DefaultParamsReader .loadParamsInstance[Model [_]](bestModelPath, sc)
306401 val avgMetrics = (metadata.metadata \ " avgMetrics" ).extract[Seq [Double ]].toArray
402+ val persistSubModels = (metadata.metadata \ " persistSubModels" )
403+ .extractOrElse[Boolean ](false )
404+
405+ val subModels : Option [Array [Array [Model [_]]]] = if (persistSubModels) {
406+ val subModelsPath = new Path (path, " subModels" )
407+ val _subModels = Array .fill(numFolds)(Array .fill[Model [_]](
408+ estimatorParamMaps.length)(null ))
409+ for (splitIndex <- 0 until numFolds) {
410+ val splitPath = new Path (subModelsPath, s " fold ${splitIndex.toString}" )
411+ for (paramIndex <- 0 until estimatorParamMaps.length) {
412+ val modelPath = new Path (splitPath, paramIndex.toString).toString
413+ _subModels(splitIndex)(paramIndex) =
414+ DefaultParamsReader .loadParamsInstance(modelPath, sc)
415+ }
416+ }
417+ Some (_subModels)
418+ } else None
307419
308420 val model = new CrossValidatorModel (metadata.uid, bestModel, avgMetrics)
421+ .setSubModels(subModels)
309422 model.set(model.estimator, estimator)
310423 .set(model.evaluator, evaluator)
311424 .set(model.estimatorParamMaps, estimatorParamMaps)
0 commit comments