1919import numpy as np
2020
2121from pyspark import since
22- from pyspark .ml .param import Params , Param
2322from pyspark .ml import Estimator , Model
23+ from pyspark .ml .param import Params , Param
24+ from pyspark .ml .param .shared import HasSeed
2425from pyspark .ml .util import keyword_only
2526from pyspark .sql .functions import rand
2627
@@ -89,7 +90,7 @@ def build(self):
8990 return [dict (zip (keys , prod )) for prod in itertools .product (* grid_values )]
9091
9192
92- class CrossValidator (Estimator ):
93+ class CrossValidator (Estimator , HasSeed ):
9394 """
9495 K-fold cross validation.
9596
@@ -129,9 +130,11 @@ class CrossValidator(Estimator):
129130 numFolds = Param (Params ._dummy (), "numFolds" , "number of folds for cross validation" )
130131
131132 @keyword_only
132- def __init__ (self , estimator = None , estimatorParamMaps = None , evaluator = None , numFolds = 3 ):
133+ def __init__ (self , estimator = None , estimatorParamMaps = None , evaluator = None , numFolds = 3 ,
134+ seed = None ):
133135 """
134- __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
136+ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
137+ seed=None)
135138 """
136139 super (CrossValidator , self ).__init__ ()
137140 #: param for estimator to be cross-validated
@@ -151,9 +154,11 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF
151154
152155 @keyword_only
153156 @since ("1.4.0" )
154- def setParams (self , estimator = None , estimatorParamMaps = None , evaluator = None , numFolds = 3 ):
157+ def setParams (self , estimator = None , estimatorParamMaps = None , evaluator = None , numFolds = 3 ,
158+ seed = None ):
155159 """
156- setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
160+ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
161+ seed=None):
157162 Sets params for cross validator.
158163 """
159164 kwargs = self .setParams ._input_kwargs
@@ -225,9 +230,10 @@ def _fit(self, dataset):
225230 numModels = len (epm )
226231 eva = self .getOrDefault (self .evaluator )
227232 nFolds = self .getOrDefault (self .numFolds )
233+ seed = self .getOrDefault (self .seed )
228234 h = 1.0 / nFolds
229235 randCol = self .uid + "_rand"
230- df = dataset .select ("*" , rand (0 ).alias (randCol ))
236+ df = dataset .select ("*" , rand (seed ).alias (randCol ))
231237 metrics = np .zeros (numModels )
232238 for i in range (nFolds ):
233239 validateLB = i * h
0 commit comments