Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
} (executionContext)
}

// Wait for metrics to be calculated before unpersisting validation dataset
// Wait for metrics to be calculated
val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))

// Unpersist training & validation set once all metrics have been produced
trainingDataset.unpersist()
validationDataset.unpersist()
foldMetrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St

// Fit models in a Future for training in parallel
logDebug(s"Train split with multiple sets of parameters.")
val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
Future[Model[_]] {
val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
Future[Double] {
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]

if (collectSubModelsParam) {
subModels.get(paramIndex) = model
}
model
} (executionContext)
}

// Unpersist training data only when all models have trained
Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
.onComplete { _ => trainingDataset.unpersist() } (executionContext)

// Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
modelFuture.map { model =>
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
logDebug(s"Got metric $metric for model trained with $paramMap.")
Expand All @@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
// Wait for all metrics to be calculated
val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))

// Unpersist validation set once all metrics have been produced
// Unpersist training & validation set once all metrics have been produced
trainingDataset.unpersist()
validationDataset.unpersist()

logInfo(s"Train validation split metrics: ${metrics.toSeq}")
Expand Down