Skip to content

Commit 85aae68

Browse files
committed
Fixed to use HasSeed
1 parent 8d69a32 commit 85aae68

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

python/pyspark/ml/tuning.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020

2121
from pyspark import since
22-
from pyspark.ml.param import Params, Param
22+
from pyspark.ml.param import HasSeed, Params, Param
2323
from pyspark.ml import Estimator, Model
2424
from pyspark.ml.util import keyword_only
2525
from pyspark.sql.functions import rand
@@ -89,7 +89,7 @@ def build(self):
8989
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
9090

9191

92-
class CrossValidator(Estimator):
92+
class CrossValidator(Estimator, HasSeed):
9393
"""
9494
K-fold cross validation.
9595
@@ -106,7 +106,7 @@ class CrossValidator(Estimator):
106106
>>> lr = LogisticRegression()
107107
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
108108
>>> evaluator = BinaryClassificationEvaluator()
109-
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
109+
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, seed=42)
110110
>>> cvModel = cv.fit(dataset)
111111
>>> evaluator.evaluate(cvModel.transform(dataset))
112112
0.8333...
@@ -129,9 +129,11 @@ class CrossValidator(Estimator):
129129
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")
130130

131131
@keyword_only
132-
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, seed=0):
132+
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
133+
seed=None):
133134
"""
134-
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
135+
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
136+
seed=None)
135137
"""
136138
super(CrossValidator, self).__init__()
137139
#: param for estimator to be cross-validated
@@ -144,18 +146,18 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF
144146
self, "evaluator",
145147
"evaluator used to select hyper-parameters that maximize the cross-validated metric")
146148
#: param for number of folds for cross validation
147-
self._setDefault(seed=0)
148-
self.seed = Param(self, "seed", "seed value used for k-fold")
149149
self.numFolds = Param(self, "numFolds", "number of folds for cross validation")
150150
self._setDefault(numFolds=3)
151151
kwargs = self.__init__._input_kwargs
152152
self._set(**kwargs)
153153

154154
@keyword_only
155155
@since("1.4.0")
156-
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
156+
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
157+
seed=None):
157158
"""
158-
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
159+
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
160+
seed=None):
159161
Sets params for cross validator.
160162
"""
161163
kwargs = self.setParams._input_kwargs
@@ -227,9 +229,10 @@ def _fit(self, dataset):
227229
numModels = len(epm)
228230
eva = self.getOrDefault(self.evaluator)
229231
nFolds = self.getOrDefault(self.numFolds)
232+
seed = self.getOrDefault(self.seed)
230233
h = 1.0 / nFolds
231234
randCol = self.uid + "_rand"
232-
df = dataset.select("*", rand(self.getOrDefault(self.seed)).alias(randCol))
235+
df = dataset.select("*", rand(seed).alias(randCol))
233236
metrics = np.zeros(numModels)
234237
for i in range(nFolds):
235238
validateLB = i * h

0 commit comments

Comments
 (0)