diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index a6767cee9bf28..5bb19a1280647 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -71,6 +71,22 @@ def fit(self, dataset, params=None): raise ValueError("Params must be either a param map or a list/tuple of param maps, " "but got %s." % type(params)) + @since("2.3.0") + def parallelFit(self, dataset, paramMaps, threadPool, modelCallback): + """ + Parallelly fits models to the input dataset with a list of param maps. + + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` + :param paramMaps: a list of param maps + :param threadPool: a thread pool used to run parallel fitting + :param modelCallback: fitted model with corresponding param map index will be passed to + the callback function. + """ + def singleTrain(paramMapIndex): + model = self.fit(dataset, paramMaps[paramMapIndex]) + modelCallback(model, paramMapIndex) + threadPool.map(singleTrain, range(len(paramMaps))) + @inherit_doc class Transformer(Params): diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 47351133524e7..102482b71e24d 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -20,11 +20,12 @@ from pyspark import since, keyword_only from pyspark.ml import Estimator, Model -from pyspark.ml.common import _py2java +from pyspark.ml.common import _java2py, _py2java +from pyspark.ml.evaluation import JavaEvaluator from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasParallelism, HasSeed from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaParams +from pyspark.ml.wrapper import JavaEstimator, JavaParams from pyspark.sql.functions import rand __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', @@ -247,16 +248,20 @@ def getNumFolds(self): def _fit(self, dataset): est = self.getOrDefault(self.estimator) + eva = self.getOrDefault(self.evaluator) + + if isinstance(est, JavaEstimator) and isinstance(eva, JavaEvaluator): + java_model = self._to_java().fit(dataset._jdf) + return CrossValidatorModel._from_java(java_model) + epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) - eva = self.getOrDefault(self.evaluator) nFolds = self.getOrDefault(self.numFolds) seed = self.getOrDefault(self.seed) h = 1.0 / nFolds randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) metrics = [0.0] * numModels - pool = ThreadPool(processes=min(self.getParallelism(), numModels)) for i in range(nFolds): @@ -266,15 +271,17 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - def singleTrain(paramMap): - model = est.fit(train, paramMap) - # TODO: duplicate evaluator to take extra params from input - metric = eva.evaluate(model.transform(validation, paramMap)) - return metric + currentFoldMetrics = [0.0] * numModels + + def modelCallback(model, paramMapIndex): + metric = eva.evaluate(model.transform(validation, epm[paramMapIndex])) + currentFoldMetrics[paramMapIndex] = metric + + est.parallelFit(train, epm, pool, modelCallback) - currentFoldMetrics = pool.map(singleTrain, epm) for j in range(numModels): metrics[j] += (currentFoldMetrics[j] / nFolds) + validation.unpersist() train.unpersist() @@ -409,10 +416,12 @@ def _from_java(cls, java_stage): Used for ML persistence. """ + sc = SparkContext._active_spark_context bestModel = JavaParams._from_java(java_stage.bestModel()) + avgMetrics = _java2py(sc, java_stage.avgMetrics()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) - py_stage = cls(bestModel=bestModel).setEstimator(estimator) + py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) py_stage._resetUid(java_stage.uid()) @@ -426,11 +435,10 @@ def _to_java(self): """ sc = SparkContext._active_spark_context - # TODO: persist average metrics as well _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", self.uid, self.bestModel._to_java(), - _py2java(sc, [])) + _py2java(sc, self.avgMetrics)) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) @@ -512,9 +520,14 @@ def getTrainRatio(self): def _fit(self, dataset): est = self.getOrDefault(self.estimator) + eva = self.getOrDefault(self.evaluator) + + if isinstance(est, JavaEstimator) and isinstance(eva, JavaEvaluator): + java_model = self._to_java().fit(dataset._jdf) + return TrainValidationSplitModel._from_java(java_model) + epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) - eva = self.getOrDefault(self.evaluator) tRatio = self.getOrDefault(self.trainRatio) seed = self.getOrDefault(self.seed) randCol = self.uid + "_rand" @@ -523,13 +536,14 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - def singleTrain(paramMap): - model = est.fit(train, paramMap) - metric = eva.evaluate(model.transform(validation, paramMap)) - return metric + def modelCallback(model, paramMapIndex): + metric = eva.evaluate(model.transform(validation, epm[paramMapIndex])) + metrics[paramMapIndex] = metric pool = ThreadPool(processes=min(self.getParallelism(), numModels)) - metrics = pool.map(singleTrain, epm) + metrics = [0.0] * numModels + est.parallelFit(train, epm, pool, modelCallback) + train.unpersist() validation.unpersist() @@ -663,12 +677,15 @@ def _from_java(cls, java_stage): Used for ML persistence. """ + sc = SparkContext._active_spark_context # Load information from java_stage to the instance. bestModel = JavaParams._from_java(java_stage.bestModel()) + validationMetrics = _java2py(sc, java_stage.validationMetrics()) estimator, epms, evaluator = super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. - py_stage = cls(bestModel=bestModel).setEstimator(estimator) + py_stage = cls(bestModel=bestModel, + validationMetrics=validationMetrics).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) py_stage._resetUid(java_stage.uid()) @@ -681,12 +698,11 @@ def _to_java(self): """ sc = SparkContext._active_spark_context - # TODO: persst validation metrics as well _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, self.bestModel._to_java(), - _py2java(sc, [])) + _py2java(sc, self.validationMetrics)) estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator)