-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-21088][ML] CrossValidator, TrainValidationSplit support collect all models when fitting: Python API #19627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #83290 has finished for PR 19627 at commit
|
|
Jenkins, test this please. |
|
Test build #83953 has finished for PR 19627 at commit
|
4c3a7ea to
9e27f6b
Compare
|
Test build #83954 has finished for PR 19627 at commit
|
|
My local test passed. This test failure looks like test system issue. |
|
What happens when you run |
|
@holdenk Find the reason. There is an empty file in the directory. :) |
|
Test build #83991 has finished for PR 19627 at commit
|
|
Test build #83992 has finished for PR 19627 at commit
|
|
@holdenk Thanks! |
|
Is this still WIP or ready? |
|
@jkbradley I think it is better to review #19857 (fix python model specific optimization) and merge it first and then I rebase & update this PR. :) |
|
@MrBago @yogeshg @jkbradley Updated and ready for review now! |
|
Test build #89111 has finished for PR 19627 at commit
|
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
You'll need to update _from_java and _to_java for CrossValidator and TrainValidationSplit.
Also, please update the PR description.
python/pyspark/ml/tests.py
Outdated
| tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, | ||
| collectSubModels=True) | ||
| tvsModel = tvs.fit(dataset) | ||
| assert len(tvsModel.subModels) == len(grid) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use self.assertEqual here and elsewhere.
| "TypeConverters.toInt"), | ||
| ("parallelism", "the number of threads to use when running parallel algorithms (>= 1).", | ||
| "1", "TypeConverters.toInt"), | ||
| ("collectSubModels", "whether to collect a list of sub-models trained during tuning", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to add the full description from Scala.
|
|
||
|
|
||
| class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): | ||
| class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to update _from_java and _to_java as well to pass collectSubModels around. (Same for TrainValidationSplit)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also clarify in the doc for CrossValidatorModel.copy() that it does not copy the extra Params into the subModels. (same for TrainValidationSplitModel)
| cvParallelModel = cv.fit(dataset) | ||
| self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) | ||
|
|
||
| def test_expose_sub_models(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice tests. Can you make one addition: Test the copy() method to make sure it copies the submodels.
|
Test build #89334 has finished for PR 19627 at commit
|
|
LGTM |
What changes were proposed in this pull request?
Add python API for collecting sub-models during CrossValidator/TrainValidationSplit fitting.
How was this patch tested?
UT added.