Skip to content

Commit 6a9b735

Browse files
committed
added note about parallelism capped by Scala collection thread pool, adjusted comments
1 parent 8126710 commit 6a9b735

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

docs/ml-tuning.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ method in each of these evaluators.
5656

5757
To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility.
5858
Sets of parameters from the parameter grid can be evaluated in parallel by setting `numParallelEval` with a value of 2 or more (a value of 1 will evaluate in serial) before running model selection with `CrossValidator` or `TrainValidationSplit`.
59-
The value of `numParallelEval` should be chosen carefully to maximize concurrency without exceeding cluster resources. Generally speaking, a value up to 10 should be sufficient for most clusters.
59+
The value of `numParallelEval` should be chosen carefully to maximize concurrency without exceeding cluster resources, and will be capped at the number of cores in the driver system. Generally speaking, a value up to 10 should be sufficient for most clusters.
6060

6161

6262
# Cross-Validation

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
105105
val eval = $(evaluator)
106106
val epm = $(estimatorParamMaps)
107107
val numModels = epm.length
108+
// Barrier to limit parallelism during model fit/evaluation
109+
// NOTE: will be capped by size of thread pool used in Scala parallel collections, which is
110+
// number of cores in the system by default
108111
val numParBarrier = new Semaphore($(numParallelEval))
109112

110113
val instr = Instrumentation.create(this, dataset)
@@ -118,7 +121,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
118121
val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
119122
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
120123
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
121-
// multi-model training
122124
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
123125

124126
// Fit models concurrently, limited by a barrier with '$numParallelEval' permits

mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
9999
val est = $(estimator)
100100
val eval = $(evaluator)
101101
val epm = $(estimatorParamMaps)
102+
// Barrier to limit parallelism during model fit/evaluation
103+
// NOTE: will be capped by size of thread pool used in Scala parallel collections, which is
104+
// number of cores in the system by default
102105
val numParBarrier = new Semaphore($(numParallelEval))
103106
logDebug(s"Running validation with level of parallelism: ${numParBarrier.availablePermits()}.")
104107

@@ -111,15 +114,16 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
111114
trainingDataset.cache()
112115
validationDataset.cache()
113116

114-
logDebug(s"Train split with multiple sets of parameters.")
115117
// Fit models concurrently, limited by a barrier with '$numParallelEval' permits
118+
logDebug(s"Train split with multiple sets of parameters.")
116119
val models = epm.par.map { paramMap =>
117120
numParBarrier.acquire()
118121
val model = est.fit(trainingDataset, paramMap)
119122
numParBarrier.release()
120123
model.asInstanceOf[Model[_]]
121124
}.seq
122125
trainingDataset.unpersist()
126+
123127
// Evaluate models concurrently, limited by a barrier with '$numParallelEval' permits
124128
val metrics = models.zip(epm).par.map { case (model, paramMap) =>
125129
numParBarrier.acquire()

0 commit comments

Comments
 (0)