Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,22 @@ def fit(self, dataset, params=None):
raise ValueError("Params must be either a param map or a list/tuple of param maps, "
"but got %s." % type(params))

@since("2.3.0")
def parallelFit(self, dataset, paramMaps, threadPool, modelCallback):
"""
Parallelly fits models to the input dataset with a list of param maps.

:param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
:param paramMaps: a list of param maps
:param threadPool: a thread pool used to run parallel fitting
:param modelCallback: fitted model with corresponding param map index will be passed to
the callback function.
"""
def singleTrain(paramMapIndex):
model = self.fit(dataset, paramMaps[paramMapIndex])
modelCallback(model, paramMapIndex)
threadPool.map(singleTrain, range(len(paramMaps)))


@inherit_doc
class Transformer(Params):
Expand Down
60 changes: 38 additions & 22 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@

from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.common import _py2java
from pyspark.ml.common import _java2py, _py2java
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasParallelism, HasSeed
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.wrapper import JavaEstimator, JavaParams
from pyspark.sql.functions import rand

__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
Expand Down Expand Up @@ -247,16 +248,20 @@ def getNumFolds(self):

def _fit(self, dataset):
est = self.getOrDefault(self.estimator)
eva = self.getOrDefault(self.evaluator)

if isinstance(est, JavaEstimator) and isinstance(eva, JavaEvaluator):
java_model = self._to_java().fit(dataset._jdf)
return CrossValidatorModel._from_java(java_model)

epm = self.getOrDefault(self.estimatorParamMaps)
numModels = len(epm)
eva = self.getOrDefault(self.evaluator)
nFolds = self.getOrDefault(self.numFolds)
seed = self.getOrDefault(self.seed)
h = 1.0 / nFolds
randCol = self.uid + "_rand"
df = dataset.select("*", rand(seed).alias(randCol))
metrics = [0.0] * numModels

pool = ThreadPool(processes=min(self.getParallelism(), numModels))

for i in range(nFolds):
Expand All @@ -266,15 +271,17 @@ def _fit(self, dataset):
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()

def singleTrain(paramMap):
model = est.fit(train, paramMap)
# TODO: duplicate evaluator to take extra params from input
metric = eva.evaluate(model.transform(validation, paramMap))
return metric
currentFoldMetrics = [0.0] * numModels

def modelCallback(model, paramMapIndex):
metric = eva.evaluate(model.transform(validation, epm[paramMapIndex]))
currentFoldMetrics[paramMapIndex] = metric

est.parallelFit(train, epm, pool, modelCallback)

currentFoldMetrics = pool.map(singleTrain, epm)
for j in range(numModels):
metrics[j] += (currentFoldMetrics[j] / nFolds)

validation.unpersist()
train.unpersist()

Expand Down Expand Up @@ -409,10 +416,12 @@ def _from_java(cls, java_stage):
Used for ML persistence.
"""

sc = SparkContext._active_spark_context
bestModel = JavaParams._from_java(java_stage.bestModel())
avgMetrics = _java2py(sc, java_stage.avgMetrics())
estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)

py_stage = cls(bestModel=bestModel).setEstimator(estimator)
py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics).setEstimator(estimator)
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)

py_stage._resetUid(java_stage.uid())
Expand All @@ -426,11 +435,10 @@ def _to_java(self):
"""

sc = SparkContext._active_spark_context
# TODO: persist average metrics as well
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
self.uid,
self.bestModel._to_java(),
_py2java(sc, []))
_py2java(sc, self.avgMetrics))
estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()

_java_obj.set("evaluator", evaluator)
Expand Down Expand Up @@ -512,9 +520,14 @@ def getTrainRatio(self):

def _fit(self, dataset):
est = self.getOrDefault(self.estimator)
eva = self.getOrDefault(self.evaluator)

if isinstance(est, JavaEstimator) and isinstance(eva, JavaEvaluator):
java_model = self._to_java().fit(dataset._jdf)
return TrainValidationSplitModel._from_java(java_model)

epm = self.getOrDefault(self.estimatorParamMaps)
numModels = len(epm)
eva = self.getOrDefault(self.evaluator)
tRatio = self.getOrDefault(self.trainRatio)
seed = self.getOrDefault(self.seed)
randCol = self.uid + "_rand"
Expand All @@ -523,13 +536,14 @@ def _fit(self, dataset):
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()

def singleTrain(paramMap):
model = est.fit(train, paramMap)
metric = eva.evaluate(model.transform(validation, paramMap))
return metric
def modelCallback(model, paramMapIndex):
metric = eva.evaluate(model.transform(validation, epm[paramMapIndex]))
metrics[paramMapIndex] = metric

pool = ThreadPool(processes=min(self.getParallelism(), numModels))
metrics = pool.map(singleTrain, epm)
metrics = [0.0] * numModels
est.parallelFit(train, epm, pool, modelCallback)

train.unpersist()
validation.unpersist()

Expand Down Expand Up @@ -663,12 +677,15 @@ def _from_java(cls, java_stage):
Used for ML persistence.
"""

sc = SparkContext._active_spark_context
# Load information from java_stage to the instance.
bestModel = JavaParams._from_java(java_stage.bestModel())
validationMetrics = _java2py(sc, java_stage.validationMetrics())
estimator, epms, evaluator = super(TrainValidationSplitModel,
cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
py_stage = cls(bestModel=bestModel).setEstimator(estimator)
py_stage = cls(bestModel=bestModel,
validationMetrics=validationMetrics).setEstimator(estimator)
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)

py_stage._resetUid(java_stage.uid())
Expand All @@ -681,12 +698,11 @@ def _to_java(self):
"""

sc = SparkContext._active_spark_context
# TODO: persst validation metrics as well
_java_obj = JavaParams._new_java_obj(
"org.apache.spark.ml.tuning.TrainValidationSplitModel",
self.uid,
self.bestModel._to_java(),
_py2java(sc, []))
_py2java(sc, self.validationMetrics))
estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()

_java_obj.set("evaluator", evaluator)
Expand Down