@@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
143143
144144 // Fit models in a Future for training in parallel
145145 logDebug(s " Train split with multiple sets of parameters. " )
146- val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
147- Future [Model [_] ] {
146+ val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
147+ Future [Double ] {
148148 val model = est.fit(trainingDataset, paramMap).asInstanceOf [Model [_]]
149149
150150 if (collectSubModelsParam) {
151151 subModels.get(paramIndex) = model
152152 }
153- model
154- } (executionContext)
155- }
156-
157- // Unpersist training data only when all models have trained
158- Future .sequence[Model [_], Iterable ](modelFutures)(implicitly, executionContext)
159- .onComplete { _ => trainingDataset.unpersist() } (executionContext)
160-
161- // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
162- val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
163- modelFuture.map { model =>
164153 // TODO: duplicate evaluator to take extra params from input
165154 val metric = eval.evaluate(model.transform(validationDataset, paramMap))
166155 logDebug(s " Got metric $metric for model trained with $paramMap. " )
@@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
171160 // Wait for all metrics to be calculated
172161 val metrics = metricFutures.map(ThreadUtils .awaitResult(_, Duration .Inf ))
173162
174- // Unpersist validation set once all metrics have been produced
163+ // Unpersist training & validation set once all metrics have been produced
164+ trainingDataset.unpersist()
175165 validationDataset.unpersist()
176166
177167 logInfo(s " Train validation split metrics: ${metrics.toSeq}" )
0 commit comments