From c871909aeb43a4eee4e3e9677bae6a1e625a6d52 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 12 Jul 2017 16:45:32 +0800 Subject: [PATCH 1/2] ML LinearRegression supports warm start by user provided initial model. --- .../shared/sharedGeneralTypeParams.scala | 29 +++++++ .../ml/regression/LinearRegression.scala | 84 +++++++++++++++++-- .../org/apache/spark/ml/util/ReadWrite.scala | 40 ++++++++- .../ml/regression/LinearRegressionSuite.scala | 48 +++++++++-- .../spark/ml/util/DefaultReadWriteTest.scala | 27 ++++-- 5 files changed, 205 insertions(+), 23 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala new file mode 100644 index 0000000000000..c67380edaa60e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedGeneralTypeParams.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.ml.Model +import org.apache.spark.ml.param._ + +private[ml] trait HasInitialModel[T <: Model[T]] extends Params { + + def initialModel: Param[T] + + /** @group getParam */ + final def getInitialModel: T = $(initialModel) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 91cd229704a37..707079f649f45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -48,12 +48,12 @@ import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel /** - * Params for linear regression. + * Params for linear regression model. */ -private[regression] trait LinearRegressionParams extends PredictorParams - with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol - with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver - with HasAggregationDepth { +private[regression] trait LinearRegressionModelParams extends PredictorParams + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol + with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver + with HasAggregationDepth { import LinearRegression._ @@ -71,6 +71,22 @@ private[regression] trait LinearRegressionParams extends PredictorParams ParamValidators.inArray[String](supportedSolvers)) } +/** + * Params for linear regression. + */ +private[regression] trait LinearRegressionParams extends LinearRegressionModelParams + with HasInitialModel[LinearRegressionModel] { + + /** + * A LinearRegressionModel to use for warm start. + * + * @group param + */ + @Since("2.3.0") + final val initialModel: Param[LinearRegressionModel] = new Param[LinearRegressionModel]( + this, "initialModel", "A LinearRegressionModel to use for warm start.") +} + /** * Linear regression. * @@ -208,6 +224,10 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** @group setParam */ + @Since("2.3.0") + def setInitialModel(value: LinearRegressionModel): this.type = set(initialModel, value) + override protected def train(dataset: Dataset[_]): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size @@ -226,6 +246,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String if (($(solver) == Auto && numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == Normal) { + + if (isSet(initialModel)) { + logWarning("Initial model will be ignored if fitting by normal solver. " + + "Set solver with l-bfgs to make it take effect.") + } + // For low dimensional data, WeightedLeastSquares is more efficient since the // training algorithm only requires one pass through the data. (SPARK-10668) @@ -251,6 +277,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String return lrModel } + if (isSet(initialModel)) { + val dimOfInitialModel = $(initialModel).coefficients.size + require(numFeatures == dimOfInitialModel, + s"The number of features in training dataset is $numFeatures, " + + s"which mismatched with dimension of initial model: $dimOfInitialModel.") + } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -365,7 +398,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol)) } - val initialCoefficients = Vectors.zeros(numFeatures) + val initialCoefficients = if (isSet(initialModel)) { + $(initialModel).coefficients + } else { + Vectors.zeros(numFeatures) + } val states = optimizer.iterations(new CachedDiffFunction(costFun), initialCoefficients.asBreeze.toDenseVector) @@ -441,16 +478,47 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model } + @Since("1.6.0") + override def write: MLWriter = new LinearRegression.LinearRegressionWriter(this) + @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) } @Since("1.6.0") -object LinearRegression extends DefaultParamsReadable[LinearRegression] { +object LinearRegression extends MLReadable[LinearRegression] { @Since("1.6.0") override def load(path: String): LinearRegression = super.load(path) + @Since("1.6.0") + override def read: MLReader[LinearRegression] = new LinearRegressionReader + + /** [[MLWriter]] instance for [[LinearRegression]] */ + private[LinearRegression] class LinearRegressionWriter(instance: LinearRegression) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveInitialModel(instance, path) + DefaultParamsWriter.saveMetadata(instance, path, sc) + } + } + + private class LinearRegressionReader extends MLReader[LinearRegression] { + + override def load(path: String): LinearRegression = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, classOf[LinearRegression].getName) + val instance = new LinearRegression(metadata.uid) + + DefaultParamsReader.getAndSetParams(instance, metadata) + DefaultParamsReader.loadInitialModel[LinearRegressionModel](path, sc) match { + case Some(m) => instance.setInitialModel(m) + case None => // initialModel doesn't exist, do nothing + } + instance + } + } + /** * When using `LinearRegression.solver` == "normal", the solver must limit the number of * features to at most this number. The entire covariance matrix X^T^X will be collected @@ -481,7 +549,7 @@ class LinearRegressionModel private[ml] ( @Since("2.0.0") val coefficients: Vector, @Since("1.3.0") val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with MLWritable { + with LinearRegressionModelParams with MLWritable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index b54e258cff2f8..ae6f790ee16b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -32,6 +32,7 @@ import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.ml.param.shared.HasInitialModel import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.util.Utils @@ -282,6 +283,8 @@ private[ml] object DefaultParamsWriter { * Helper for [[saveMetadata()]] which extracts the JSON to save. * This is useful for ensemble models which need to save metadata for many sub-models. * + * Note: This function does not handle param `initialModel`, see [[saveInitialModel()]]. + * * @see [[saveMetadata()]] for details on what this includes. */ def getMetadataToSave( @@ -291,7 +294,8 @@ private[ml] object DefaultParamsWriter { paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val params = instance.extractParamMap().toSeq + .filter(_.param.name != "initialModel").asInstanceOf[Seq[ParamPair[Any]]] val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) @@ -309,6 +313,23 @@ private[ml] object DefaultParamsWriter { val metadataJson: String = compact(render(metadata)) metadataJson } + + /** + * Save estimator's `initialModel` to corresponding path. + */ + def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]]( + instance: T, path: String): Unit = { + if (instance.isDefined(instance.initialModel)) { + val initialModelPath = new Path(path, "initialModel").toString + val initialModel = instance.getOrDefault(instance.initialModel) + // When saving, only keep the direct initialModel by eliminating possible initialModels of the + // direct initialModel, to avoid unnecessary deep recursion of initialModel. + if (initialModel.hasParam("initialModel")) { + initialModel.clear(initialModel.getParam("initialModel")) + } + initialModel.save(initialModelPath) + } + } } /** @@ -437,6 +458,23 @@ private[ml] object DefaultParamsReader { val cls = Utils.classForName(metadata.className) cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) } + + /** + * Load estimator's `initialModel` instance from the given path, and return it. + * If the `initialModel` path does not exist, it means the estimator does not have or + * set param `initialModel`, then return None. + * This assumes the model implements [[MLReadable]]. + */ + def loadInitialModel[M <: Model[M]](path: String, sc: SparkContext): Option[M] = { + val hadoopConf = sc.hadoopConfiguration + val initialModelPath = new Path(path, "initialModel") + val fs = initialModelPath.getFileSystem(hadoopConf) + if (fs.exists(initialModelPath)) { + Some(loadParamsInstance[M](initialModelPath.toString, sc)) + } else { + None + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index e7bd4eb9e0adf..ba4ad2e489320 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -984,8 +984,9 @@ class LinearRegressionSuite assert(model.coefficients === model2.coefficients) } val lr = new LinearRegression() - testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, - LinearRegressionSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite( + lr, datasetWithWeight, LinearRegressionSuite.estimatorParamSettings, + LinearRegressionSuite.modelParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { @@ -998,16 +999,45 @@ class LinearRegressionSuite } } } + + test("training with initial model") { + val trainer1 = new LinearRegression() + .setRegParam(2.3) + .setFitIntercept(false) + .setSolver("l-bfgs") + + val model1 = trainer1.fit(datasetWithSparseFeature) + val model2 = trainer1.setInitialModel(model1).fit(datasetWithSparseFeature) + + assert(model1.coefficients ~== model2.coefficients absTol 1E-3) + assert(model1.intercept === model2.intercept) + assert(model1.summary.objectiveHistory.length < model2.summary.objectiveHistory.length) + + // Initial model will be ignored if fitting by normal solver. + val trainer2 = new LinearRegression().setRegParam(2.3).setFitIntercept(false) + val model3 = trainer2.fit(datasetWithDenseFeature) + val model4 = trainer2.setInitialModel(model3).fit(datasetWithDenseFeature) + + assert(model3.coefficients ~== model4.coefficients absTol 1E-3) + assert(model3.intercept === model4.intercept) + assert(model3.summary.objectiveHistory.length === model4.summary.objectiveHistory.length) + + // Training dataset dimension mismatched. + intercept[IllegalArgumentException] { + val initialModel = new LinearRegressionModel("linReg", Vectors.dense(1.0), 1.0) + new LinearRegression().setInitialModel(initialModel).fit(datasetWithSparseFeature) + } + } } object LinearRegressionSuite { /** - * Mapping from all Params to valid settings which differ from the defaults. + * Mapping from estimator Params to valid settings which differ from the defaults. * This is useful for tests which need to exercise all Params, such as save/load. * This excludes input columns to simplify some tests. */ - val allParamSettings: Map[String, Any] = Map( + val estimatorParamSettings: Map[String, Any] = Map( "predictionCol" -> "myPrediction", "regParam" -> 0.01, "elasticNetParam" -> 0.1, @@ -1015,6 +1045,14 @@ object LinearRegressionSuite { "fitIntercept" -> true, "tol" -> 0.8, "standardization" -> false, - "solver" -> "l-bfgs" + "solver" -> "l-bfgs", + "initialModel" -> new LinearRegressionModel("linReg", Vectors.dense(1.0, 1.0), 1.0) ) + + /** + * Mapping from model Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val modelParamSettings: Map[String, Any] = Map("predictionCol" -> "myPrediction") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 27d606cb05dc2..f2be02443cc09 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -59,15 +59,17 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => assert(newInstance.uid === instance.uid) if (testParams) { instance.params.foreach { p => - if (instance.isDefined(p)) { - (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { - case (Array(values), Array(newValues)) => - assert(values === newValues, s"Values do not match on param ${p.name}.") - case (value, newValue) => - assert(value === newValue, s"Values do not match on param ${p.name}.") + if (p.name != "initialModel") { + if (instance.isDefined(p)) { + (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { + case (Array(values), Array(newValues)) => + assert(values === newValues, s"Values do not match on param ${p.name}.") + case (value, newValue) => + assert(value === newValue, s"Values do not match on param ${p.name}.") + } + } else { + assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } - } else { - assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") } } } @@ -113,7 +115,14 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => val estimator2 = testDefaultReadWrite(estimator) testEstimatorParams.foreach { case (p, v) => val param = estimator.getParam(p) - assert(estimator.get(param).get === estimator2.get(param).get) + if (param.name == "initialModel") { + // Estimator's `initialModel` has same type as the model produced by this estimator. + // So we can use `checkModelData` to check equality of `initialModel` as well. + checkModelData(estimator.get(param).get.asInstanceOf[M], + estimator2.get(param).get.asInstanceOf[M]) + } else { + assert(estimator.get(param).get === estimator2.get(param).get) + } } // Test Model save/load From e49d45271c923eeeba3f3bbbd79dc5cbf0c91ab1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 12 Jul 2017 17:28:50 +0800 Subject: [PATCH 2/2] Fix align and add mima excludes. --- .../org/apache/spark/ml/regression/LinearRegression.scala | 6 +++--- project/MimaExcludes.scala | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 707079f649f45..e8f9aef030a4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -51,9 +51,9 @@ import org.apache.spark.storage.StorageLevel * Params for linear regression model. */ private[regression] trait LinearRegressionModelParams extends PredictorParams - with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol - with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver - with HasAggregationDepth { + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol + with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver + with HasAggregationDepth { import LinearRegression._ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1793da03a2c3e..030dc89bc69b0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -1012,6 +1012,13 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy") + ) ++ Seq( + // [SPARK-21386] ML LinearRegression supports warm start from user provided initial model. + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.regression.LinearRegressionModelParams.org$apache$spark$ml$regression$LinearRegressionModelParams$_setter_$solver_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasInitialModel.initialModel"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasInitialModel.getInitialModel"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.LinearRegression$"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.LinearRegressionModel") ) }