Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.param.shared

import org.apache.spark.ml.Model
import org.apache.spark.ml.param._

private[ml] trait HasInitialModel[T <: Model[T]] extends Params {

def initialModel: Param[T]

/** @group getParam */
final def getInitialModel: T = $(initialModel)
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel

/**
* Params for linear regression.
* Params for linear regression model.
*/
private[regression] trait LinearRegressionParams extends PredictorParams
private[regression] trait LinearRegressionModelParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
with HasAggregationDepth {
Expand All @@ -71,6 +71,22 @@ private[regression] trait LinearRegressionParams extends PredictorParams
ParamValidators.inArray[String](supportedSolvers))
}

/**
* Params for linear regression.
*/
private[regression] trait LinearRegressionParams extends LinearRegressionModelParams
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It maybe cleaner if we just move the param initialModel into LinearRegression? so we don't have to touch the class hierarchy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The refactor here is not only for this PR, since model usually has less params than estimator, we make EstimatorParams extends from ModelParams in other places, for example, ALSParams extends ALSModelParams. I think the class hierarchy should be more clear.

with HasInitialModel[LinearRegressionModel] {

/**
* A LinearRegressionModel to use for warm start.
*
* @group param
*/
@Since("2.3.0")
final val initialModel: Param[LinearRegressionModel] = new Param[LinearRegressionModel](
this, "initialModel", "A LinearRegressionModel to use for warm start.")
}

/**
* Linear regression.
*
Expand Down Expand Up @@ -208,6 +224,10 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
setDefault(aggregationDepth -> 2)

/** @group setParam */
@Since("2.3.0")
def setInitialModel(value: LinearRegressionModel): this.type = set(initialModel, value)

override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
// Extract the number of features before deciding optimization solver.
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
Expand All @@ -226,6 +246,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String

if (($(solver) == Auto &&
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == Normal) {

if (isSet(initialModel)) {
logWarning("Initial model will be ignored if fitting by normal solver. " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since initial model is a pretty important parameter. By setting the initial model, user would expect it to work and they may neglect the warning in the overwhelming Spark logs.
Maybe we can move the parameter check to transformSchema and throws an exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, I will update it after collecting all comments. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree. Normal solver provides an exact solution, so it is harmless to ignore the initial model. I don't see a reason to do anything more than log a warning.

"Set solver with l-bfgs to make it take effect.")
}

// For low dimensional data, WeightedLeastSquares is more efficient since the
// training algorithm only requires one pass through the data. (SPARK-10668)

Expand All @@ -251,6 +277,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
return lrModel
}

if (isSet(initialModel)) {
val dimOfInitialModel = $(initialModel).coefficients.size
require(numFeatures == dimOfInitialModel,
s"The number of features in training dataset is $numFeatures, " +
s"which mismatched with dimension of initial model: $dimOfInitialModel.")
}

val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

Expand Down Expand Up @@ -365,7 +398,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol))
}

val initialCoefficients = Vectors.zeros(numFeatures)
val initialCoefficients = if (isSet(initialModel)) {
$(initialModel).coefficients
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we only use coefficients, not use intercept. This is because the intercept was computed using closed form after the coefficients are converged.

} else {
Vectors.zeros(numFeatures)
}
val states = optimizer.iterations(new CachedDiffFunction(costFun),
initialCoefficients.asBreeze.toDenseVector)

Expand Down Expand Up @@ -441,16 +478,47 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model
}

@Since("1.6.0")
override def write: MLWriter = new LinearRegression.LinearRegressionWriter(this)

@Since("1.4.0")
override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
}

@Since("1.6.0")
object LinearRegression extends DefaultParamsReadable[LinearRegression] {
object LinearRegression extends MLReadable[LinearRegression] {

@Since("1.6.0")
override def load(path: String): LinearRegression = super.load(path)

@Since("1.6.0")
override def read: MLReader[LinearRegression] = new LinearRegressionReader

/** [[MLWriter]] instance for [[LinearRegression]] */
private[LinearRegression] class LinearRegressionWriter(instance: LinearRegression)
extends MLWriter {

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveInitialModel(instance, path)
DefaultParamsWriter.saveMetadata(instance, path, sc)
}
}

private class LinearRegressionReader extends MLReader[LinearRegression] {

override def load(path: String): LinearRegression = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, classOf[LinearRegression].getName)
val instance = new LinearRegression(metadata.uid)

DefaultParamsReader.getAndSetParams(instance, metadata)
DefaultParamsReader.loadInitialModel[LinearRegressionModel](path, sc) match {
case Some(m) => instance.setInitialModel(m)
case None => // initialModel doesn't exist, do nothing
}
instance
}
}

