Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,13 @@ object CrossValidator extends MLReadable[CrossValidator] {

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val seed = (metadata.params \ "seed").extract[Long]
new CrossValidator(metadata.uid)
val cv = new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
.setNumFolds(numFolds)
.setSeed(seed)
DefaultParamsReader.getAndSetParams(cv, metadata,
skipParams = Option(List("estimatorParamMaps")))
cv
}
}
}
Expand Down Expand Up @@ -302,17 +301,17 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val seed = (metadata.params \ "seed").extract[Long]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray

val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
model.set(model.estimator, estimator)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
.set(model.numFolds, numFolds)
.set(model.seed, seed)
DefaultParamsReader.getAndSetParams(model, metadata,
skipParams = Option(List("estimatorParamMaps")))
model
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.tuning

import java.io.IOException
import java.util.{List => JList}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -207,14 +208,13 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
val seed = (metadata.params \ "seed").extract[Long]
new TrainValidationSplit(metadata.uid)
val tvs = new TrainValidationSplit(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
.setTrainRatio(trainRatio)
.setSeed(seed)
DefaultParamsReader.getAndSetParams(tvs, metadata,
skipParams = Option(List("estimatorParamMaps")))
tvs
}
}
}
Expand Down Expand Up @@ -295,17 +295,17 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
val seed = (metadata.params \ "seed").extract[Long]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray

val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
model.set(model.estimator, estimator)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
.set(model.trainRatio, trainRatio)
.set(model.seed, seed)
DefaultParamsReader.getAndSetParams(model, metadata,
skipParams = Option(List("estimatorParamMaps")))
model
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,14 @@ private[ml] object ValidatorParams {
}.toSeq
))

val validatorSpecificParams = instance match {
case cv: CrossValidatorParams =>
List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds)))
case tvs: TrainValidationSplitParams =>
List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio)))
case _ =>
// This should not happen.
throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " +
instance.getClass.getCanonicalName)
}

val jsonParams = validatorSpecificParams ++ List(
"estimatorParamMaps" -> parse(estimatorParamMapsJson),
"seed" -> parse(instance.seed.jsonEncode(instance.getSeed)))
val params = instance.extractParamMap().toSeq
val skipParams = List("estimator", "evaluator", "estimatorParamMaps")
val jsonParams = render(params
.filter { case ParamPair(p, v) => !skipParams.contains(p.name)}
.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson))
)

DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))

Expand Down
20 changes: 15 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -396,17 +396,27 @@ private[ml] object DefaultParamsReader {

/**
* Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
* This works if all Params (except params included by `skipParams` list) implement
* [[org.apache.spark.ml.param.Param.jsonDecode()]].
*
* @param skipParams The params included in `skipParams` won't be set. This is useful if some
* params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]]
* and need special handling.
* TODO: Move to [[Metadata]] method
*/
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
def getAndSetParams(
instance: Params,
metadata: Metadata,
skipParams: Option[List[String]] = None): Unit = {
implicit val format = DefaultFormats
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
if (skipParams == None || !skipParams.get.contains(paramName)) {
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
}
}
case _ =>
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,15 @@ class CrossValidatorSuite
.setEvaluator(evaluator)
.setNumFolds(20)
.setEstimatorParamMaps(paramMaps)
.setSeed(42L)
.setParallelism(2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the test for the model too please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.


val cv2 = testDefaultReadWrite(cv, testParams = false)

assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds)
assert(cv.getSeed === cv2.getSeed)
assert(cv.getParallelism === cv2.getParallelism)

assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{ParamMap}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
Expand Down Expand Up @@ -160,11 +160,13 @@ class TrainValidationSplitSuite
.setTrainRatio(0.5)
.setEstimatorParamMaps(paramMaps)
.setSeed(42L)
.setParallelism(2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update the test for the Model too please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. The model do not own parallel parameter. This was discussed before.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right, thanks


val tvs2 = testDefaultReadWrite(tvs, testParams = false)

assert(tvs.getTrainRatio === tvs2.getTrainRatio)
assert(tvs.getSeed === tvs2.getSeed)
assert(tvs.getParallelism === tvs2.getParallelism)

ValidatorParamsSuiteHelpers
.compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)
Expand Down