Skip to content

Commit 7743980

Browse files
WeichenXu123jkbradley
authored andcommitted
[SPARK-21087][ML] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala
## What changes were proposed in this pull request? We add a parameter whether to collect the full model list when CrossValidator/TrainValidationSplit training (Default is NOT), avoid the change cause OOM) - Add a method in CrossValidatorModel/TrainValidationSplitModel, allow user to get the model list - CrossValidatorModelWriter add a “option”, allow user to control whether to persist the model list to disk (will persist by default). - Note: when persisting the model list, use indices as the sub-model path ## How was this patch tested? Test cases added. Author: WeichenXu <[email protected]> Closes #19208 from WeichenXu123/expose-model-list.
1 parent b009722 commit 7743980

File tree

8 files changed

+388
-29
lines changed

8 files changed

+388
-29
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,13 @@ private[shared] object SharedParamsCodeGen {
8383
"all instance weights as 1.0"),
8484
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
8585
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
86-
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
86+
isValid = "ParamValidators.gtEq(2)", isExpertParam = true),
87+
ParamDesc[Boolean]("collectSubModels", "If set to false, then only the single best " +
88+
"sub-model will be available after fitting. If set to true, then all sub-models will be " +
89+
"available. Warning: For large models, collecting all sub-models can cause OOMs on the " +
90+
"Spark driver.",
91+
Some("false"), isExpertParam = true)
92+
)
8793

8894
val code = genSharedParams(params)
8995
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,4 +468,21 @@ trait HasAggregationDepth extends Params {
468468
/** @group expertGetParam */
469469
final def getAggregationDepth: Int = $(aggregationDepth)
470470
}
471+
472+
/**
473+
* Trait for shared param collectSubModels (default: false).
474+
*/
475+
private[ml] trait HasCollectSubModels extends Params {
476+
477+
/**
478+
* Param for whether to collect a list of sub-models trained during tuning.
479+
* @group expertParam
480+
*/
481+
final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect a list of sub-models trained during tuning")
482+
483+
setDefault(collectSubModels, false)
484+
485+
/** @group expertGetParam */
486+
final def getCollectSubModels: Boolean = $(collectSubModels)
487+
}
471488
// scalastyle:on

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 125 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ml.tuning
1919

20-
import java.util.{List => JList}
20+
import java.util.{List => JList, Locale}
2121

2222
import scala.collection.JavaConverters._
2323
import scala.concurrent.Future
@@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging
3131
import org.apache.spark.ml.{Estimator, Model}
3232
import org.apache.spark.ml.evaluation.Evaluator
3333
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
34-
import org.apache.spark.ml.param.shared.HasParallelism
34+
import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism}
3535
import org.apache.spark.ml.util._
3636
import org.apache.spark.mllib.util.MLUtils
3737
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -67,7 +67,8 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
6767
@Since("1.2.0")
6868
class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
6969
extends Estimator[CrossValidatorModel]
70-
with CrossValidatorParams with HasParallelism with MLWritable with Logging {
70+
with CrossValidatorParams with HasParallelism with HasCollectSubModels
71+
with MLWritable with Logging {
7172

7273
@Since("1.2.0")
7374
def this() = this(Identifiable.randomUID("cv"))
@@ -101,6 +102,21 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
101102
@Since("2.3.0")
102103
def setParallelism(value: Int): this.type = set(parallelism, value)
103104

105+
/**
106+
* Whether to collect submodels when fitting. If set, we can get submodels from
107+
* the returned model.
108+
*
109+
* Note: If set this param, when you save the returned model, you can set an option
110+
* "persistSubModels" to be "true" before saving, in order to save these submodels.
111+
* You can check documents of
112+
* {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter}
113+
* for more information.
114+
*
115+
* @group expertSetParam
116+
*/
117+
@Since("2.3.0")
118+
def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value)
119+
104120
@Since("2.0.0")
105121
override def fit(dataset: Dataset[_]): CrossValidatorModel = {
106122
val schema = dataset.schema
@@ -117,6 +133,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
117133
instr.logParams(numFolds, seed, parallelism)
118134
logTuningParams(instr)
119135

136+
val collectSubModelsParam = $(collectSubModels)
137+
138+
var subModels: Option[Array[Array[Model[_]]]] = if (collectSubModelsParam) {
139+
Some(Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null)))
140+
} else None
141+
120142
// Compute metrics for each model over each split
121143
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
122144
val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
@@ -125,10 +147,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
125147
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
126148

127149
// Fit models in a Future for training in parallel
128-
val modelFutures = epm.map { paramMap =>
150+
val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
129151
Future[Model[_]] {
130-
val model = est.fit(trainingDataset, paramMap)
131-
model.asInstanceOf[Model[_]]
152+
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
153+
154+
if (collectSubModelsParam) {
155+
subModels.get(splitIndex)(paramIndex) = model
156+
}
157+
model
132158
} (executionContext)
133159
}
134160

@@ -160,7 +186,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
160186
logInfo(s"Best cross-validation metric: $bestMetric.")
161187
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
162188
instr.logSuccess(bestModel)
163-
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
189+
copyValues(new CrossValidatorModel(uid, bestModel, metrics)
190+
.setSubModels(subModels).setParent(this))
164191
}
165192

166193
@Since("1.4.0")
@@ -244,6 +271,31 @@ class CrossValidatorModel private[ml] (
244271
this(uid, bestModel, avgMetrics.asScala.toArray)
245272
}
246273

