Skip to content

Commit 6aa46f8

Browse files
committed
Apply CrossValidator approach to Driver/Distributed memory tradeoff for
TrainValidationSplit.
1 parent 9a2b65a commit 6aa46f8

File tree

2 files changed

+7
-15
lines changed

2 files changed

+7
-15
lines changed

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
160160
} (executionContext)
161161
}
162162

163-
// Wait for metrics to be calculated before unpersisting validation dataset
163+
// Wait for metrics to be calculated
164164
val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
165+
166+
// Unpersist training & validation set once all metrics have been produced
165167
trainingDataset.unpersist()
166168
validationDataset.unpersist()
167169
foldMetrics

mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
143143

144144
// Fit models in a Future for training in parallel
145145
logDebug(s"Train split with multiple sets of parameters.")
146-
val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
147-
Future[Model[_]] {
146+
val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
147+
Future[Double] {
148148
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
149149

150150
if (collectSubModelsParam) {
151151
subModels.get(paramIndex) = model
152152
}
153-
model
154-
} (executionContext)
155-
}
156-
157-
// Unpersist training data only when all models have trained
158-
Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
159-
.onComplete { _ => trainingDataset.unpersist() } (executionContext)
160-
161-
// Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
162-
val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
163-
modelFuture.map { model =>
164153
// TODO: duplicate evaluator to take extra params from input
165154
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
166155
logDebug(s"Got metric $metric for model trained with $paramMap.")
@@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
171160
// Wait for all metrics to be calculated
172161
val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
173162

174-
// Unpersist validation set once all metrics have been produced
163+
// Unpersist training & validation set once all metrics have been produced
164+
trainingDataset.unpersist()
175165
validationDataset.unpersist()
176166

177167
logInfo(s"Train validation split metrics: ${metrics.toSeq}")

0 commit comments

Comments
 (0)