-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-21911][ML][PySpark] Parallel Model Evaluation for ML Tuning in PySpark #19122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
44f4332
init pr
WeichenXu123 b321534
update
WeichenXu123 d5209c4
improve code in thread
WeichenXu123 6c3debd
update
WeichenXu123 849b675
update
WeichenXu123 b03499a
add serial parallel cmp testcase
WeichenXu123 fb0ac04
fix py style
WeichenXu123 dbe66fb
update
WeichenXu123 93ab39a
update
WeichenXu123 67ad3d2
improve unit test
WeichenXu123 8b3ef97
Merge branch 'master' into par-ml-tuning-py
WeichenXu123 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,15 +14,15 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| import itertools | ||
| import numpy as np | ||
| from multiprocessing.pool import ThreadPool | ||
|
|
||
| from pyspark import since, keyword_only | ||
| from pyspark.ml import Estimator, Model | ||
| from pyspark.ml.common import _py2java | ||
| from pyspark.ml.param import Params, Param, TypeConverters | ||
| from pyspark.ml.param.shared import HasSeed | ||
| from pyspark.ml.param.shared import HasParallelism, HasSeed | ||
| from pyspark.ml.util import * | ||
| from pyspark.ml.wrapper import JavaParams | ||
| from pyspark.sql.functions import rand | ||
|
|
@@ -170,7 +170,7 @@ def _to_java_impl(self): | |
| return java_estimator, java_epms, java_evaluator | ||
|
|
||
|
|
||
| class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): | ||
| class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): | ||
| """ | ||
|
|
||
| K-fold cross validation performs model selection by splitting the dataset into a set of | ||
|
|
@@ -193,7 +193,8 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): | |
| >>> lr = LogisticRegression() | ||
| >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() | ||
| >>> evaluator = BinaryClassificationEvaluator() | ||
| >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) | ||
| >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, | ||
| ... parallelism=2) | ||
| >>> cvModel = cv.fit(dataset) | ||
| >>> cvModel.avgMetrics[0] | ||
| 0.5 | ||
|
|
@@ -208,23 +209,23 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): | |
|
|
||
| @keyword_only | ||
| def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, | ||
| seed=None): | ||
| seed=None, parallelism=1): | ||
| """ | ||
| __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ | ||
| seed=None) | ||
| seed=None, parallelism=1) | ||
| """ | ||
| super(CrossValidator, self).__init__() | ||
| self._setDefault(numFolds=3) | ||
| self._setDefault(numFolds=3, parallelism=1) | ||
|
||
| kwargs = self._input_kwargs | ||
| self._set(**kwargs) | ||
|
|
||
| @keyword_only | ||
| @since("1.4.0") | ||
| def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, | ||
| seed=None): | ||
| seed=None, parallelism=1): | ||
| """ | ||
| setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ | ||
| seed=None): | ||
| seed=None, parallelism=1): | ||
| Sets params for cross validator. | ||
| """ | ||
| kwargs = self._input_kwargs | ||
|
|
@@ -255,18 +256,27 @@ def _fit(self, dataset): | |
| randCol = self.uid + "_rand" | ||
| df = dataset.select("*", rand(seed).alias(randCol)) | ||
| metrics = [0.0] * numModels | ||
|
|
||
| pool = ThreadPool(processes=min(self.getParallelism(), numModels)) | ||
|
|
||
| for i in range(nFolds): | ||
| validateLB = i * h | ||
| validateUB = (i + 1) * h | ||
| condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) | ||
| validation = df.filter(condition) | ||
| train = df.filter(~condition) | ||
| models = est.fit(train, epm) | ||
| for j in range(numModels): | ||
| model = models[j] | ||
| validation = df.filter(condition).cache() | ||
| train = df.filter(~condition).cache() | ||
|
|
||
| def singleTrain(paramMap): | ||
| model = est.fit(train, paramMap) | ||
| # TODO: duplicate evaluator to take extra params from input | ||
| metric = eva.evaluate(model.transform(validation, epm[j])) | ||
| metrics[j] += metric/nFolds | ||
| metric = eva.evaluate(model.transform(validation, paramMap)) | ||
| return metric | ||
|
|
||
| currentFoldMetrics = pool.map(singleTrain, epm) | ||
| for j in range(numModels): | ||
| metrics[j] += (currentFoldMetrics[j] / nFolds) | ||
| validation.unpersist() | ||
| train.unpersist() | ||
|
|
||
| if eva.isLargerBetter(): | ||
| bestIndex = np.argmax(metrics) | ||
|
|
@@ -316,9 +326,10 @@ def _from_java(cls, java_stage): | |
| estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage) | ||
| numFolds = java_stage.getNumFolds() | ||
| seed = java_stage.getSeed() | ||
| parallelism = java_stage.getParallelism() | ||
| # Create a new instance of this stage. | ||
| py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, | ||
| numFolds=numFolds, seed=seed) | ||
| numFolds=numFolds, seed=seed, parallelism=parallelism) | ||
| py_stage._resetUid(java_stage.uid()) | ||
| return py_stage | ||
|
|
||
|
|
@@ -337,6 +348,7 @@ def _to_java(self): | |
| _java_obj.setEstimator(estimator) | ||
| _java_obj.setSeed(self.getSeed()) | ||
| _java_obj.setNumFolds(self.getNumFolds()) | ||
| _java_obj.setParallelism(self.getParallelism()) | ||
|
|
||
| return _java_obj | ||
|
|
||
|
|
@@ -427,7 +439,7 @@ def _to_java(self): | |
| return _java_obj | ||
|
|
||
|
|
||
| class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): | ||
| class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): | ||
| """ | ||
| .. note:: Experimental | ||
|
|
||
|
|
@@ -448,7 +460,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): | |
| >>> lr = LogisticRegression() | ||
| >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() | ||
| >>> evaluator = BinaryClassificationEvaluator() | ||
| >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) | ||
| >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, | ||
| ... parallelism=2) | ||
| >>> tvsModel = tvs.fit(dataset) | ||
| >>> evaluator.evaluate(tvsModel.transform(dataset)) | ||
| 0.8333... | ||
|
|
@@ -461,23 +474,23 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): | |
|
|
||
| @keyword_only | ||
| def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, | ||
| seed=None): | ||
| parallelism=1, seed=None): | ||
| """ | ||
| __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ | ||
| seed=None) | ||
| parallelism=1, seed=None) | ||
| """ | ||
| super(TrainValidationSplit, self).__init__() | ||
| self._setDefault(trainRatio=0.75) | ||
| self._setDefault(trainRatio=0.75, parallelism=1) | ||
| kwargs = self._input_kwargs | ||
| self._set(**kwargs) | ||
|
|
||
| @since("2.0.0") | ||
| @keyword_only | ||
| def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, | ||
| seed=None): | ||
| parallelism=1, seed=None): | ||
| """ | ||
| setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ | ||
| seed=None): | ||
| parallelism=1, seed=None): | ||
| Sets params for the train validation split. | ||
| """ | ||
| kwargs = self._input_kwargs | ||
|
|
@@ -506,15 +519,20 @@ def _fit(self, dataset): | |
| seed = self.getOrDefault(self.seed) | ||
| randCol = self.uid + "_rand" | ||
| df = dataset.select("*", rand(seed).alias(randCol)) | ||
| metrics = [0.0] * numModels | ||
| condition = (df[randCol] >= tRatio) | ||
| validation = df.filter(condition) | ||
| train = df.filter(~condition) | ||
| models = est.fit(train, epm) | ||
| for j in range(numModels): | ||
| model = models[j] | ||
| metric = eva.evaluate(model.transform(validation, epm[j])) | ||
| metrics[j] += metric | ||
| validation = df.filter(condition).cache() | ||
| train = df.filter(~condition).cache() | ||
|
|
||
| def singleTrain(paramMap): | ||
| model = est.fit(train, paramMap) | ||
| metric = eva.evaluate(model.transform(validation, paramMap)) | ||
| return metric | ||
|
|
||
| pool = ThreadPool(processes=min(self.getParallelism(), numModels)) | ||
| metrics = pool.map(singleTrain, epm) | ||
| train.unpersist() | ||
| validation.unpersist() | ||
|
|
||
| if eva.isLargerBetter(): | ||
| bestIndex = np.argmax(metrics) | ||
| else: | ||
|
|
@@ -563,9 +581,10 @@ def _from_java(cls, java_stage): | |
| estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage) | ||
| trainRatio = java_stage.getTrainRatio() | ||
| seed = java_stage.getSeed() | ||
| parallelism = java_stage.getParallelism() | ||
| # Create a new instance of this stage. | ||
| py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, | ||
| trainRatio=trainRatio, seed=seed) | ||
| trainRatio=trainRatio, seed=seed, parallelism=parallelism) | ||
| py_stage._resetUid(java_stage.uid()) | ||
| return py_stage | ||
|
|
||
|
|
@@ -584,6 +603,7 @@ def _to_java(self): | |
| _java_obj.setEstimator(estimator) | ||
| _java_obj.setTrainRatio(self.getTrainRatio()) | ||
| _java_obj.setSeed(self.getSeed()) | ||
| _java_obj.setParallelism(self.getParallelism()) | ||
|
|
||
| return _java_obj | ||
|
|
||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you planning on adding a unit test to verify that parallel has the same results as serial?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test added.