Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)

/**
* Optional parameter. If set, all the trained models during cross validation will be
* saved in the specific path. By default the models will not be preserved.
*
* @group expertSetParam
*/
@Since("2.3.0")
def setModelPreservePath(value: String): this.type = set(modelPreservePath, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val schema = dataset.schema
Expand All @@ -113,15 +122,28 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
// multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
trainingDataset.unpersist()

var i = 0
while (i < numModels) {
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
if (isDefined(modelPreservePath)) {
models(i) match {
case w: MLWritable =>
// e.g. maxIter-5-regParam-0.001-split0-0.859
val fileName = epm(i).toSeq.map(p => p.param.name + "-" + p.value).sorted
.mkString("-") + s"-split$splitIndex-${math.rint(metric * 1000) / 1000}"
w.save(new Path($(modelPreservePath), fileName).toString)
case _ =>
// for third-party algorithms
logWarning(models(i).uid + " did not implement MLWritable. Serialization omitted.")
}
}
metrics(i) += metric
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @hhbyyh @jkbradley
I am sorry but I want to say that this code is incorrect. Inside fit method save the list of fitted models help nothing.
What we need is to let CrossValidatorModel/TrainValidationSplitModel preserve the full list of fitted models, when save CrossValidatorModel/TrainValidationSplitModel we can save the list of models, also we can choose not to save, controlled by a parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you want to keep all the trained models in CrossValidatorModel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think so.
In order to save time, I would like to take over this feature, if you don't mind.
ping @jkbradley

Copy link
Contributor Author

@hhbyyh hhbyyh Jul 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's interesting...
If somehow we follow your suggestion, with numFolds = 10, and param grid size = 8, you'll be holding 80 models in the driver memory (CrossValidatorModel) at the same time. I would be very surprised to see anyone go along with your suggestion.
I'll happily close the PR if your suggestion turns out to be good.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hhbyyh So I we need to set parameter default value false, only when user really need this they turn on this feature... I will discuss with @jkbradley later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that 80 models in driver memory sounds like a lot. However, we already are holding that many in driver memory at once in val models = est.fit(trainingDataset, epm), so that should not be a problem for current use cases.

Scaling to large models which do not fit in memory is a different problem, but your PR does bring up the issue that exposing something like models: Seq[...] could cause problems in the future if we want to scale more. I'd suggest 2 things:

  • The models could be exposed via a getter, rather than a val. In the future, if the models are not available, the getter could throw a nice exception.
  • In the future, we could add the Param which you are suggesting for dumping the models to some directory during training. Feel free to preserve this PR for that, but I think this PR is overkill for most users' needs.

i += 1
}
trainingDataset.unpersist()
validationDataset.unpersist()
}
f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)

/**
* Optional parameter. If set, all the models fitted during the training will be saved
* under the specific directory path. By default the models will not be saved.
*
* @group expertSetParam
*/
@Since("2.3.0")
def setModelPreservePath(value: String): this.type = set(modelPreservePath, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
val schema = dataset.schema
Expand All @@ -109,15 +118,27 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
// multi-model training
logDebug(s"Train split with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
trainingDataset.unpersist()

var i = 0
while (i < numModels) {
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
if (isDefined(modelPreservePath)) {
models(i) match {
case w: MLWritable =>
// e.g. maxIter-5-regParam-0.001-0.859
val fileName = epm(i).toSeq.map(p => p.param.name + "-" + p.value).sorted
.mkString("-") + s"-${math.rint(metric * 1000) / 1000}"
w.save(new Path($(modelPreservePath), fileName).toString)
case _ =>
logWarning(models(i).uid + " did not implement MLWritable. Serialization omitted.")
}
}
metrics(i) += metric
i += 1
}
trainingDataset.unpersist()
validationDataset.unpersist()

logInfo(s"Train validation split metrics: ${metrics.toSeq}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.json4s.{DefaultFormats, _}
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
Expand Down Expand Up @@ -67,6 +68,21 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
/** @group getParam */
def getEvaluator: Evaluator = $(evaluator)


/**
* Optional parameter. If set, all the models trained during the tuning grid search will be
* saved in the specific path. By default the models will not be preserved.
*
* @group expertParam
*/
val modelPreservePath: Param[String] = new Param(this, "modelPath",
"Optional parameter. If set, all the models fitted during the cross validation will be" +
" saved in the path")

/** @group expertGetParam */
@Since("2.3.0")
def getModelPreservePath: String = $(modelPreservePath)

protected def transformSchemaImpl(schema: StructType): StructType = {
require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps")
val firstEstimatorParamMap = $(estimatorParamMaps).head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.linalg.{DenseMatrix, Vectors}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{ParamMap, ParamPair}
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

class CrossValidatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Expand All @@ -56,6 +57,7 @@ class CrossValidatorSuite
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setNumFolds(3)
assert(!cv.isDefined(cv.modelPreservePath))
val cvModel = cv.fit(dataset)

MLTestingUtils.checkCopyAndUids(cv, cvModel)
Expand Down Expand Up @@ -242,6 +244,29 @@ class CrossValidatorSuite
}
}

test("cross validation with model path to save trained models") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
val lr = new LogisticRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.001, 1000.0))
.addGrid(lr.maxIter, Array(0, 5))
.build()
val eval = new BinaryClassificationEvaluator
val cv = new CrossValidator()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setNumFolds(3)
.setModelPreservePath(path)
try {
cv.fit(dataset)
assert(tempDir.list().length === 3 * 2 * 2)
} finally {
Utils.deleteRecursively(tempDir)
}
}

test("read/write: CrossValidatorModel") {
val lr = new LogisticRegression()
.setThreshold(0.6)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

class TrainValidationSplitSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Expand All @@ -53,6 +54,7 @@ class TrainValidationSplitSuite
.setSeed(42L)
val tvsModel = tvs.fit(dataset)
val parent = tvsModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(!tvs.isDefined(tvs.modelPreservePath))
assert(tvs.getTrainRatio === 0.5)
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
Expand Down Expand Up @@ -117,6 +119,32 @@ class TrainValidationSplitSuite
}
}

test("train validation with modelPath to save trained models") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString

val dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF()
val lr = new LogisticRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.001, 1000.0))
.addGrid(lr.maxIter, Array(0, 10))
.build()
val eval = new BinaryClassificationEvaluator
val tvs = new TrainValidationSplit()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
.setSeed(42L)
.setModelPreservePath(path)
try {
tvs.fit(dataset)
assert(tempDir.list().length === 2 * 2)
} finally {
Utils.deleteRecursively(tempDir)
}
}

test("read/write: TrainValidationSplit") {
val lr = new LogisticRegression().setMaxIter(3)
val evaluator = new BinaryClassificationEvaluator()
Expand Down