diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 1247882d6c1bd..40d9fe2fb1d76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -18,10 +18,14 @@ package org.apache.spark.ml import scala.annotation.varargs +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.sql.Dataset +import org.apache.spark.util.ThreadUtils /** * :: DeveloperApi :: @@ -82,5 +86,49 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { paramMaps.map(fit(dataset, _)) } + /** + * (Java-specific) + */ + @Since("2.3.0") + def fit(dataset: Dataset[_], paramMaps: Array[ParamMap], + unpersistDatasetAfterFitting: Boolean, executionContext: ExecutionContext, + modelCallback: VoidFunction2[Model[_], Int]): Unit = { + // Fit models in a Future for training in parallel + val modelFutures = paramMaps.map { paramMap => + Future[Model[_]] { + fit(dataset, paramMap).asInstanceOf[Model[_]] + } (executionContext) + } + + if (unpersistDatasetAfterFitting) { + // Unpersist training data only when all models have trained + Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) + .onComplete { _ => dataset.unpersist() } (executionContext) + } + + val modelCallbackFutures = modelFutures.zipWithIndex.map { + case (modelFuture, paramMapIndex) => + modelFuture.map { model => + modelCallback.call(model, paramMapIndex) + } (executionContext) + } + modelCallbackFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + } + + /** + * (Scala-specific) + */ + @Since("2.3.0") + def fit(dataset: Dataset[_], paramMaps: Array[ParamMap], + unpersistDatasetAfterFitting: Boolean, executionContext: ExecutionContext, + modelCallback: (Model[_], Int) => Unit): Unit = { + fit(dataset, paramMaps, unpersistDatasetAfterFitting, executionContext, + new VoidFunction2[Model[_], Int] { + override def call(model: Model[_], paramMapIndex: Int): Unit = { + modelCallback(model, paramMapIndex) + } + }) + } + override def copy(extra: ParamMap): Estimator[M] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 1682ca91bf832..730eacf1668ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -146,34 +146,20 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val validationDataset = sparkSession.createDataFrame(validation, schema).cache() logDebug(s"Train split $splitIndex with multiple sets of parameters.") - // Fit models in a Future for training in parallel - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { - val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] - + val foldMetrics = new Array[Double](epm.length) + est.fit(trainingDataset, epm, true, executionContext, + (model: Model[_], paramMapIndex: Int) => { + val paramMap = epm(paramMapIndex) if (collectSubModelsParam) { - subModels.get(splitIndex)(paramIndex) = model + subModels.get(splitIndex)(paramMapIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") - metric - } (executionContext) - } + foldMetrics(paramMapIndex) = metric + } + ) - // Wait for metrics to be calculated before unpersisting validation dataset - val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) validationDataset.unpersist() foldMetrics }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c73bd18475475..424e941da2221 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,33 +143,20 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { - val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] + val metrics = new Array[Double](epm.length) + est.fit(trainingDataset, epm, true, executionContext, + (model: Model[_], paramMapIndex: Int) => { + val paramMap = epm(paramMapIndex) if (collectSubModelsParam) { - subModels.get(paramIndex) = model + subModels.get(paramMapIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") - metric - } (executionContext) - } - - // Wait for all metrics to be calculated - val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + metrics(paramMapIndex) = metric + } + ) // Unpersist validation set once all metrics have been produced validationDataset.unpersist()