-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-21386] ML LinearRegression supports warm start from user provided initial model. #18610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.param.shared | ||
|
|
||
| import org.apache.spark.ml.Model | ||
| import org.apache.spark.ml.param._ | ||
|
|
||
| private[ml] trait HasInitialModel[T <: Model[T]] extends Params { | ||
|
|
||
| def initialModel: Param[T] | ||
|
|
||
| /** @group getParam */ | ||
| final def getInitialModel: T = $(initialModel) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,9 +48,9 @@ import org.apache.spark.sql.types.DoubleType | |
| import org.apache.spark.storage.StorageLevel | ||
|
|
||
| /** | ||
| * Params for linear regression. | ||
| * Params for linear regression model. | ||
| */ | ||
| private[regression] trait LinearRegressionParams extends PredictorParams | ||
| private[regression] trait LinearRegressionModelParams extends PredictorParams | ||
| with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol | ||
| with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver | ||
| with HasAggregationDepth { | ||
|
|
@@ -71,6 +71,22 @@ private[regression] trait LinearRegressionParams extends PredictorParams | |
| ParamValidators.inArray[String](supportedSolvers)) | ||
| } | ||
|
|
||
| /** | ||
| * Params for linear regression. | ||
| */ | ||
| private[regression] trait LinearRegressionParams extends LinearRegressionModelParams | ||
| with HasInitialModel[LinearRegressionModel] { | ||
|
|
||
| /** | ||
| * A LinearRegressionModel to use for warm start. | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.3.0") | ||
| final val initialModel: Param[LinearRegressionModel] = new Param[LinearRegressionModel]( | ||
| this, "initialModel", "A LinearRegressionModel to use for warm start.") | ||
| } | ||
|
|
||
| /** | ||
| * Linear regression. | ||
| * | ||
|
|
@@ -208,6 +224,10 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
| def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) | ||
| setDefault(aggregationDepth -> 2) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setInitialModel(value: LinearRegressionModel): this.type = set(initialModel, value) | ||
|
|
||
| override protected def train(dataset: Dataset[_]): LinearRegressionModel = { | ||
| // Extract the number of features before deciding optimization solver. | ||
| val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size | ||
|
|
@@ -226,6 +246,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
|
|
||
| if (($(solver) == Auto && | ||
| numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == Normal) { | ||
|
|
||
| if (isSet(initialModel)) { | ||
| logWarning("Initial model will be ignored if fitting by normal solver. " + | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since initial model is a pretty important parameter. By setting the initial model, user would expect it to work and they may neglect the warning in the overwhelming Spark logs.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough, I will update it after collecting all comments. Thanks.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I disagree. Normal solver provides an exact solution, so it is harmless to ignore the initial model. I don't see a reason to do anything more than log a warning. |
||
| "Set solver with l-bfgs to make it take effect.") | ||
| } | ||
|
|
||
| // For low dimensional data, WeightedLeastSquares is more efficient since the | ||
| // training algorithm only requires one pass through the data. (SPARK-10668) | ||
|
|
||
|
|
@@ -251,6 +277,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
| return lrModel | ||
| } | ||
|
|
||
| if (isSet(initialModel)) { | ||
| val dimOfInitialModel = $(initialModel).coefficients.size | ||
| require(numFeatures == dimOfInitialModel, | ||
| s"The number of features in training dataset is $numFeatures, " + | ||
| s"which mismatched with dimension of initial model: $dimOfInitialModel.") | ||
| } | ||
|
|
||
| val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE | ||
| if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) | ||
|
|
||
|
|
@@ -365,7 +398,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
| new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol)) | ||
| } | ||
|
|
||
| val initialCoefficients = Vectors.zeros(numFeatures) | ||
| val initialCoefficients = if (isSet(initialModel)) { | ||
| $(initialModel).coefficients | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we only use |
||
| } else { | ||
| Vectors.zeros(numFeatures) | ||
| } | ||
| val states = optimizer.iterations(new CachedDiffFunction(costFun), | ||
| initialCoefficients.asBreeze.toDenseVector) | ||
|
|
||
|
|
@@ -441,16 +478,47 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
| model | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def write: MLWriter = new LinearRegression.LinearRegressionWriter(this) | ||
|
|
||
| @Since("1.4.0") | ||
| override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| object LinearRegression extends DefaultParamsReadable[LinearRegression] { | ||
| object LinearRegression extends MLReadable[LinearRegression] { | ||
|
|
||
| @Since("1.6.0") | ||
| override def load(path: String): LinearRegression = super.load(path) | ||
|
|
||
| @Since("1.6.0") | ||
| override def read: MLReader[LinearRegression] = new LinearRegressionReader | ||
|
|
||
| /** [[MLWriter]] instance for [[LinearRegression]] */ | ||
| private[LinearRegression] class LinearRegressionWriter(instance: LinearRegression) | ||
| extends MLWriter { | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| DefaultParamsWriter.saveInitialModel(instance, path) | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| } | ||
| } | ||
|
|
||
| private class LinearRegressionReader extends MLReader[LinearRegression] { | ||
|
|
||
| override def load(path: String): LinearRegression = { | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, classOf[LinearRegression].getName) | ||
| val instance = new LinearRegression(metadata.uid) | ||
|
|
||
| DefaultParamsReader.getAndSetParams(instance, metadata) | ||
| DefaultParamsReader.loadInitialModel[LinearRegressionModel](path, sc) match { | ||
| case Some(m) => instance.setInitialModel(m) | ||
| case None => // initialModel doesn't exist, do nothing | ||
| } | ||
| instance | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * When using `LinearRegression.solver` == "normal", the solver must limit the number of | ||
| * features to at most this number. The entire covariance matrix X^T^X will be collected | ||
|
|
@@ -481,7 +549,7 @@ class LinearRegressionModel private[ml] ( | |
| @Since("2.0.0") val coefficients: Vector, | ||
| @Since("1.3.0") val intercept: Double) | ||
| extends RegressionModel[Vector, LinearRegressionModel] | ||
| with LinearRegressionParams with MLWritable { | ||
| with LinearRegressionModelParams with MLWritable { | ||
|
|
||
| private var trainingSummary: Option[LinearRegressionTrainingSummary] = None | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,7 @@ import org.apache.spark.ml._ | |
| import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} | ||
| import org.apache.spark.ml.feature.RFormulaModel | ||
| import org.apache.spark.ml.param.{ParamPair, Params} | ||
| import org.apache.spark.ml.param.shared.HasInitialModel | ||
| import org.apache.spark.ml.tuning.ValidatorParams | ||
| import org.apache.spark.sql.{SparkSession, SQLContext} | ||
| import org.apache.spark.util.Utils | ||
|
|
@@ -282,6 +283,8 @@ private[ml] object DefaultParamsWriter { | |
| * Helper for [[saveMetadata()]] which extracts the JSON to save. | ||
| * This is useful for ensemble models which need to save metadata for many sub-models. | ||
| * | ||
| * Note: This function does not handle param `initialModel`, see [[saveInitialModel()]]. | ||
| * | ||
| * @see [[saveMetadata()]] for details on what this includes. | ||
| */ | ||
| def getMetadataToSave( | ||
|
|
@@ -291,7 +294,8 @@ private[ml] object DefaultParamsWriter { | |
| paramMap: Option[JValue] = None): String = { | ||
| val uid = instance.uid | ||
| val cls = instance.getClass.getName | ||
| val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] | ||
| val params = instance.extractParamMap().toSeq | ||
| .filter(_.param.name != "initialModel").asInstanceOf[Seq[ParamPair[Any]]] | ||
| val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => | ||
| p.name -> parse(p.jsonEncode(v)) | ||
| }.toList)) | ||
|
|
@@ -309,6 +313,23 @@ private[ml] object DefaultParamsWriter { | |
| val metadataJson: String = compact(render(metadata)) | ||
| metadataJson | ||
| } | ||
|
|
||
| /** | ||
| * Save estimator's `initialModel` to corresponding path. | ||
| */ | ||
| def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]]( | ||
| instance: T, path: String): Unit = { | ||
| if (instance.isDefined(instance.initialModel)) { | ||
| val initialModelPath = new Path(path, "initialModel").toString | ||
| val initialModel = instance.getOrDefault(instance.initialModel) | ||
| // When saving, only keep the direct initialModel by eliminating possible initialModels of the | ||
| // direct initialModel, to avoid unnecessary deep recursion of initialModel. | ||
| if (initialModel.hasParam("initialModel")) { | ||
| initialModel.clear(initialModel.getParam("initialModel")) | ||
| } | ||
| initialModel.save(initialModelPath) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to avoid recursion, here modify the model to cut down the recursion, but may cause the problem that influence the model which is maybe used in other place.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @WeichenXu123 Can you elaborate the concrete scenario? Thanks.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest when setInitialModel, first copy the passed in model (by Another way is avoid to add the initial model param to Model, only add it to the Estimator. (This way I think is more reasonable).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually we did in the later way,
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough. |
||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -437,6 +458,23 @@ private[ml] object DefaultParamsReader { | |
| val cls = Utils.classForName(metadata.className) | ||
| cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) | ||
| } | ||
|
|
||
| /** | ||
| * Load estimator's `initialModel` instance from the given path, and return it. | ||
| * If the `initialModel` path does not exist, it means the estimator does not have or | ||
| * set param `initialModel`, then return None. | ||
| * This assumes the model implements [[MLReadable]]. | ||
| */ | ||
| def loadInitialModel[M <: Model[M]](path: String, sc: SparkContext): Option[M] = { | ||
| val hadoopConf = sc.hadoopConfiguration | ||
| val initialModelPath = new Path(path, "initialModel") | ||
| val fs = initialModelPath.getFileSystem(hadoopConf) | ||
| if (fs.exists(initialModelPath)) { | ||
| Some(loadParamsInstance[M](initialModelPath.toString, sc)) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,15 +59,17 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => | |
| assert(newInstance.uid === instance.uid) | ||
| if (testParams) { | ||
| instance.params.foreach { p => | ||
| if (instance.isDefined(p)) { | ||
| (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { | ||
| case (Array(values), Array(newValues)) => | ||
| assert(values === newValues, s"Values do not match on param ${p.name}.") | ||
| case (value, newValue) => | ||
| assert(value === newValue, s"Values do not match on param ${p.name}.") | ||
| if (p.name != "initialModel") { | ||
| if (instance.isDefined(p)) { | ||
| (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { | ||
| case (Array(values), Array(newValues)) => | ||
| assert(values === newValues, s"Values do not match on param ${p.name}.") | ||
| case (value, newValue) => | ||
| assert(value === newValue, s"Values do not match on param ${p.name}.") | ||
| } | ||
| } else { | ||
| assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") | ||
| } | ||
| } else { | ||
| assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -113,7 +115,14 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => | |
| val estimator2 = testDefaultReadWrite(estimator) | ||
| testEstimatorParams.foreach { case (p, v) => | ||
| val param = estimator.getParam(p) | ||
| assert(estimator.get(param).get === estimator2.get(param).get) | ||
| if (param.name == "initialModel") { | ||
| // Estimator's `initialModel` has same type as the model produced by this estimator. | ||
|
||
| // So we can use `checkModelData` to check equality of `initialModel` as well. | ||
| checkModelData(estimator.get(param).get.asInstanceOf[M], | ||
| estimator2.get(param).get.asInstanceOf[M]) | ||
| } else { | ||
| assert(estimator.get(param).get === estimator2.get(param).get) | ||
| } | ||
| } | ||
|
|
||
| // Test Model save/load | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It maybe cleaner if we just move the param
initialModelinto LinearRegression? so we don't have to touch the class hierarchy.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The refactor here is not only for this PR, since model usually has less params than estimator, we make EstimatorParams extends from ModelParams in other places, for example, ALSParams extends ALSModelParams. I think the class hierarchy should be more clear.