From ad6034d9cdb21ec563d784eea3a9dd92c839365a Mon Sep 17 00:00:00 2001 From: John Brock Date: Mon, 24 Jul 2017 17:33:12 -0700 Subject: [PATCH] Change ML LogisticRegression setInitialModel from private to public and add getInitialModel method --- .../spark/ml/classification/LogisticRegression.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 65b09e571264..09f2d227effd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -477,11 +477,19 @@ class LogisticRegression @Since("1.2.0") ( private var optInitialModel: Option[LogisticRegressionModel] = None - private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = { + /** + * Set the initial logistic regression parameters, bypassing the default parameter initialization. + */ + def setInitialModel(model: LogisticRegressionModel): this.type = { this.optInitialModel = Some(model) this } + /** + * Return the user-supplied initial logistic regression model, if supplied + */ + def getInitialModel(): Option[LogisticRegressionModel] = this.optInitialModel + override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE train(dataset, handlePersistence)