/**
* When using `LinearRegression.solver` == "normal", the solver must limit the number of
* features to at most this number. The entire covariance matrix X^T^X will be collected
Expand Down Expand Up @@ -481,7 +549,7 @@ class LinearRegressionModel private[ml] (
@Since("2.0.0") val coefficients: Vector,
@Since("1.3.0") val intercept: Double)
extends RegressionModel[Vector, LinearRegressionModel]
with LinearRegressionParams with MLWritable {
with LinearRegressionModelParams with MLWritable {

private var trainingSummary: Option[LinearRegressionTrainingSummary] = None

Expand Down
40 changes: 39 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.param.shared.HasInitialModel
import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -282,6 +283,8 @@ private[ml] object DefaultParamsWriter {
* Helper for [[saveMetadata()]] which extracts the JSON to save.
* This is useful for ensemble models which need to save metadata for many sub-models.
*
* Note: This function does not handle param `initialModel`, see [[saveInitialModel()]].
*
* @see [[saveMetadata()]] for details on what this includes.
*/
def getMetadataToSave(
Expand All @@ -291,7 +294,8 @@ private[ml] object DefaultParamsWriter {
paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val params = instance.extractParamMap().toSeq
.filter(_.param.name != "initialModel").asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList))
Expand All @@ -309,6 +313,23 @@ private[ml] object DefaultParamsWriter {
val metadataJson: String = compact(render(metadata))
metadataJson
}

/**
* Save estimator's `initialModel` to corresponding path.
*/
def saveInitialModel[T <: HasInitialModel[_ <: MLWritable with Params]](
instance: T, path: String): Unit = {
if (instance.isDefined(instance.initialModel)) {
val initialModelPath = new Path(path, "initialModel").toString
val initialModel = instance.getOrDefault(instance.initialModel)
// When saving, only keep the direct initialModel by eliminating possible initialModels of the
// direct initialModel, to avoid unnecessary deep recursion of initialModel.
if (initialModel.hasParam("initialModel")) {
initialModel.clear(initialModel.getParam("initialModel"))
}
initialModel.save(initialModelPath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to avoid recursion, here modify the model to cut down the recursion, but may cause the problem that influence the model which is maybe used in other place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WeichenXu123 Can you elaborate the concrete scenario? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest when setInitialModel, first copy the passed in model (by Model.copy) and then clear the initialModel param in the passed in model to avoid recursive risk. Because the passed in model maybe used in other place we should avoid to modify its params.

Another way is avoid to add the initial model param to Model, only add it to the Estimator. (This way I think is more reasonable).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we did in the later way, initialModel is only param for Estimator, not for Model. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

}
}
}

/**
Expand Down Expand Up @@ -437,6 +458,23 @@ private[ml] object DefaultParamsReader {
val cls = Utils.classForName(metadata.className)
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
}

/**
* Load estimator's `initialModel` instance from the given path, and return it.
* If the `initialModel` path does not exist, it means the estimator does not have or
* set param `initialModel`, then return None.
* This assumes the model implements [[MLReadable]].
*/
def loadInitialModel[M <: Model[M]](path: String, sc: SparkContext): Option[M] = {
val hadoopConf = sc.hadoopConfiguration
val initialModelPath = new Path(path, "initialModel")
val fs = initialModelPath.getFileSystem(hadoopConf)
if (fs.exists(initialModelPath)) {
Some(loadParamsInstance[M](initialModelPath.toString, sc))
} else {
None
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,9 @@ class LinearRegressionSuite
assert(model.coefficients === model2.coefficients)
}
val lr = new LinearRegression()
testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
LinearRegressionSuite.allParamSettings, checkModelData)
testEstimatorAndModelReadWrite(
lr, datasetWithWeight, LinearRegressionSuite.estimatorParamSettings,
LinearRegressionSuite.modelParamSettings, checkModelData)
}

test("should support all NumericType labels and weights, and not support other types") {
Expand All @@ -998,23 +999,60 @@ class LinearRegressionSuite
}
}
}