274+
private var _subModels: Option[Array[Array[Model[_]]]] = None
275+
276+
private[tuning] def setSubModels(subModels: Option[Array[Array[Model[_]]]])
277+
: CrossValidatorModel = {
278+
_subModels = subModels
279+
this
280+
}
281+
282+
/**
283+
* @return submodels represented in two dimension array. The index of outer array is the
284+
* fold index, and the index of inner array corresponds to the ordering of
285+
* estimatorParamMaps
286+
* @throws IllegalArgumentException if subModels are not available. To retrieve subModels,
287+
* make sure to set collectSubModels to true before fitting.
288+
*/
289+
@Since("2.3.0")
290+
def subModels: Array[Array[Model[_]]] = {
291+
require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " +
292+
"to set collectSubModels to true before fitting.")
293+
_subModels.get
294+
}
295+
296+
@Since("2.3.0")
297+
def hasSubModels: Boolean = _subModels.isDefined
298+
247299
@Since("2.0.0")
248300
override def transform(dataset: Dataset[_]): DataFrame = {
249301
transformSchema(dataset.schema, logging = true)
@@ -260,34 +312,76 @@ class CrossValidatorModel private[ml] (
260312
val copied = new CrossValidatorModel(
261313
uid,
262314
bestModel.copy(extra).asInstanceOf[Model[_]],
263-
avgMetrics.clone())
315+
avgMetrics.clone()
316+
).setSubModels(CrossValidatorModel.copySubModels(_subModels))
264317
copyValues(copied, extra).setParent(parent)
265318
}
266319

267320
@Since("1.6.0")
268-
override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this)
321+
override def write: CrossValidatorModel.CrossValidatorModelWriter = {
322+
new CrossValidatorModel.CrossValidatorModelWriter(this)
323+
}
269324
}
270325

271326
@Since("1.6.0")
272327
object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
273328

329+
private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]])
330+
: Option[Array[Array[Model[_]]]] = {
331+
subModels.map(_.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]])))
332+
}
333+
274334
@Since("1.6.0")
275335
override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
276336

277337
@Since("1.6.0")
278338
override def load(path: String): CrossValidatorModel = super.load(path)
279339

280-
private[CrossValidatorModel]
281-
class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
340+
/**
341+
* Writer for CrossValidatorModel.
342+
* @param instance CrossValidatorModel instance used to construct the writer
343+
*
344+
* CrossValidatorModelWriter supports an option "persistSubModels", with possible values
345+
* "true" or "false". If you set the collectSubModels Param before fitting, then you can
346+
* set "persistSubModels" to "true" in order to persist the subModels. By default,
347+
* "persistSubModels" will be "true" when subModels are available and "false" otherwise.
348+
* If subModels are not available, then setting "persistSubModels" to "true" will cause
349+
* an exception.
350+
*/
351+
@Since("2.3.0")
352+
final class CrossValidatorModelWriter private[tuning] (
353+
instance: CrossValidatorModel) extends MLWriter {
282354

283355
ValidatorParams.validateParams(instance)
284356

285357
override protected def saveImpl(path: String): Unit = {
358+
val persistSubModelsParam = optionMap.getOrElse("persistsubmodels",
359+
if (instance.hasSubModels) "true" else "false")
360+
361+
require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)),
362+
s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " +
363+
"values are \"true\" or \"false\"")
364+
val persistSubModels = persistSubModelsParam.toBoolean
365+
286366
import org.json4s.JsonDSL._
287-
val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
367+
val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~
368+
("persistSubModels" -> persistSubModels)
288369
ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
289370
val bestModelPath = new Path(path, "bestModel").toString
290371
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
372+
if (persistSubModels) {
373+
require(instance.hasSubModels, "When persisting tuning models, you can only set " +
374+
"persistSubModels to true if the tuning was done with collectSubModels set to true. " +
375+
"To save the sub-models, try rerunning fitting with collectSubModels set to true.")
376+
val subModelsPath = new Path(path, "subModels")
377+
for (splitIndex <- 0 until instance.getNumFolds) {
378+
val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}")
379+
for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
380+
val modelPath = new Path(splitPath, paramIndex.toString).toString
381+
instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath)
382+
}
383+
}
384+
}
291385
}
292386
}
293387

@@ -301,11 +395,30 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
301395

302396
val (metadata, estimator, evaluator, estimatorParamMaps) =
303397
ValidatorParams.loadImpl(path, sc, className)
398+
val numFolds = (metadata.params \ "numFolds").extract[Int]
304399
val bestModelPath = new Path(path, "bestModel").toString
305400
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
306401
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
402+
val persistSubModels = (metadata.metadata \ "persistSubModels")
403+
.extractOrElse[Boolean](false)
404+
405+
val subModels: Option[Array[Array[Model[_]]]] = if (persistSubModels) {
406+
val subModelsPath = new Path(path, "subModels")
407+
val _subModels = Array.fill(numFolds)(Array.fill[Model[_]](
408+
estimatorParamMaps.length)(null))
409+
for (splitIndex <- 0 until numFolds) {
410+
val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}")
411+
for (paramIndex <- 0 until estimatorParamMaps.length) {
412+
val modelPath = new Path(splitPath, paramIndex.toString).toString
413+
_subModels(splitIndex)(paramIndex) =
414+
DefaultParamsReader.loadParamsInstance(modelPath, sc)
415+
}
416+
}
417+
Some(_subModels)
418+
} else None
307419

308420
val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
421+
.setSubModels(subModels)
309422
model.set(model.estimator, estimator)
310423
.set(model.evaluator, evaluator)
311424
.set(model.estimatorParamMaps, estimatorParamMaps)

0 commit comments

Comments
 (0)