From 480a5d081fe34bea42dc3c474a57a811947f483f Mon Sep 17 00:00:00 2001 From: Martin MENESTRET Date: Thu, 6 Aug 2015 18:28:48 +0200 Subject: [PATCH] SPARK-9690 Adding the possibility to set the seed of the rand in the CrossValidator fold --- python/pyspark/ml/tuning.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 0bf988fd72f14..584efa7d35fe7 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -121,7 +121,7 @@ class CrossValidator(Estimator): numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") @keyword_only - def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, seed=0): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) """ @@ -136,6 +136,8 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF self, "evaluator", "evaluator used to select hyper-parameters that maximize the cross-validated metric") #: param for number of folds for cross validation + self._setDefault(seed=0) + self.seed = Param(self, "seed", "seed value used for k-fold") self.numFolds = Param(self, "numFolds", "number of folds for cross validation") self._setDefault(numFolds=3) kwargs = self.__init__._input_kwargs @@ -210,7 +212,7 @@ def _fit(self, dataset): nFolds = self.getOrDefault(self.numFolds) h = 1.0 / nFolds randCol = self.uid + "_rand" - df = dataset.select("*", rand(0).alias(randCol)) + df = dataset.select("*", rand(self.getOrDefault(self.seed)).alias(randCol)) metrics = np.zeros(numModels) for i in range(nFolds): validateLB = i * h