test("training with initial model") {
val trainer1 = new LinearRegression()
.setRegParam(2.3)
.setFitIntercept(false)
.setSolver("l-bfgs")

val model1 = trainer1.fit(datasetWithSparseFeature)
val model2 = trainer1.setInitialModel(model1).fit(datasetWithSparseFeature)

assert(model1.coefficients ~== model2.coefficients absTol 1E-3)
assert(model1.intercept === model2.intercept)
assert(model1.summary.objectiveHistory.length < model2.summary.objectiveHistory.length)

// Initial model will be ignored if fitting by normal solver.
val trainer2 = new LinearRegression().setRegParam(2.3).setFitIntercept(false)
val model3 = trainer2.fit(datasetWithDenseFeature)
val model4 = trainer2.setInitialModel(model3).fit(datasetWithDenseFeature)

assert(model3.coefficients ~== model4.coefficients absTol 1E-3)
assert(model3.intercept === model4.intercept)
assert(model3.summary.objectiveHistory.length === model4.summary.objectiveHistory.length)

// Training dataset dimension mismatched.
intercept[IllegalArgumentException] {
val initialModel = new LinearRegressionModel("linReg", Vectors.dense(1.0), 1.0)
new LinearRegression().setInitialModel(initialModel).fit(datasetWithSparseFeature)
}
}
}

object LinearRegressionSuite {

/**
* Mapping from all Params to valid settings which differ from the defaults.
* Mapping from estimator Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
val estimatorParamSettings: Map[String, Any] = Map(
"predictionCol" -> "myPrediction",
"regParam" -> 0.01,
"elasticNetParam" -> 0.1,
"maxIter" -> 2, // intentionally small
"fitIntercept" -> true,
"tol" -> 0.8,
"standardization" -> false,
"solver" -> "l-bfgs"
"solver" -> "l-bfgs",
"initialModel" -> new LinearRegressionModel("linReg", Vectors.dense(1.0, 1.0), 1.0)
)

/**
* Mapping from model Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val modelParamSettings: Map[String, Any] = Map("predictionCol" -> "myPrediction")
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,17 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
assert(newInstance.uid === instance.uid)
if (testParams) {
instance.params.foreach { p =>
if (instance.isDefined(p)) {
(instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
case (Array(values), Array(newValues)) =>
assert(values === newValues, s"Values do not match on param ${p.name}.")
case (value, newValue) =>
assert(value === newValue, s"Values do not match on param ${p.name}.")
if (p.name != "initialModel") {
if (instance.isDefined(p)) {
(instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
case (Array(values), Array(newValues)) =>
assert(values === newValues, s"Values do not match on param ${p.name}.")
case (value, newValue) =>
assert(value === newValue, s"Values do not match on param ${p.name}.")
}
} else {
assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
}
} else {
assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
}
}
}
Expand Down Expand Up @@ -113,7 +115,14 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
val estimator2 = testDefaultReadWrite(estimator)
testEstimatorParams.foreach { case (p, v) =>
val param = estimator.getParam(p)
assert(estimator.get(param).get === estimator2.get(param).get)
if (param.name == "initialModel") {
// Estimator's `initialModel` has same type as the model produced by this estimator.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an assumption that initialModel must be the same type of the model that the estimator produced. We should enforce this assumption by making the trait as HasInitialModel[T <: Model[T]] or add new argument checkInitialModelData for testEstimatorAndModelReadWrite. I'm open to hear others' thoughts and update here after collecting feedback. Thanks.

// So we can use `checkModelData` to check equality of `initialModel` as well.
checkModelData(estimator.get(param).get.asInstanceOf[M],
estimator2.get(param).get.asInstanceOf[M])
} else {
assert(estimator.get(param).get === estimator2.get(param).get)
}
}

// Test Model save/load
Expand Down
7 changes: 7 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,13 @@ object MimaExcludes {
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy")
) ++ Seq(
// [SPARK-21386] ML LinearRegression supports warm start from user provided initial model.
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.regression.LinearRegressionModelParams.org$apache$spark$ml$regression$LinearRegressionModelParams$_setter_$solver_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasInitialModel.initialModel"),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasInitialModel.getInitialModel"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.LinearRegression$"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.LinearRegressionModel")
)
}

Expand Down