1919import numpy as np
2020
2121from pyspark import since
22- from pyspark .ml .param import Params , Param
22+ from pyspark .ml .param import HasSeed , Params , Param
2323from pyspark .ml import Estimator , Model
2424from pyspark .ml .util import keyword_only
2525from pyspark .sql .functions import rand
@@ -89,7 +89,7 @@ def build(self):
8989 return [dict (zip (keys , prod )) for prod in itertools .product (* grid_values )]
9090
9191
92- class CrossValidator (Estimator ):
92+ class CrossValidator (Estimator , HasSeed ):
9393 """
9494 K-fold cross validation.
9595
@@ -106,7 +106,7 @@ class CrossValidator(Estimator):
106106 >>> lr = LogisticRegression()
107107 >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
108108 >>> evaluator = BinaryClassificationEvaluator()
109- >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
109+ >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, seed=42 )
110110 >>> cvModel = cv.fit(dataset)
111111 >>> evaluator.evaluate(cvModel.transform(dataset))
112112 0.8333...
@@ -129,9 +129,11 @@ class CrossValidator(Estimator):
129129 numFolds = Param (Params ._dummy (), "numFolds" , "number of folds for cross validation" )
130130
131131 @keyword_only
132- def __init__ (self , estimator = None , estimatorParamMaps = None , evaluator = None , numFolds = 3 , seed = 0 ):
132+ def __init__ (self , estimator = None , estimatorParamMaps = None , evaluator = None , numFolds = 3 ,
133+ seed = None ):
133134 """
134- __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
135+ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
136+ seed=None)
135137 """
136138 super (CrossValidator , self ).__init__ ()
137139 #: param for estimator to be cross-validated
@@ -144,18 +146,18 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF
144146 self , "evaluator" ,
145147 "evaluator used to select hyper-parameters that maximize the cross-validated metric" )
146148 #: param for number of folds for cross validation
147- self ._setDefault (seed = 0 )
148- self .seed = Param (self , "seed" , "seed value used for k-fold" )
149149 self .numFolds = Param (self , "numFolds" , "number of folds for cross validation" )
150150 self ._setDefault (numFolds = 3 )
151151 kwargs = self .__init__ ._input_kwargs
152152 self ._set (** kwargs )
153153
154154 @keyword_only
155155 @since ("1.4.0" )
156- def setParams (self , estimator = None , estimatorParamMaps = None , evaluator = None , numFolds = 3 ):
156+ def setParams (self , estimator = None , estimatorParamMaps = None , evaluator = None , numFolds = 3 ,
157+ seed = None ):
157158 """
158- setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
159+ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
160+ seed=None):
159161 Sets params for cross validator.
160162 """
161163 kwargs = self .setParams ._input_kwargs
@@ -227,9 +229,10 @@ def _fit(self, dataset):
227229 numModels = len (epm )
228230 eva = self .getOrDefault (self .evaluator )
229231 nFolds = self .getOrDefault (self .numFolds )
232+ seed = self .getOrDefault (self .seed )
230233 h = 1.0 / nFolds
231234 randCol = self .uid + "_rand"
232- df = dataset .select ("*" , rand (self . getOrDefault ( self . seed ) ).alias (randCol ))
235+ df = dataset .select ("*" , rand (seed ).alias (randCol ))
233236 metrics = np .zeros (numModels )
234237 for i in range (nFolds ):
235238 validateLB = i * h
0 commit comments