-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-21087] [ML] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala #19208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #81685 has finished for PR 19208 at commit
|
|
cc @jkbradley |
| def save(path: String, persistSubModels: Boolean): Unit = { | ||
| write.asInstanceOf[CrossValidatorModel.CrossValidatorModelWriter] | ||
| .persistSubModels(persistSubModels).save(path) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I add this method because the CrossValidatorModelWriter is private. User cannot use it. But I don't know whether there is better solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think users can still access CrossValidatorModelWriter through CrossValidatorModel.write, so the save method is unnecessary.
The private[CrossValidatorModel] annotation on the CrossValidatorModelWriter constructor only means that users can't create instances of the class e.g. via new CrossValidatorModel.CrossValidatorModelWriter(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried model.write.asInstanceOf[CrossValidatorModel.CrossValidatorModelWriter] but cannot pass complier, it is inaccessible.
Do you have some other ways ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussion: Another way I think is adding an interface def option(key: String, value: String) into Writer. cc @jkbradley
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the last suggestion of adding def option(key: String, value: String) to mimic the SQL datasource API.
| .map { case ParamPair(p, v) => | ||
| p.name -> parse(p.jsonEncode(v)) | ||
| }.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson)) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve code here. So that we don't need to add code for each parameter. Now we have 3 new added parameter: (parallelism, collectSubModels, persistSubModelPath), all added only in CV/TVS estimator. The old code here is easy to cause bugs if we forgot to update it when we add new params.
| .setEstimatorParamMaps(estimatorParamMaps) | ||
| .setNumFolds(numFolds) | ||
| .setSeed(seed) | ||
| DefaultParamsReader.getAndSetParams(cv, metadata, skipParams = List("estimatorParamMaps")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use getAndSetParams instead of setting all params manually. This simplify code, and it can keep read/write compatibility.
|
Test build #81686 has finished for PR 19208 at commit
|
|
oh...sorry for that, I integrate @hhbyyh's old PR into this new one, because I found the code "dump models to disk" and "collect models" seem to be cohesive and split them will cause some conflicts when merging. @jkbradley |
|
Synced offline: I hadn't looked carefully and seen the 2 issues had been merged. @WeichenXu123 said he will split the work in 2, adding one parameter first. |
|
@jkbradley I split this PR, removed the code for "dump models to disk", so the PR will be smaller and easier to review. When this PR merged, I will create follow-up PR for "dump models to disk". Thanks! |
|
Test build #81767 has finished for PR 19208 at commit
|
|
Jenkins, test this please. |
|
Test build #81772 has finished for PR 19208 at commit
|
|
It's OK to me to include the "dump model to disk" #18313 in this or other PR (or not). After reading the discussion, I feel it's an overkill to support a feature like this in two ways (keeping in memory and dumping to disk). Allowing user to register a custom action after each batch of If you want to stick to this way which I'm not a fan of, I would only suggest to add the logic to estimate the memory the models will cost and stops the application if OOM is foreseeable. |
smurching
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work, just a couple of comments!
| ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), | ||
| isValid = "ParamValidators.gtEq(2)", isExpertParam = true)) | ||
| isValid = "ParamValidators.gtEq(2)", isExpertParam = true), | ||
| ParamDesc[Boolean]("collectSubModels", "whether to collect sub models when tuning fitting", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: reword "whether to collect sub models when tuning fitting" --> "whether to collect a list of sub-models trained during tuning"
|
|
||
| val collectSubModelsParam = $(collectSubModels) | ||
|
|
||
| var subModels: Array[Array[Model[_]]] = if (collectSubModelsParam) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps use an Option[Array[Model[_]]] instead of setting subModels to null?
| /** A Python-friendly auxiliary constructor. */ | ||
| private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = { | ||
| this(uid, bestModel, avgMetrics.asScala.toArray) | ||
| this(uid, bestModel, avgMetrics.asScala.toArray, null) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See earlier suggestion, use an Option set to None instead of setting the Array to null
| def save(path: String, persistSubModels: Boolean): Unit = { | ||
| write.asInstanceOf[CrossValidatorModel.CrossValidatorModelWriter] | ||
| .persistSubModels(persistSubModels).save(path) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think users can still access CrossValidatorModelWriter through CrossValidatorModel.write, so the save method is unnecessary.
The private[CrossValidatorModel] annotation on the CrossValidatorModelWriter constructor only means that users can't create instances of the class e.g. via new CrossValidatorModel.CrossValidatorModelWriter(...)
| val subModelsPath = new Path(path, "subModels") | ||
| for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { | ||
| val modelPath = new Path(subModelsPath, paramIndex.toString).toString | ||
| instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we clean up/remove the partially-persisted subModels if any of these save() calls fail? E.g. let's say we have four subModels and the first three save() calls succeed but the fourth fails - should we delete the folders for the first three submodels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WeichenXu123 Actually I don't think we have to worry about this; Pipeline persistence doesn't clean up if a stage fails to persist (see Pipeline.scala)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, its a good point. But currently model saving code do not have some exception handling code. e.g, overwrite saving, when save failed, it do not recover the old directory.
I think these things can be done in separated PRs.
cc @jkbradley What' your opinion ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question about cleaning up partially saved models. I agree it'd be nice to do in the future, rather than now.
|
@smurching Thanks! I will update later. And note that I will separate part of this PR to a new PR (the separated part will be a bugfix for #16774 ) |
|
@smurching I will update this PR after #19278 merged. Because now this PR depend on that one. Thanks! |
…st/load bug ## What changes were proposed in this pull request? Currently the param of CrossValidator/TrainValidationSplit persist/loading is hardcoding, which is different with other ML estimators. This cause persist bug for new added `parallelism` param. I refactor related code, avoid hardcoding persist/load param. And in the same time, it solve the `parallelism` persisting bug. This refactoring is very useful because we will add more new params in #19208 , hardcoding param persisting/loading making the thing adding new params very troublesome. ## How was this patch tested? Test added. Author: WeichenXu <[email protected]> Closes #19278 from WeichenXu123/fix-tuning-param-bug.
|
I will update this PR after #19350 get merged. We need to address another issue first. Thanks! |
|
Test build #82233 has finished for PR 19208 at commit
|
77d05f6 to
e009ee1
Compare
|
cc @smurching code updated, thanks! |
|
Test build #82246 has finished for PR 19208 at commit
|
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sending some comments, but I'm not done yet.
One issue is that users will have a hard time discovered the persistSubModels option. I'd recommend we do the following:
- Make the CrossValidatorModelWriter (and TVS writer) public, and add Scala doc to them to describe the option.
- Override the write method in CrossValidatorModel so it return type CrossValidatorModelWriter (rather than a generic MLWriter). That should make it a little easier for users to find the writer option.
- Add a note in the setCollectSubModels method about persistSubModels (to help with discoverability).
| this(uid, bestModel, avgMetrics, null) | ||
| private var _subModels: Option[Array[Array[Model[_]]]] = None | ||
|
|
||
| @Since("2.3.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only use Since annotations for public APIs
| } | ||
|
|
||
| @Since("2.3.0") | ||
| def subModels: Array[Array[Model[_]]] = _subModels.get |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add Scala doc. We'll need to explain what the inner and outer array are and which one corresponds to the ordering of estimatorParamsMaps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can you please add a better Exception message? If submodels are not available, then we should tell users to set the collectSubModels Param before fitting.
| /** | ||
| * Set option for persist sub models. | ||
| * Extra options for CrossValidatorModelWriter, current support "persistSubModels". | ||
| * if sub models exsit, the default value for option "persistSubModels" is "true". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: exsit -> exist
| * `option()` handles extra options. If subclasses need to support extra options, override this | ||
| * method. | ||
| */ | ||
| @Since("2.3.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than overriding this in each subclass, let's have this option() method collect the specified options in a map which is consumed by the subclass when saveImpl() is called.
| @Since("1.6.0") | ||
| object CrossValidatorModel extends MLReadable[CrossValidatorModel] { | ||
|
|
||
| private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]]) = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: state return value explicitly
| object CrossValidatorModel extends MLReadable[CrossValidatorModel] { | ||
|
|
||
| private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]]) = { | ||
| subModels.map { subModels => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be simplified using map?
subModels.map(_.map(_.map(_.copy(...).asInstanceOf[...])))
| import org.json4s.JsonDSL._ | ||
| val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq | ||
| val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~ | ||
| ("shouldPersistSubModels" -> shouldPersistSubModels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have 1 name for this argument: "persistSubModels"
| val bestModelPath = new Path(path, "bestModel").toString | ||
| instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) | ||
| if (shouldPersistSubModels) { | ||
| require(instance.hasSubModels, "Cannot get sub models to persist.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This error message may be unclear. How about adding: "When persisting tuning models, you can only set persistSubModels to true if the tuning was done with collectSubModels set to true. To save the sub-models, try rerunning fitting with collectSubModels set to true."
| require(instance.hasSubModels, "Cannot get sub models to persist.") | ||
| val subModelsPath = new Path(path, "subModels") | ||
| for (splitIndex <- 0 until instance.getNumFolds) { | ||
| val splitPath = new Path(subModelsPath, splitIndex.toString) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about naming this with the string "fold":
splitIndex.toString --> "fold" + splitIndex.toString?
|
Done with review. I mainly reviewed CrossValidator since some comments will apply to TrainValidationSplit as well. Thanks for the PR! |
|
Test build #83469 has finished for PR 19208 at commit
|
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more small comments, thanks!
| /** | ||
| * @return submodels represented in two dimension array. The index of outer array is the | ||
| * fold index, and the index of inner array corresponds to the ordering of | ||
| * estimatorParamsMaps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: estimatorParamMaps
| * fold index, and the index of inner array corresponds to the ordering of | ||
| * estimatorParamsMaps | ||
| * | ||
| * Note: If submodels not available, exception will be thrown. only when we set collectSubModels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reword, and use @throws scaladoc:
@throws IllegalArgumentException if subModels are not available. To retrieve subModels, make sure to set collectSubModels to true before fitting.
(Please fix wording in the error message too.)
| * "persistSubModels" will cause exception. | ||
| */ | ||
| @Since("2.3.0") | ||
| class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although we're making this public, let's not make all of its APIs public. Can you please make the constructor private and make this class final?
| * @param instance CrossValidatorModel instance used to construct the writer | ||
| * | ||
| * Options: | ||
| * CrossValidatorModelWriter support an option "persistSubModels", available value is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix wording:
CrossValidatorModelWriter supports an option "persistSubModels", with possible values "true" or "false". If you set the collectSubModels Param before fitting, then you can set "persistSubModels" to "true" in order to persist the submodels. By default, "persistSubModels" will be "true" when submodels are available and "false" otherwise. If submodels are not available, then setting "persistSubModels" to "true" will cause an exception.
| /** | ||
| * `option()` handles extra options. If subclasses need to support extra options, override this | ||
| * method. | ||
| * Map store extra options for this writer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Map to store"
| protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]() | ||
|
|
||
| /** | ||
| * `option()` handles extra options. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Adds an option to the underlying MLWriter. See the documentation for the specific model's writer for possible options. The option name (key) is case-insensitive."
| } | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| val persistSubModels = optionMap.getOrElse("persistsubmodels", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update this so that, when the valid is not convertible to a Boolean, the user sees an error message which states the invalid value and the possible valid values.
| * Note: If set this param, when you save the returned model, you can set an option | ||
| * "persistSubModels" to be "true" before saving, in order to save these submodels. | ||
| * You can check documents of | ||
| * {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't checked through TrainValidationSplit yet, but please do make sure updates to CrossValidator get applied here (and that the updates are checked for copy errors like this line). Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have done some search to make sure everywhere is checked.
|
Test build #83530 has finished for PR 19208 at commit
|
|
Test build #83534 has finished for PR 19208 at commit
|
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few tiny items left, thanks!
| * If subModels are not available, then setting "persistSubModels" to "true" will cause | ||
| * an exception. | ||
| */ | ||
| final class TrainValidationSplitModelWriter private[tuning] ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since annotation
|
|
||
| @Since("2.0.0") | ||
| override def write: MLWriter = new TrainValidationSplit.TrainValidationSplitWriter(this) | ||
| override def write: TrainValidationSplit.TrainValidationSplitWriter = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this meant to be for TrainValidationSplitModel, not the Estimator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, there is two write method, one for Estimator and another for Model.
We only need to change the return type of write method for model.
| val eval = new BinaryClassificationEvaluator | ||
| val numFolds = 3 | ||
| val subPath = new File(tempDir, "testCrossValidatorSubModels") | ||
| val persistSubModelsPath = new File(subPath, "subModels").toString |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not used
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 more comments about backwards compatibility. Would you mind testing this manually, saving a model from spark 2.2 and then loading it with a build of this PR?
| val bestModelPath = new Path(path, "bestModel").toString | ||
| val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) | ||
| val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray | ||
| val shouldPersistSubModels = (metadata.metadata \ "persistSubModels").extract[Boolean] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized this will not be backwards compatible. Let's make this persistSubModels optional so that we assume it is false if it is not in the metadata.
| val bestModelPath = new Path(path, "bestModel").toString | ||
| val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) | ||
| val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray | ||
| val shouldPersistSubModels = (metadata.metadata \ "persistSubModels").extract[Boolean] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here; let's make this optional
|
Test build #83823 has finished for PR 19208 at commit
|
2bb6835 to
7e997da
Compare
|
I manually tested backwards compatibility and it works fine. I paste the test code for Run following code in spark-2.2 shell first: and then run following code on current PR: (in spark-shell) |
|
Test build #83824 has finished for PR 19208 at commit
|
|
Jenkins, test this please. |
|
Test build #83835 has finished for PR 19208 at commit
|
|
Test build #83834 has finished for PR 19208 at commit
|
|
Awesome, thanks for the updates and for checking backwards compatibility! |
What changes were proposed in this pull request?
We add a parameter whether to collect the full model list when CrossValidator/TrainValidationSplit training (Default is NOT), avoid the change cause OOM)
Add a method in CrossValidatorModel/TrainValidationSplitModel, allow user to get the model list
CrossValidatorModelWriter add a “option”, allow user to control whether to persist the model list to disk (will persist by default).
Note: when persisting the model list, use indices as the sub-model path
How was this patch tested?
Test cases added.