Skip to content

Commit d5625a6

Browse files
committed
init pr
1 parent 20adf9a commit d5625a6

File tree

3 files changed

+42
-39
lines changed

3 files changed

+42
-39
lines changed

mllib/src/main/scala/org/apache/spark/ml/Estimator.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
package org.apache.spark.ml
1919

2020
import scala.annotation.varargs
21+
import scala.concurrent.{ExecutionContext, Future}
22+
import scala.concurrent.duration.Duration
2123

2224
import org.apache.spark.annotation.{DeveloperApi, Since}
2325
import org.apache.spark.ml.param.{ParamMap, ParamPair}
2426
import org.apache.spark.sql.Dataset
27+
import org.apache.spark.util.ThreadUtils
2528

2629
/**
2730
* :: DeveloperApi ::
@@ -82,5 +85,32 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
8285
paramMaps.map(fit(dataset, _))
8386
}
8487

88+
@Since("2.3.0")
89+
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap],
90+
unpersistDatasetAfterFitting: Boolean, executionContext: ExecutionContext,
91+
modelCallback: (Model[_], ParamMap, Int) => Unit
92+
): Unit = {
93+
// Fit models in a Future for training in parallel
94+
val modelFutures = paramMaps.map { paramMap =>
95+
Future[Model[_]] {
96+
fit(dataset, paramMap).asInstanceOf[Model[_]]
97+
} (executionContext)
98+
}
99+
100+
if (unpersistDatasetAfterFitting) {
101+
// Unpersist training data only when all models have trained
102+
Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
103+
.onComplete { _ => dataset.unpersist() }(executionContext)
104+
}
105+
106+
val modelCallbackFutures = modelFutures.zipWithIndex.map {
107+
case (modelFuture, paramMapIndex) =>
108+
modelFuture.map { model =>
109+
modelCallback(model, paramMaps(paramMapIndex), paramMapIndex)
110+
}(executionContext)
111+
}
112+
modelCallbackFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
113+
}
114+
85115
override def copy(extra: ParamMap): Estimator[M]
86116
}

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -124,30 +124,16 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
124124
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
125125
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
126126

127-
// Fit models in a Future for training in parallel
128-
val modelFutures = epm.map { paramMap =>
129-
Future[Model[_]] {
130-
val model = est.fit(trainingDataset, paramMap)
131-
model.asInstanceOf[Model[_]]
132-
} (executionContext)
133-
}
134-
135-
// Unpersist training data only when all models have trained
136-
Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
137-
.onComplete { _ => trainingDataset.unpersist() } (executionContext)
138-
139-
// Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
140-
val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
141-
modelFuture.map { model =>
127+
val foldMetrics = new Array[Double](epm.length)
128+
est.fit(trainingDataset, epm, true, executionContext,
129+
(model: Model[_], paramMap: ParamMap, paramMapIndex: Int) => {
142130
// TODO: duplicate evaluator to take extra params from input
143131
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
144132
logDebug(s"Got metric $metric for model trained with $paramMap.")
145-
metric
146-
} (executionContext)
147-
}
133+
foldMetrics(paramMapIndex) = metric
134+
}
135+
)
148136

149-
// Wait for metrics to be calculated before unpersisting validation dataset
150-
val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
151137
validationDataset.unpersist()
152138
foldMetrics
153139
}.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits

mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -123,29 +123,16 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
123123

124124
// Fit models in a Future for training in parallel
125125
logDebug(s"Train split with multiple sets of parameters.")
126-
val modelFutures = epm.map { paramMap =>
127-
Future[Model[_]] {
128-
val model = est.fit(trainingDataset, paramMap)
129-
model.asInstanceOf[Model[_]]
130-
} (executionContext)
131-
}
132-
133-
// Unpersist training data only when all models have trained
134-
Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
135-
.onComplete { _ => trainingDataset.unpersist() } (executionContext)
136126

137-
// Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
138-
val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
139-
modelFuture.map { model =>
127+
val metrics = new Array[Double](epm.length)
128+
est.fit(trainingDataset, epm, true, executionContext,
129+
(model: Model[_], paramMap: ParamMap, paramMapIndex: Int) => {
140130
// TODO: duplicate evaluator to take extra params from input
141131
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
142132
logDebug(s"Got metric $metric for model trained with $paramMap.")
143-
metric
144-
} (executionContext)
145-
}
146-
147-
// Wait for all metrics to be calculated
148-
val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
133+
metrics(paramMapIndex) = metric
134+
}
135+
)
149136

150137
// Unpersist validation set once all metrics have been produced
151138
validationDataset.unpersist()

0 commit comments

Comments
 (0)