Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import numpy as np

from pyspark import since
from pyspark.ml.param import Params, Param
from pyspark.ml import Estimator, Model
from pyspark.ml.param import Params, Param
from pyspark.ml.param.shared import HasSeed
from pyspark.ml.util import keyword_only
from pyspark.sql.functions import rand

Expand Down Expand Up @@ -89,7 +90,7 @@ def build(self):
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]


class CrossValidator(Estimator):
class CrossValidator(Estimator, HasSeed):
"""
K-fold cross validation.

Expand Down Expand Up @@ -129,9 +130,11 @@ class CrossValidator(Estimator):
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")

@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
seed=None):
"""
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
seed=None)
"""
super(CrossValidator, self).__init__()
#: param for estimator to be cross-validated
Expand All @@ -151,9 +154,11 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF

@keyword_only
@since("1.4.0")
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
seed=None):
"""
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
seed=None):
Sets params for cross validator.
"""
kwargs = self.setParams._input_kwargs
Expand Down Expand Up @@ -225,9 +230,10 @@ def _fit(self, dataset):
numModels = len(epm)
eva = self.getOrDefault(self.evaluator)
nFolds = self.getOrDefault(self.numFolds)
seed = self.getOrDefault(self.seed)
h = 1.0 / nFolds
randCol = self.uid + "_rand"
df = dataset.select("*", rand(0).alias(randCol))
df = dataset.select("*", rand(seed).alias(randCol))
metrics = np.zeros(numModels)
for i in range(nFolds):
validateLB = i * h
Expand Down