From 980c8ec87ddbc9f938942e78bb4cfe9753722bd2 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 30 Nov 2017 18:08:55 +0800 Subject: [PATCH 1/2] init pr --- python/pyspark/ml/base.py | 16 +++++++++++ python/pyspark/ml/tuning.py | 57 +++++++++++++++++++++++-------------- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index a6767cee9bf2..5bb19a128064 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -71,6 +71,22 @@ def fit(self, dataset, params=None): raise ValueError("Params must be either a param map or a list/tuple of param maps, " "but got %s." % type(params)) + @since("2.3.0") + def parallelFit(self, dataset, paramMaps, threadPool, modelCallback): + """ + Parallelly fits models to the input dataset with a list of param maps. + + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` + :param paramMaps: a list of param maps + :param threadPool: a thread pool used to run parallel fitting + :param modelCallback: fitted model with corresponding param map index will be passed to + the callback function. + """ + def singleTrain(paramMapIndex): + model = self.fit(dataset, paramMaps[paramMapIndex]) + modelCallback(model, paramMapIndex) + threadPool.map(singleTrain, range(len(paramMaps))) + @inherit_doc class Transformer(Params): diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 47351133524e..c10e6e91c6f0 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -20,11 +20,12 @@ from pyspark import since, keyword_only from pyspark.ml import Estimator, Model -from pyspark.ml.common import _py2java +from pyspark.ml.common import _java2py, _py2java +from pyspark.ml.evaluation import JavaEvaluator from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasParallelism, HasSeed from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaParams +from pyspark.ml.wrapper import JavaEstimator, JavaParams from pyspark.sql.functions import rand __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', @@ -247,9 +248,14 @@ def getNumFolds(self): def _fit(self, dataset): est = self.getOrDefault(self.estimator) + eva = self.getOrDefault(self.evaluator) + + if isinstance(est, JavaEstimator) and isinstance(eva, JavaEvaluator): + java_model = self._to_java().fit(dataset._jdf) + return CrossValidatorModel._from_java(java_model) + epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) - eva = self.getOrDefault(self.evaluator) nFolds = self.getOrDefault(self.numFolds) seed = self.getOrDefault(self.seed) h = 1.0 / nFolds @@ -266,15 +272,15 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - def singleTrain(paramMap): - model = est.fit(train, paramMap) - # TODO: duplicate evaluator to take extra params from input - metric = eva.evaluate(model.transform(validation, paramMap)) - return metric + currentFoldMetrics = [0.0] * numModels + def modelCallback(model, paramMapIndex): + metric = eva.evaluate(model.transform(validation, epm[paramMapIndex])) + currentFoldMetrics[paramMapIndex] = metric + est.parallelFit(train, epm, pool, modelCallback) - currentFoldMetrics = pool.map(singleTrain, epm) for j in range(numModels): metrics[j] += (currentFoldMetrics[j] / nFolds) + validation.unpersist() train.unpersist() @@ -409,10 +415,12 @@ def _from_java(cls, java_stage): Used for ML persistence. """ + sc = SparkContext._active_spark_context bestModel = JavaParams._from_java(java_stage.bestModel()) + avgMetrics = _java2py(sc, java_stage.avgMetrics()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) - py_stage = cls(bestModel=bestModel).setEstimator(estimator) + py_stage = cls(bestModel = bestModel, avgMetrics = avgMetrics).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) py_stage._resetUid(java_stage.uid()) @@ -426,11 +434,10 @@ def _to_java(self): """ sc = SparkContext._active_spark_context - # TODO: persist average metrics as well _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", self.uid, self.bestModel._to_java(), - _py2java(sc, [])) + _py2java(sc, self.avgMetrics)) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) @@ -512,9 +519,14 @@ def getTrainRatio(self): def _fit(self, dataset): est = self.getOrDefault(self.estimator) + eva = self.getOrDefault(self.evaluator) + + if isinstance(est, JavaEstimator) and isinstance(eva, JavaEvaluator): + java_model = self._to_java().fit(dataset._jdf) + return TrainValidationSplitModel._from_java(java_model) + epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) - eva = self.getOrDefault(self.evaluator) tRatio = self.getOrDefault(self.trainRatio) seed = self.getOrDefault(self.seed) randCol = self.uid + "_rand" @@ -523,13 +535,14 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - def singleTrain(paramMap): - model = est.fit(train, paramMap) - metric = eva.evaluate(model.transform(validation, paramMap)) - return metric + def modelCallback(model, paramMapIndex): + metric = eva.evaluate(model.transform(validation, epm[paramMapIndex])) + metrics[paramMapIndex] = metric pool = ThreadPool(processes=min(self.getParallelism(), numModels)) - metrics = pool.map(singleTrain, epm) + metrics = [0.0] * numModels + est.parallelFit(train, epm, pool, modelCallback) + train.unpersist() validation.unpersist() @@ -663,12 +676,15 @@ def _from_java(cls, java_stage): Used for ML persistence. """ + sc = SparkContext._active_spark_context # Load information from java_stage to the instance. bestModel = JavaParams._from_java(java_stage.bestModel()) + validationMetrics = _java2py(sc, java_stage.validationMetrics()) estimator, epms, evaluator = super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. - py_stage = cls(bestModel=bestModel).setEstimator(estimator) + py_stage = cls(bestModel = bestModel, validationMetrics = validationMetrics)\ + .setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) py_stage._resetUid(java_stage.uid()) @@ -681,12 +697,11 @@ def _to_java(self): """ sc = SparkContext._active_spark_context - # TODO: persst validation metrics as well _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, self.bestModel._to_java(), - _py2java(sc, [])) + _py2java(sc, self.validationMetrics)) estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) From c6f225025a1ba002b6aa4ce83fb67dbe742395b1 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 1 Dec 2017 15:11:31 +0800 Subject: [PATCH 2/2] fix python style --- python/pyspark/ml/tuning.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index c10e6e91c6f0..102482b71e24 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -262,7 +262,6 @@ def _fit(self, dataset): randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) metrics = [0.0] * numModels - pool = ThreadPool(processes=min(self.getParallelism(), numModels)) for i in range(nFolds): @@ -273,9 +272,11 @@ def _fit(self, dataset): train = df.filter(~condition).cache() currentFoldMetrics = [0.0] * numModels + def modelCallback(model, paramMapIndex): metric = eva.evaluate(model.transform(validation, epm[paramMapIndex])) currentFoldMetrics[paramMapIndex] = metric + est.parallelFit(train, epm, pool, modelCallback) for j in range(numModels): @@ -420,7 +421,7 @@ def _from_java(cls, java_stage): avgMetrics = _java2py(sc, java_stage.avgMetrics()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) - py_stage = cls(bestModel = bestModel, avgMetrics = avgMetrics).setEstimator(estimator) + py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) py_stage._resetUid(java_stage.uid()) @@ -683,8 +684,8 @@ def _from_java(cls, java_stage): estimator, epms, evaluator = super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. - py_stage = cls(bestModel = bestModel, validationMetrics = validationMetrics)\ - .setEstimator(estimator) + py_stage = cls(bestModel=bestModel, + validationMetrics=validationMetrics).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) py_stage._resetUid(java_stage.uid())