Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class CrossValidator(Estimator):
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")

@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, seed=0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make CorssValidator extend HasSeed then put seed=None here.

"""
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update this line too.

"""
Expand All @@ -136,6 +136,8 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF
self, "evaluator",
"evaluator used to select hyper-parameters that maximize the cross-validated metric")
#: param for number of folds for cross validation
self._setDefault(seed=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put all defaults in one call to setDefault.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value of seed is set here: https://github.com/apache/spark/blob/master/python/pyspark/ml/param/shared.py#L365. This line is not necessary.

self.seed = Param(self, "seed", "seed value used for k-fold")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also not necessary if we make CrossValidator extend HasSeed.

self.numFolds = Param(self, "numFolds", "number of folds for cross validation")
self._setDefault(numFolds=3)
kwargs = self.__init__._input_kwargs
Expand Down Expand Up @@ -210,7 +212,7 @@ def _fit(self, dataset):
nFolds = self.getOrDefault(self.numFolds)
h = 1.0 / nFolds
randCol = self.uid + "_rand"
df = dataset.select("*", rand(0).alias(randCol))
df = dataset.select("*", rand(self.getOrDefault(self.seed)).alias(randCol))
metrics = np.zeros(numModels)
for i in range(nFolds):
validateLB = i * h
Expand Down