Skip to content

Commit 3a44aeb

Browse files
mmenestretjkbradley
authored andcommitted
[SPARK-9690][ML][PYTHON] pyspark CrossValidator random seed
Extend CrossValidator with HasSeed in PySpark. This PR replaces [#7997] CC: yanboliang thunterdb mmenestret Would one of you mind taking a look? Thanks! Author: Joseph K. Bradley <[email protected]> Author: Martin MENESTRET <[email protected]> Closes #10268 from jkbradley/pyspark-cv-seed.
1 parent 9657ee8 commit 3a44aeb

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

python/pyspark/ml/tuning.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
import numpy as np
2020

2121
from pyspark import since
22-
from pyspark.ml.param import Params, Param
2322
from pyspark.ml import Estimator, Model
23+
from pyspark.ml.param import Params, Param
24+
from pyspark.ml.param.shared import HasSeed
2425
from pyspark.ml.util import keyword_only
2526
from pyspark.sql.functions import rand
2627

@@ -89,7 +90,7 @@ def build(self):
8990
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
9091

9192

92-
class CrossValidator(Estimator):
93+
class CrossValidator(Estimator, HasSeed):
9394
"""
9495
K-fold cross validation.
9596
@@ -129,9 +130,11 @@ class CrossValidator(Estimator):
129130
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")
130131

131132
@keyword_only
132-
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
133+
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
134+
seed=None):
133135
"""
134-
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
136+
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
137+
seed=None)
135138
"""
136139
super(CrossValidator, self).__init__()
137140
#: param for estimator to be cross-validated
@@ -151,9 +154,11 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF
151154

152155
@keyword_only
153156
@since("1.4.0")
154-
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
157+
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
158+
seed=None):
155159
"""
156-
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
160+
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
161+
seed=None):
157162
Sets params for cross validator.
158163
"""
159164
kwargs = self.setParams._input_kwargs
@@ -225,9 +230,10 @@ def _fit(self, dataset):
225230
numModels = len(epm)
226231
eva = self.getOrDefault(self.evaluator)
227232
nFolds = self.getOrDefault(self.numFolds)
233+
seed = self.getOrDefault(self.seed)
228234
h = 1.0 / nFolds
229235
randCol = self.uid + "_rand"
230-
df = dataset.select("*", rand(0).alias(randCol))
236+
df = dataset.select("*", rand(seed).alias(randCol))
231237
metrics = np.zeros(numModels)
232238
for i in range(nFolds):
233239
validateLB = i * h

0 commit comments

Comments
 (0)