1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616#
17-
1817import itertools
1918import numpy as np
19+ from multiprocessing .pool import ThreadPool
2020
2121from pyspark import since , keyword_only
2222from pyspark .ml import Estimator , Model
2323from pyspark .ml .common import _py2java
2424from pyspark .ml .param import Params , Param , TypeConverters
25- from pyspark .ml .param .shared import HasSeed
25+ from pyspark .ml .param .shared import HasParallelism , HasSeed
2626from pyspark .ml .util import *
2727from pyspark .ml .wrapper import JavaParams
2828from pyspark .sql .functions import rand
@@ -170,7 +170,7 @@ def _to_java_impl(self):
170170 return java_estimator , java_epms , java_evaluator
171171
172172
173- class CrossValidator (Estimator , ValidatorParams , MLReadable , MLWritable ):
173+ class CrossValidator (Estimator , ValidatorParams , HasParallelism , MLReadable , MLWritable ):
174174 """
175175
176176 K-fold cross validation performs model selection by splitting the dataset into a set of
@@ -193,7 +193,8 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
193193 >>> lr = LogisticRegression()
194194 >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
195195 >>> evaluator = BinaryClassificationEvaluator()
196- >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
196+ >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
197+ ... parallelism=2)
197198 >>> cvModel = cv.fit(dataset)
198199 >>> cvModel.avgMetrics[0]
199200 0.5
@@ -208,23 +209,23 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
208209
209210 @keyword_only
210211 def __init__ (self , estimator = None , estimatorParamMaps = None , evaluator = None , numFolds = 3 ,
211- seed = None ):
212+ seed = None , parallelism = 1 ):
212213 """
213214 __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
214- seed=None)
215+ seed=None, parallelism=1 )
215216 """
216217 super (CrossValidator , self ).__init__ ()
217- self ._setDefault (numFolds = 3 )
218+ self ._setDefault (numFolds = 3 , parallelism = 1 )
218219 kwargs = self ._input_kwargs
219220 self ._set (** kwargs )
220221
221222 @keyword_only
222223 @since ("1.4.0" )
223224 def setParams (self , estimator = None , estimatorParamMaps = None , evaluator = None , numFolds = 3 ,
224- seed = None ):
225+ seed = None , parallelism = 1 ):
225226 """
226227 setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
227- seed=None):
228+ seed=None, parallelism=1 ):
228229 Sets params for cross validator.
229230 """
230231 kwargs = self ._input_kwargs
@@ -255,18 +256,27 @@ def _fit(self, dataset):
255256 randCol = self .uid + "_rand"
256257 df = dataset .select ("*" , rand (seed ).alias (randCol ))
257258 metrics = [0.0 ] * numModels
259+
260+ pool = ThreadPool (processes = min (self .getParallelism (), numModels ))
261+
258262 for i in range (nFolds ):
259263 validateLB = i * h
260264 validateUB = (i + 1 ) * h
261265 condition = (df [randCol ] >= validateLB ) & (df [randCol ] < validateUB )
262- validation = df .filter (condition )
263- train = df .filter (~ condition )
264- models = est . fit ( train , epm )
265- for j in range ( numModels ):
266- model = models [ j ]
266+ validation = df .filter (condition ). cache ()
267+ train = df .filter (~ condition ). cache ()
268+
269+ def singleTrain ( paramMap ):
270+ model = est . fit ( train , paramMap )
267271 # TODO: duplicate evaluator to take extra params from input
268- metric = eva .evaluate (model .transform (validation , epm [j ]))
269- metrics [j ] += metric / nFolds
272+ metric = eva .evaluate (model .transform (validation , paramMap ))
273+ return metric
274+
275+ currentFoldMetrics = pool .map (singleTrain , epm )
276+ for j in range (numModels ):
277+ metrics [j ] += (currentFoldMetrics [j ] / nFolds )
278+ validation .unpersist ()
279+ train .unpersist ()
270280
271281 if eva .isLargerBetter ():
272282 bestIndex = np .argmax (metrics )
@@ -316,9 +326,10 @@ def _from_java(cls, java_stage):
316326 estimator , epms , evaluator = super (CrossValidator , cls )._from_java_impl (java_stage )
317327 numFolds = java_stage .getNumFolds ()
318328 seed = java_stage .getSeed ()
329+ parallelism = java_stage .getParallelism ()
319330 # Create a new instance of this stage.
320331 py_stage = cls (estimator = estimator , estimatorParamMaps = epms , evaluator = evaluator ,
321- numFolds = numFolds , seed = seed )
332+ numFolds = numFolds , seed = seed , parallelism = parallelism )
322333 py_stage ._resetUid (java_stage .uid ())
323334 return py_stage
324335
@@ -337,6 +348,7 @@ def _to_java(self):
337348 _java_obj .setEstimator (estimator )
338349 _java_obj .setSeed (self .getSeed ())
339350 _java_obj .setNumFolds (self .getNumFolds ())
351+ _java_obj .setParallelism (self .getParallelism ())
340352
341353 return _java_obj
342354
@@ -427,7 +439,7 @@ def _to_java(self):
427439 return _java_obj
428440
429441
430- class TrainValidationSplit (Estimator , ValidatorParams , MLReadable , MLWritable ):
442+ class TrainValidationSplit (Estimator , ValidatorParams , HasParallelism , MLReadable , MLWritable ):
431443 """
432444 .. note:: Experimental
433445
@@ -448,7 +460,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
448460 >>> lr = LogisticRegression()
449461 >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
450462 >>> evaluator = BinaryClassificationEvaluator()
451- >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
463+ >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
464+ ... parallelism=2)
452465 >>> tvsModel = tvs.fit(dataset)
453466 >>> evaluator.evaluate(tvsModel.transform(dataset))
454467 0.8333...
@@ -461,23 +474,23 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
461474
462475 @keyword_only
463476 def __init__ (self , estimator = None , estimatorParamMaps = None , evaluator = None , trainRatio = 0.75 ,
464- seed = None ):
477+ parallelism = 1 , seed = None ):
465478 """
466479 __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
467- seed=None)
480+ parallelism=1, seed=None)
468481 """
469482 super (TrainValidationSplit , self ).__init__ ()
470- self ._setDefault (trainRatio = 0.75 )
483+ self ._setDefault (trainRatio = 0.75 , parallelism = 1 )
471484 kwargs = self ._input_kwargs
472485 self ._set (** kwargs )
473486
474487 @since ("2.0.0" )
475488 @keyword_only
476489 def setParams (self , estimator = None , estimatorParamMaps = None , evaluator = None , trainRatio = 0.75 ,
477- seed = None ):
490+ parallelism = 1 , seed = None ):
478491 """
479492 setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
480- seed=None):
493+ parallelism=1, seed=None):
481494 Sets params for the train validation split.
482495 """
483496 kwargs = self ._input_kwargs
@@ -506,15 +519,20 @@ def _fit(self, dataset):
506519 seed = self .getOrDefault (self .seed )
507520 randCol = self .uid + "_rand"
508521 df = dataset .select ("*" , rand (seed ).alias (randCol ))
509- metrics = [0.0 ] * numModels
510522 condition = (df [randCol ] >= tRatio )
511- validation = df .filter (condition )
512- train = df .filter (~ condition )
513- models = est .fit (train , epm )
514- for j in range (numModels ):
515- model = models [j ]
516- metric = eva .evaluate (model .transform (validation , epm [j ]))
517- metrics [j ] += metric
523+ validation = df .filter (condition ).cache ()
524+ train = df .filter (~ condition ).cache ()
525+
526+ def singleTrain (paramMap ):
527+ model = est .fit (train , paramMap )
528+ metric = eva .evaluate (model .transform (validation , paramMap ))
529+ return metric
530+
531+ pool = ThreadPool (processes = min (self .getParallelism (), numModels ))
532+ metrics = pool .map (singleTrain , epm )
533+ train .unpersist ()
534+ validation .unpersist ()
535+
518536 if eva .isLargerBetter ():
519537 bestIndex = np .argmax (metrics )
520538 else :
@@ -563,9 +581,10 @@ def _from_java(cls, java_stage):
563581 estimator , epms , evaluator = super (TrainValidationSplit , cls )._from_java_impl (java_stage )
564582 trainRatio = java_stage .getTrainRatio ()
565583 seed = java_stage .getSeed ()
584+ parallelism = java_stage .getParallelism ()
566585 # Create a new instance of this stage.
567586 py_stage = cls (estimator = estimator , estimatorParamMaps = epms , evaluator = evaluator ,
568- trainRatio = trainRatio , seed = seed )
587+ trainRatio = trainRatio , seed = seed , parallelism = parallelism )
569588 py_stage ._resetUid (java_stage .uid ())
570589 return py_stage
571590
@@ -584,6 +603,7 @@ def _to_java(self):
584603 _java_obj .setEstimator (estimator )
585604 _java_obj .setTrainRatio (self .getTrainRatio ())
586605 _java_obj .setSeed (self .getSeed ())
606+ _java_obj .setParallelism (self .getParallelism ())
587607
588608 return _java_obj
589609
0 commit comments